1mod connection;
2pub use connection::*;
3
4pub use acp_old::ToolCallId;
5use agent_client_protocol::{self as acp};
6use agentic_coding_protocol as acp_old;
7use anyhow::{Context as _, Result};
8use assistant_tool::ActionLog;
9use buffer_diff::BufferDiff;
10use editor::{Bias, MultiBuffer, PathKey};
11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
12use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
13use itertools::{Itertools, PeekingNext};
14use language::{
15 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
16 text_diff,
17};
18use markdown::Markdown;
19use project::{AgentLocation, Project};
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt::Formatter;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::{App, IconName};
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n{}\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(
114 acp::ContentBlock::Text(acp::TextContent {
115 text: chunk.to_owned().into(),
116 annotations: None,
117 }),
118 language_registry,
119 cx,
120 ),
121 }
122 }
123
124 fn to_markdown(&self, cx: &App) -> String {
125 match self {
126 Self::Message { block } => block.to_markdown(cx).to_string(),
127 Self::Thought { block } => {
128 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
129 }
130 }
131 }
132}
133
134#[derive(Debug)]
135pub enum AgentThreadEntry {
136 UserMessage(UserMessage),
137 AssistantMessage(AssistantMessage),
138 ToolCall(ToolCall),
139}
140
141impl AgentThreadEntry {
142 fn to_markdown(&self, cx: &App) -> String {
143 match self {
144 Self::UserMessage(message) => message.to_markdown(cx),
145 Self::AssistantMessage(message) => message.to_markdown(cx),
146 Self::ToolCall(too_call) => too_call.to_markdown(cx),
147 }
148 }
149
150 pub fn diff(&self) -> Option<&Diff> {
151 if let AgentThreadEntry::ToolCall(ToolCall {
152 content: Some(ToolCallContent::Diff { diff }),
153 ..
154 }) = self
155 {
156 Some(&diff)
157 } else {
158 None
159 }
160 }
161
162 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
163 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
164 Some(locations)
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 // todo! Should this be a vec?
177 pub content: Option<ToolCallContent>,
178 pub status: ToolCallStatus,
179 pub locations: Vec<acp::ToolCallLocation>,
180}
181
182impl ToolCall {
183 fn from_acp(
184 tool_call: acp::ToolCall,
185 status: ToolCallStatus,
186 language_registry: Arc<LanguageRegistry>,
187 cx: &mut App,
188 ) -> Self {
189 Self {
190 id: tool_call.id,
191 label: cx.new(|cx| {
192 Markdown::new(
193 tool_call.label.into(),
194 Some(language_registry.clone()),
195 None,
196 cx,
197 )
198 }),
199 kind: tool_call.kind,
200 // todo! Do we assume there is either a coalesced content OR diff?
201 content: ToolCallContent::from_acp_contents(tool_call.content, language_registry, cx)
202 .into_iter()
203 .next(),
204 locations: tool_call.locations,
205 status,
206 }
207 }
208 fn to_markdown(&self, cx: &App) -> String {
209 let mut markdown = format!(
210 "**Tool Call: {}**\nStatus: {}\n\n",
211 self.label.read(cx).source(),
212 self.status
213 );
214 if let Some(content) = &self.content {
215 markdown.push_str(content.to_markdown(cx).as_str());
216 markdown.push_str("\n\n");
217 }
218 markdown
219 }
220}
221
222#[derive(Debug)]
223pub enum ToolCallStatus {
224 WaitingForConfirmation {
225 possible_grants: Vec<acp::Grant>,
226 respond_tx: oneshot::Sender<acp::GrantId>,
227 },
228 Allowed {
229 status: acp::ToolCallStatus,
230 },
231 Rejected,
232 Canceled,
233}
234
235impl Display for ToolCallStatus {
236 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237 write!(
238 f,
239 "{}",
240 match self {
241 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
242 ToolCallStatus::Allowed { status } => match status {
243 acp::ToolCallStatus::InProgress => "In Progress",
244 acp::ToolCallStatus::Completed => "Completed",
245 acp::ToolCallStatus::Failed => "Failed",
246 },
247 ToolCallStatus::Rejected => "Rejected",
248 ToolCallStatus::Canceled => "Canceled",
249 }
250 )
251 }
252}
253
254#[derive(Debug)]
255pub enum ToolCallConfirmation {
256 Edit {
257 description: Option<Entity<Markdown>>,
258 },
259 Execute {
260 command: String,
261 root_command: String,
262 description: Option<Entity<Markdown>>,
263 },
264 Mcp {
265 server_name: String,
266 tool_name: String,
267 tool_display_name: String,
268 description: Option<Entity<Markdown>>,
269 },
270 Fetch {
271 urls: Vec<SharedString>,
272 description: Option<Entity<Markdown>>,
273 },
274 Other {
275 description: Entity<Markdown>,
276 },
277}
278
279impl ToolCallConfirmation {
280 pub fn from_acp(
281 confirmation: acp_old::ToolCallConfirmation,
282 language_registry: Arc<LanguageRegistry>,
283 cx: &mut App,
284 ) -> Self {
285 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
286 cx.new(|cx| {
287 Markdown::new(
288 description.into(),
289 Some(language_registry.clone()),
290 None,
291 cx,
292 )
293 })
294 };
295
296 match confirmation {
297 acp_old::ToolCallConfirmation::Edit { description } => Self::Edit {
298 description: description.map(|description| to_md(description, cx)),
299 },
300 acp_old::ToolCallConfirmation::Execute {
301 command,
302 root_command,
303 description,
304 } => Self::Execute {
305 command,
306 root_command,
307 description: description.map(|description| to_md(description, cx)),
308 },
309 acp_old::ToolCallConfirmation::Mcp {
310 server_name,
311 tool_name,
312 tool_display_name,
313 description,
314 } => Self::Mcp {
315 server_name,
316 tool_name,
317 tool_display_name,
318 description: description.map(|description| to_md(description, cx)),
319 },
320 acp_old::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
321 urls: urls.iter().map(|url| url.into()).collect(),
322 description: description.map(|description| to_md(description, cx)),
323 },
324 acp_old::ToolCallConfirmation::Other { description } => Self::Other {
325 description: to_md(description, cx),
326 },
327 }
328 }
329}
330
331#[derive(Debug, PartialEq, Clone)]
332enum ContentBlock {
333 Empty,
334 Markdown { markdown: Entity<Markdown> },
335}
336
337impl ContentBlock {
338 pub fn new(
339 block: acp::ContentBlock,
340 language_registry: &Arc<LanguageRegistry>,
341 cx: &mut App,
342 ) -> Self {
343 let mut this = Self::Empty;
344 this.append(block, language_registry, cx);
345 this
346 }
347
348 pub fn new_combined(
349 blocks: impl IntoIterator<Item = acp::ContentBlock>,
350 language_registry: &Arc<LanguageRegistry>,
351 cx: &mut App,
352 ) -> Self {
353 let mut this = Self::Empty;
354 for block in blocks {
355 this.append(block, language_registry, cx);
356 }
357 this
358 }
359
360 pub fn append(
361 &mut self,
362 block: acp::ContentBlock,
363 language_registry: &Arc<LanguageRegistry>,
364 cx: &mut App,
365 ) {
366 let new_content = match block {
367 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
368 acp::ContentBlock::ResourceLink(resource_link) => {
369 if let Some(path) = resource_link.uri.strip_prefix("file://") {
370 format!("{}", MentionPath(path.as_ref()))
371 } else {
372 resource_link.uri.clone()
373 }
374 }
375 acp::ContentBlock::Image(_)
376 | acp::ContentBlock::Audio(_)
377 | acp::ContentBlock::Resource(_) => String::new(),
378 };
379
380 match self {
381 ContentBlock::Empty => {
382 *self = ContentBlock::Markdown {
383 markdown: cx.new(|cx| {
384 Markdown::new(
385 new_content.into(),
386 Some(language_registry.clone()),
387 None,
388 cx,
389 )
390 }),
391 };
392 }
393 ContentBlock::Markdown { markdown } => {
394 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
395 }
396 }
397 }
398
399 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
400 match self {
401 ContentBlock::Empty => "",
402 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
403 }
404 }
405}
406
407#[derive(Debug)]
408pub enum ToolCallContent {
409 ContentBlock { content: ContentBlock },
410 Diff { diff: Diff },
411}
412
413impl ToolCallContent {
414 pub fn from_acp(
415 content: acp::ToolCallContent,
416 language_registry: Arc<LanguageRegistry>,
417 cx: &mut App,
418 ) -> Self {
419 match content {
420 acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock {
421 content: ContentBlock::new(content, &language_registry, cx),
422 },
423 acp::ToolCallContent::Diff { diff } => Self::Diff {
424 diff: Diff::from_acp(diff, language_registry, cx),
425 },
426 }
427 }
428
429 pub fn from_acp_contents(
430 content: Vec<acp::ToolCallContent>,
431 language_registry: Arc<LanguageRegistry>,
432 cx: &mut App,
433 ) -> Vec<Self> {
434 content
435 .into_iter()
436 .peekable()
437 .batching(|it| match it.next()? {
438 acp::ToolCallContent::ContentBlock { content } => {
439 let mut block = ContentBlock::new(content, &language_registry, cx);
440 while let Some(acp::ToolCallContent::ContentBlock { content }) =
441 it.peeking_next(|c| matches!(c, acp::ToolCallContent::ContentBlock { .. }))
442 {
443 block.append(content, &language_registry, cx);
444 }
445 Some(ToolCallContent::ContentBlock { content: block })
446 }
447 content @ acp::ToolCallContent::Diff { .. } => Some(ToolCallContent::from_acp(
448 content,
449 language_registry.clone(),
450 cx,
451 )),
452 })
453 .collect()
454 }
455
456 fn to_markdown(&self, cx: &App) -> String {
457 match self {
458 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
459 Self::Diff { diff } => diff.to_markdown(cx),
460 }
461 }
462}
463
464#[derive(Debug)]
465pub struct Diff {
466 pub multibuffer: Entity<MultiBuffer>,
467 pub path: PathBuf,
468 pub new_buffer: Entity<Buffer>,
469 pub old_buffer: Entity<Buffer>,
470 _task: Task<Result<()>>,
471}
472
473impl Diff {
474 pub fn from_acp(
475 diff: acp::Diff,
476 language_registry: Arc<LanguageRegistry>,
477 cx: &mut App,
478 ) -> Self {
479 let acp::Diff {
480 path,
481 old_text,
482 new_text,
483 } = diff;
484
485 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
486
487 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
488 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
489 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
490 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
491 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
492 let diff_task = buffer_diff.update(cx, |diff, cx| {
493 diff.set_base_text(
494 old_buffer_snapshot,
495 Some(language_registry.clone()),
496 new_buffer_snapshot,
497 cx,
498 )
499 });
500
501 let task = cx.spawn({
502 let multibuffer = multibuffer.clone();
503 let path = path.clone();
504 let new_buffer = new_buffer.clone();
505 async move |cx| {
506 diff_task.await?;
507
508 multibuffer
509 .update(cx, |multibuffer, cx| {
510 let hunk_ranges = {
511 let buffer = new_buffer.read(cx);
512 let diff = buffer_diff.read(cx);
513 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
514 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
515 .collect::<Vec<_>>()
516 };
517
518 multibuffer.set_excerpts_for_path(
519 PathKey::for_buffer(&new_buffer, cx),
520 new_buffer.clone(),
521 hunk_ranges,
522 editor::DEFAULT_MULTIBUFFER_CONTEXT,
523 cx,
524 );
525 multibuffer.add_diff(buffer_diff.clone(), cx);
526 })
527 .log_err();
528
529 if let Some(language) = language_registry
530 .language_for_file_path(&path)
531 .await
532 .log_err()
533 {
534 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
535 }
536
537 anyhow::Ok(())
538 }
539 });
540
541 Self {
542 multibuffer,
543 path,
544 new_buffer,
545 old_buffer,
546 _task: task,
547 }
548 }
549
550 fn to_markdown(&self, cx: &App) -> String {
551 let buffer_text = self
552 .multibuffer
553 .read(cx)
554 .all_buffers()
555 .iter()
556 .map(|buffer| buffer.read(cx).text())
557 .join("\n");
558 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
559 }
560}
561
562#[derive(Debug, Default)]
563pub struct Plan {
564 pub entries: Vec<PlanEntry>,
565}
566
567#[derive(Debug)]
568pub struct PlanStats<'a> {
569 pub in_progress_entry: Option<&'a PlanEntry>,
570 pub pending: u32,
571 pub completed: u32,
572}
573
574impl Plan {
575 pub fn is_empty(&self) -> bool {
576 self.entries.is_empty()
577 }
578
579 pub fn stats(&self) -> PlanStats<'_> {
580 let mut stats = PlanStats {
581 in_progress_entry: None,
582 pending: 0,
583 completed: 0,
584 };
585
586 for entry in &self.entries {
587 match &entry.status {
588 acp_old::PlanEntryStatus::Pending => {
589 stats.pending += 1;
590 }
591 acp_old::PlanEntryStatus::InProgress => {
592 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
593 }
594 acp_old::PlanEntryStatus::Completed => {
595 stats.completed += 1;
596 }
597 }
598 }
599
600 stats
601 }
602}
603
604#[derive(Debug)]
605pub struct PlanEntry {
606 pub content: Entity<Markdown>,
607 pub priority: acp_old::PlanEntryPriority,
608 pub status: acp_old::PlanEntryStatus,
609}
610
611impl PlanEntry {
612 pub fn from_acp(entry: acp_old::PlanEntry, cx: &mut App) -> Self {
613 Self {
614 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
615 priority: entry.priority,
616 status: entry.status,
617 }
618 }
619}
620
621pub struct AcpThread {
622 title: SharedString,
623 entries: Vec<AgentThreadEntry>,
624 plan: Plan,
625 project: Entity<Project>,
626 action_log: Entity<ActionLog>,
627 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
628 send_task: Option<Task<()>>,
629 connection: Arc<dyn AgentConnection>,
630 child_status: Option<Task<Result<()>>>,
631}
632
633pub enum AcpThreadEvent {
634 NewEntry,
635 EntryUpdated(usize),
636}
637
638impl EventEmitter<AcpThreadEvent> for AcpThread {}
639
640#[derive(PartialEq, Eq)]
641pub enum ThreadStatus {
642 Idle,
643 WaitingForToolConfirmation,
644 Generating,
645}
646
647#[derive(Debug, Clone)]
648pub enum LoadError {
649 Unsupported {
650 error_message: SharedString,
651 upgrade_message: SharedString,
652 upgrade_command: String,
653 },
654 Exited(i32),
655 Other(SharedString),
656}
657
658impl Display for LoadError {
659 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
660 match self {
661 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
662 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
663 LoadError::Other(msg) => write!(f, "{}", msg),
664 }
665 }
666}
667
668impl Error for LoadError {}
669
670impl AcpThread {
671 pub fn new(
672 connection: impl AgentConnection + 'static,
673 title: SharedString,
674 child_status: Option<Task<Result<()>>>,
675 project: Entity<Project>,
676 cx: &mut Context<Self>,
677 ) -> Self {
678 let action_log = cx.new(|_| ActionLog::new(project.clone()));
679
680 Self {
681 action_log,
682 shared_buffers: Default::default(),
683 entries: Default::default(),
684 plan: Default::default(),
685 title,
686 project,
687 send_task: None,
688 connection: Arc::new(connection),
689 child_status,
690 }
691 }
692
693 /// Send a request to the agent and wait for a response.
694 pub fn request<R: acp_old::AgentRequest + 'static>(
695 &self,
696 params: R,
697 ) -> impl use<R> + Future<Output = Result<R::Response>> {
698 let params = params.into_any();
699 let result = self.connection.request_any(params);
700 async move {
701 let result = result.await?;
702 Ok(R::response_from_any(result)?)
703 }
704 }
705
706 pub fn action_log(&self) -> &Entity<ActionLog> {
707 &self.action_log
708 }
709
710 pub fn project(&self) -> &Entity<Project> {
711 &self.project
712 }
713
714 pub fn title(&self) -> SharedString {
715 self.title.clone()
716 }
717
718 pub fn entries(&self) -> &[AgentThreadEntry] {
719 &self.entries
720 }
721
722 pub fn status(&self) -> ThreadStatus {
723 if self.send_task.is_some() {
724 if self.waiting_for_tool_confirmation() {
725 ThreadStatus::WaitingForToolConfirmation
726 } else {
727 ThreadStatus::Generating
728 }
729 } else {
730 ThreadStatus::Idle
731 }
732 }
733
734 pub fn has_pending_edit_tool_calls(&self) -> bool {
735 for entry in self.entries.iter().rev() {
736 match entry {
737 AgentThreadEntry::UserMessage(_) => return false,
738 AgentThreadEntry::ToolCall(ToolCall {
739 status:
740 ToolCallStatus::Allowed {
741 status: acp::ToolCallStatus::InProgress,
742 ..
743 },
744 content: Some(ToolCallContent::Diff { .. }),
745 ..
746 }) => return true,
747 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
748 }
749 }
750
751 false
752 }
753
754 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
755 self.entries.push(entry);
756 cx.emit(AcpThreadEvent::NewEntry);
757 }
758
759 pub fn push_assistant_chunk(
760 &mut self,
761 chunk: acp::ContentBlock,
762 is_thought: bool,
763 cx: &mut Context<Self>,
764 ) {
765 let language_registry = self.project.read(cx).languages().clone();
766 let entries_len = self.entries.len();
767 if let Some(last_entry) = self.entries.last_mut()
768 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
769 {
770 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
771 match (chunks.last_mut(), is_thought) {
772 (Some(AssistantMessageChunk::Message { block }), false)
773 | (Some(AssistantMessageChunk::Thought { block }), true) => {
774 block.append(chunk, &language_registry, cx)
775 }
776 _ => {
777 let block = ContentBlock::new(chunk, &language_registry, cx);
778 if is_thought {
779 chunks.push(AssistantMessageChunk::Thought { block })
780 } else {
781 chunks.push(AssistantMessageChunk::Message { block })
782 }
783 }
784 }
785 } else {
786 let block = ContentBlock::new(chunk, &language_registry, cx);
787 let chunk = if is_thought {
788 AssistantMessageChunk::Thought { block }
789 } else {
790 AssistantMessageChunk::Message { block }
791 };
792
793 self.push_entry(
794 AgentThreadEntry::AssistantMessage(AssistantMessage {
795 chunks: vec![chunk],
796 }),
797 cx,
798 );
799 }
800 }
801
802 pub fn update_tool_call(
803 &mut self,
804 tool_call: acp::ToolCall,
805 cx: &mut Context<Self>,
806 ) -> Result<()> {
807 let language_registry = self.project.read(cx).languages().clone();
808 let status = ToolCallStatus::Allowed {
809 status: tool_call.status,
810 };
811 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
812
813 let location = call.locations.last().cloned();
814
815 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
816 match ¤t_call.status {
817 ToolCallStatus::WaitingForConfirmation { .. } => {
818 anyhow::bail!("Tool call hasn't been authorized yet")
819 }
820 ToolCallStatus::Rejected => {
821 anyhow::bail!("Tool call was rejected and therefore can't be updated")
822 }
823 ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {}
824 }
825
826 *current_call = call;
827
828 cx.emit(AcpThreadEvent::EntryUpdated(ix));
829 } else {
830 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
831 }
832
833 if let Some(location) = location {
834 self.set_project_location(location, cx)
835 }
836
837 Ok(())
838 }
839
840 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
841 self.entries
842 .iter_mut()
843 .enumerate()
844 .rev()
845 .find_map(|(index, tool_call)| {
846 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
847 && &tool_call.id == id
848 {
849 Some((index, tool_call))
850 } else {
851 None
852 }
853 })
854 }
855
856 pub fn request_tool_call_permission(
857 &mut self,
858 tool_call: acp::ToolCall,
859 possible_grants: Vec<acp::Grant>,
860 cx: &mut Context<Self>,
861 ) -> oneshot::Receiver<acp::GrantId> {
862 let (tx, rx) = oneshot::channel();
863
864 let status = ToolCallStatus::WaitingForConfirmation {
865 possible_grants,
866 respond_tx: tx,
867 };
868
869 self.insert_tool_call(tool_call, status, cx);
870 rx
871 }
872
873 fn insert_tool_call(
874 &mut self,
875 tool_call: acp::ToolCall,
876 status: ToolCallStatus,
877 cx: &mut Context<Self>,
878 ) {
879 let language_registry = self.project.read(cx).languages().clone();
880 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
881
882 let location = call.locations.last().cloned();
883 if let Some(location) = location {
884 self.set_project_location(location, cx)
885 }
886
887 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
888 }
889
890 pub fn authorize_tool_call(
891 &mut self,
892 id: acp::ToolCallId,
893 grant: acp::Grant,
894 cx: &mut Context<Self>,
895 ) {
896 let Some((ix, call)) = self.tool_call_mut(&id) else {
897 return;
898 };
899
900 let new_status = if grant.is_allowed {
901 ToolCallStatus::Allowed {
902 status: acp::ToolCallStatus::InProgress,
903 }
904 } else {
905 ToolCallStatus::Rejected
906 };
907
908 let curr_status = mem::replace(&mut call.status, new_status);
909
910 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
911 respond_tx.send(grant.id).log_err();
912 } else if cfg!(debug_assertions) {
913 panic!("tried to authorize an already authorized tool call");
914 }
915
916 cx.emit(AcpThreadEvent::EntryUpdated(ix));
917 }
918
919 pub fn plan(&self) -> &Plan {
920 &self.plan
921 }
922
923 pub fn update_plan(&mut self, request: acp_old::UpdatePlanParams, cx: &mut Context<Self>) {
924 self.plan = Plan {
925 entries: request
926 .entries
927 .into_iter()
928 .map(|entry| PlanEntry::from_acp(entry, cx))
929 .collect(),
930 };
931
932 cx.notify();
933 }
934
935 pub fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
936 self.plan
937 .entries
938 .retain(|entry| !matches!(entry.status, acp_old::PlanEntryStatus::Completed));
939 cx.notify();
940 }
941
942 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
943 self.project.update(cx, |project, cx| {
944 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
945 return;
946 };
947 let buffer = project.open_buffer(path, cx);
948 cx.spawn(async move |project, cx| {
949 let buffer = buffer.await?;
950
951 project.update(cx, |project, cx| {
952 let position = if let Some(line) = location.line {
953 let snapshot = buffer.read(cx).snapshot();
954 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
955 snapshot.anchor_before(point)
956 } else {
957 Anchor::MIN
958 };
959
960 project.set_agent_location(
961 Some(AgentLocation {
962 buffer: buffer.downgrade(),
963 position,
964 }),
965 cx,
966 );
967 })
968 })
969 .detach_and_log_err(cx);
970 });
971 }
972
973 /// Returns true if the last turn is awaiting tool authorization
974 pub fn waiting_for_tool_confirmation(&self) -> bool {
975 for entry in self.entries.iter().rev() {
976 match &entry {
977 AgentThreadEntry::ToolCall(call) => match call.status {
978 ToolCallStatus::WaitingForConfirmation { .. } => return true,
979 ToolCallStatus::Allowed { .. }
980 | ToolCallStatus::Rejected
981 | ToolCallStatus::Canceled => continue,
982 },
983 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
984 // Reached the beginning of the turn
985 return false;
986 }
987 }
988 }
989 false
990 }
991
992 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp_old::InitializeResponse>> {
993 self.request(acp_old::InitializeParams {
994 protocol_version: acp_old::ProtocolVersion::latest(),
995 })
996 }
997
998 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
999 self.request(acp_old::AuthenticateParams)
1000 }
1001
1002 #[cfg(any(test, feature = "test-support"))]
1003 pub fn send_raw(
1004 &mut self,
1005 message: &str,
1006 cx: &mut Context<Self>,
1007 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1008 self.send(
1009 vec![acp::ContentBlock::Text(acp::TextContent {
1010 text: message.to_string(),
1011 annotations: None,
1012 })],
1013 cx,
1014 )
1015 }
1016
1017 pub fn send(
1018 &mut self,
1019 message: Vec<acp::ContentBlock>,
1020 cx: &mut Context<Self>,
1021 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1022 let block =
1023 ContentBlock::new_combined(message.clone(), self.project.read(cx).languages(), cx);
1024 self.push_entry(
1025 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1026 cx,
1027 );
1028
1029 let (tx, rx) = oneshot::channel();
1030 let cancel = self.cancel(cx);
1031
1032 self.send_task = Some(cx.spawn(async move |this, cx| {
1033 async {
1034 cancel.await.log_err();
1035
1036 let result = this.update(cx, |this, _| this.request(message))?.await;
1037 tx.send(result).log_err();
1038 this.update(cx, |this, _cx| this.send_task.take())?;
1039 anyhow::Ok(())
1040 }
1041 .await
1042 .log_err();
1043 }));
1044
1045 async move {
1046 match rx.await {
1047 Ok(Err(e)) => Err(e)?,
1048 _ => Ok(()),
1049 }
1050 }
1051 .boxed()
1052 }
1053
1054 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp_old::Error>> {
1055 if self.send_task.take().is_some() {
1056 let request = self.request(acp_old::CancelSendMessageParams);
1057 cx.spawn(async move |this, cx| {
1058 request.await?;
1059 this.update(cx, |this, _cx| {
1060 for entry in this.entries.iter_mut() {
1061 if let AgentThreadEntry::ToolCall(call) = entry {
1062 let cancel = matches!(
1063 call.status,
1064 ToolCallStatus::WaitingForConfirmation { .. }
1065 | ToolCallStatus::Allowed {
1066 status: acp::ToolCallStatus::InProgress
1067 }
1068 );
1069
1070 if cancel {
1071 let curr_status =
1072 mem::replace(&mut call.status, ToolCallStatus::Canceled);
1073
1074 if let ToolCallStatus::WaitingForConfirmation {
1075 respond_tx, ..
1076 } = curr_status
1077 {
1078 // todo! do we need a way to cancel rather than reject?
1079 respond_tx
1080 .send(acp_old::ToolCallConfirmationOutcome::Cancel)
1081 .ok();
1082 }
1083 }
1084 }
1085 }
1086 })?;
1087 Ok(())
1088 })
1089 } else {
1090 Task::ready(Ok(()))
1091 }
1092 }
1093
1094 pub fn read_text_file(
1095 &self,
1096 request: acp_old::ReadTextFileParams,
1097 reuse_shared_snapshot: bool,
1098 cx: &mut Context<Self>,
1099 ) -> Task<Result<String>> {
1100 let project = self.project.clone();
1101 let action_log = self.action_log.clone();
1102 cx.spawn(async move |this, cx| {
1103 let load = project.update(cx, |project, cx| {
1104 let path = project
1105 .project_path_for_absolute_path(&request.path, cx)
1106 .context("invalid path")?;
1107 anyhow::Ok(project.open_buffer(path, cx))
1108 });
1109 let buffer = load??.await?;
1110
1111 let snapshot = if reuse_shared_snapshot {
1112 this.read_with(cx, |this, _| {
1113 this.shared_buffers.get(&buffer.clone()).cloned()
1114 })
1115 .log_err()
1116 .flatten()
1117 } else {
1118 None
1119 };
1120
1121 let snapshot = if let Some(snapshot) = snapshot {
1122 snapshot
1123 } else {
1124 action_log.update(cx, |action_log, cx| {
1125 action_log.buffer_read(buffer.clone(), cx);
1126 })?;
1127 project.update(cx, |project, cx| {
1128 let position = buffer
1129 .read(cx)
1130 .snapshot()
1131 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1132 project.set_agent_location(
1133 Some(AgentLocation {
1134 buffer: buffer.downgrade(),
1135 position,
1136 }),
1137 cx,
1138 );
1139 })?;
1140
1141 buffer.update(cx, |buffer, _| buffer.snapshot())?
1142 };
1143
1144 this.update(cx, |this, _| {
1145 let text = snapshot.text();
1146 this.shared_buffers.insert(buffer.clone(), snapshot);
1147 if request.line.is_none() && request.limit.is_none() {
1148 return Ok(text);
1149 }
1150 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1151 let Some(line) = request.line else {
1152 return Ok(text.lines().take(limit).collect::<String>());
1153 };
1154
1155 let count = text.lines().count();
1156 if count < line as usize {
1157 anyhow::bail!("There are only {} lines", count);
1158 }
1159 Ok(text
1160 .lines()
1161 .skip(line as usize + 1)
1162 .take(limit)
1163 .collect::<String>())
1164 })?
1165 })
1166 }
1167
1168 pub fn write_text_file(
1169 &self,
1170 path: PathBuf,
1171 content: String,
1172 cx: &mut Context<Self>,
1173 ) -> Task<Result<()>> {
1174 let project = self.project.clone();
1175 let action_log = self.action_log.clone();
1176 cx.spawn(async move |this, cx| {
1177 let load = project.update(cx, |project, cx| {
1178 let path = project
1179 .project_path_for_absolute_path(&path, cx)
1180 .context("invalid path")?;
1181 anyhow::Ok(project.open_buffer(path, cx))
1182 });
1183 let buffer = load??.await?;
1184 let snapshot = this.update(cx, |this, cx| {
1185 this.shared_buffers
1186 .get(&buffer)
1187 .cloned()
1188 .unwrap_or_else(|| buffer.read(cx).snapshot())
1189 })?;
1190 let edits = cx
1191 .background_executor()
1192 .spawn(async move {
1193 let old_text = snapshot.text();
1194 text_diff(old_text.as_str(), &content)
1195 .into_iter()
1196 .map(|(range, replacement)| {
1197 (
1198 snapshot.anchor_after(range.start)
1199 ..snapshot.anchor_before(range.end),
1200 replacement,
1201 )
1202 })
1203 .collect::<Vec<_>>()
1204 })
1205 .await;
1206 cx.update(|cx| {
1207 project.update(cx, |project, cx| {
1208 project.set_agent_location(
1209 Some(AgentLocation {
1210 buffer: buffer.downgrade(),
1211 position: edits
1212 .last()
1213 .map(|(range, _)| range.end)
1214 .unwrap_or(Anchor::MIN),
1215 }),
1216 cx,
1217 );
1218 });
1219
1220 action_log.update(cx, |action_log, cx| {
1221 action_log.buffer_read(buffer.clone(), cx);
1222 });
1223 buffer.update(cx, |buffer, cx| {
1224 buffer.edit(edits, None, cx);
1225 });
1226 action_log.update(cx, |action_log, cx| {
1227 action_log.buffer_edited(buffer.clone(), cx);
1228 });
1229 })?;
1230 project
1231 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1232 .await
1233 })
1234 }
1235
1236 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1237 self.child_status.take()
1238 }
1239
1240 pub fn to_markdown(&self, cx: &App) -> String {
1241 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1242 }
1243}
1244
1245#[derive(Clone)]
1246pub struct OldAcpClientDelegate {
1247 thread: WeakEntity<AcpThread>,
1248 cx: AsyncApp,
1249 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1250}
1251
1252impl OldAcpClientDelegate {
1253 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1254 Self { thread, cx }
1255 }
1256
1257 pub async fn clear_completed_plan_entries(&self) -> Result<()> {
1258 let cx = &mut self.cx.clone();
1259 cx.update(|cx| {
1260 self.thread
1261 .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
1262 })?
1263 .context("Failed to update thread")?;
1264
1265 Ok(())
1266 }
1267
1268 pub async fn request_existing_tool_call_confirmation(
1269 &self,
1270 tool_call_id: ToolCallId,
1271 confirmation: acp_old::ToolCallConfirmation,
1272 ) -> Result<acp_old::ToolCallConfirmationOutcome> {
1273 let cx = &mut self.cx.clone();
1274 let ToolCallRequest { outcome, .. } = cx
1275 .update(|cx| {
1276 self.thread.update(cx, |thread, cx| {
1277 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1278 })
1279 })?
1280 .context("Failed to update thread")??;
1281
1282 Ok(outcome.await?)
1283 }
1284
1285 pub async fn read_text_file_reusing_snapshot(
1286 &self,
1287 request: acp_old::ReadTextFileParams,
1288 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1289 let content = self
1290 .cx
1291 .update(|cx| {
1292 self.thread
1293 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1294 })?
1295 .context("Failed to update thread")?
1296 .await?;
1297 Ok(acp_old::ReadTextFileResponse { content })
1298 }
1299}
1300
1301impl acp_old::Client for OldAcpClientDelegate {
1302 async fn stream_assistant_message_chunk(
1303 &self,
1304 params: acp_old::StreamAssistantMessageChunkParams,
1305 ) -> Result<(), acp_old::Error> {
1306 let cx = &mut self.cx.clone();
1307
1308 cx.update(|cx| {
1309 self.thread
1310 .update(cx, |thread, cx| {
1311 thread.push_assistant_chunk(params.chunk, cx)
1312 })
1313 .ok();
1314 })?;
1315
1316 Ok(())
1317 }
1318
1319 async fn request_tool_call_confirmation(
1320 &self,
1321 request: acp_old::RequestToolCallConfirmationParams,
1322 ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
1323 let cx = &mut self.cx.clone();
1324 let ToolCallRequest { id, outcome } = cx
1325 .update(|cx| {
1326 self.thread
1327 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1328 })?
1329 .context("Failed to update thread")?;
1330
1331 Ok(acp_old::RequestToolCallConfirmationResponse {
1332 id,
1333 outcome: outcome.await.map_err(acp_old::Error::into_internal_error)?,
1334 })
1335 }
1336
1337 async fn push_tool_call(
1338 &self,
1339 request: acp_old::PushToolCallParams,
1340 ) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
1341 let cx = &mut self.cx.clone();
1342 let id = cx
1343 .update(|cx| {
1344 self.thread
1345 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1346 })?
1347 .context("Failed to update thread")?;
1348
1349 Ok(acp_old::PushToolCallResponse { id })
1350 }
1351
1352 async fn update_tool_call(
1353 &self,
1354 request: acp_old::UpdateToolCallParams,
1355 ) -> Result<(), acp_old::Error> {
1356 let cx = &mut self.cx.clone();
1357
1358 cx.update(|cx| {
1359 self.thread.update(cx, |thread, cx| {
1360 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1361 })
1362 })?
1363 .context("Failed to update thread")??;
1364
1365 Ok(())
1366 }
1367
1368 async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
1369 let cx = &mut self.cx.clone();
1370
1371 cx.update(|cx| {
1372 self.thread
1373 .update(cx, |thread, cx| thread.update_plan(request, cx))
1374 })?
1375 .context("Failed to update thread")?;
1376
1377 Ok(())
1378 }
1379
1380 async fn read_text_file(
1381 &self,
1382 request: acp_old::ReadTextFileParams,
1383 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1384 let content = self
1385 .cx
1386 .update(|cx| {
1387 self.thread
1388 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1389 })?
1390 .context("Failed to update thread")?
1391 .await?;
1392 Ok(acp_old::ReadTextFileResponse { content })
1393 }
1394
1395 async fn write_text_file(
1396 &self,
1397 request: acp_old::WriteTextFileParams,
1398 ) -> Result<(), acp_old::Error> {
1399 self.cx
1400 .update(|cx| {
1401 self.thread.update(cx, |thread, cx| {
1402 thread.write_text_file(request.path, request.content, cx)
1403 })
1404 })?
1405 .context("Failed to update thread")?
1406 .await?;
1407
1408 Ok(())
1409 }
1410}
1411
1412fn acp_icon_to_ui_icon(icon: acp_old::Icon) -> IconName {
1413 match icon {
1414 acp_old::Icon::FileSearch => IconName::ToolSearch,
1415 acp_old::Icon::Folder => IconName::ToolFolder,
1416 acp_old::Icon::Globe => IconName::ToolWeb,
1417 acp_old::Icon::Hammer => IconName::ToolHammer,
1418 acp_old::Icon::LightBulb => IconName::ToolBulb,
1419 acp_old::Icon::Pencil => IconName::ToolPencil,
1420 acp_old::Icon::Regex => IconName::ToolRegex,
1421 acp_old::Icon::Terminal => IconName::ToolTerminal,
1422 }
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427 use super::*;
1428 use anyhow::anyhow;
1429 use async_pipe::{PipeReader, PipeWriter};
1430 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1431 use gpui::{AsyncApp, TestAppContext};
1432 use indoc::indoc;
1433 use project::FakeFs;
1434 use serde_json::json;
1435 use settings::SettingsStore;
1436 use smol::{future::BoxedLocal, stream::StreamExt as _};
1437 use std::{cell::RefCell, rc::Rc, time::Duration};
1438 use util::path;
1439
1440 fn init_test(cx: &mut TestAppContext) {
1441 env_logger::try_init().ok();
1442 cx.update(|cx| {
1443 let settings_store = SettingsStore::test(cx);
1444 cx.set_global(settings_store);
1445 Project::init_settings(cx);
1446 language::init(cx);
1447 });
1448 }
1449
1450 #[gpui::test]
1451 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1452 init_test(cx);
1453
1454 let fs = FakeFs::new(cx.executor());
1455 let project = Project::test(fs, [], cx).await;
1456 let (thread, fake_server) = fake_acp_thread(project, cx);
1457
1458 fake_server.update(cx, |fake_server, _| {
1459 fake_server.on_user_message(move |_, server, mut cx| async move {
1460 server
1461 .update(&mut cx, |server, _| {
1462 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1463 chunk: acp_old::AssistantMessageChunk::Thought {
1464 thought: "Thinking ".into(),
1465 },
1466 })
1467 })?
1468 .await
1469 .unwrap();
1470 server
1471 .update(&mut cx, |server, _| {
1472 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1473 chunk: acp_old::AssistantMessageChunk::Thought {
1474 thought: "hard!".into(),
1475 },
1476 })
1477 })?
1478 .await
1479 .unwrap();
1480
1481 Ok(())
1482 })
1483 });
1484
1485 thread
1486 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1487 .await
1488 .unwrap();
1489
1490 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1491 assert_eq!(
1492 output,
1493 indoc! {r#"
1494 ## User
1495
1496 Hello from Zed!
1497
1498 ## Assistant
1499
1500 <thinking>
1501 Thinking hard!
1502 </thinking>
1503
1504 "#}
1505 );
1506 }
1507
1508 #[gpui::test]
1509 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1510 init_test(cx);
1511
1512 let fs = FakeFs::new(cx.executor());
1513 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1514 .await;
1515 let project = Project::test(fs.clone(), [], cx).await;
1516 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1517 let (worktree, pathbuf) = project
1518 .update(cx, |project, cx| {
1519 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1520 })
1521 .await
1522 .unwrap();
1523 let buffer = project
1524 .update(cx, |project, cx| {
1525 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1526 })
1527 .await
1528 .unwrap();
1529
1530 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1531 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1532
1533 fake_server.update(cx, |fake_server, _| {
1534 fake_server.on_user_message(move |_, server, mut cx| {
1535 let read_file_tx = read_file_tx.clone();
1536 async move {
1537 let content = server
1538 .update(&mut cx, |server, _| {
1539 server.send_to_zed(acp_old::ReadTextFileParams {
1540 path: path!("/tmp/foo").into(),
1541 line: None,
1542 limit: None,
1543 })
1544 })?
1545 .await
1546 .unwrap();
1547 assert_eq!(content.content, "one\ntwo\nthree\n");
1548 read_file_tx.take().unwrap().send(()).unwrap();
1549 server
1550 .update(&mut cx, |server, _| {
1551 server.send_to_zed(acp_old::WriteTextFileParams {
1552 path: path!("/tmp/foo").into(),
1553 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1554 })
1555 })?
1556 .await
1557 .unwrap();
1558 Ok(())
1559 }
1560 })
1561 });
1562
1563 let request = thread.update(cx, |thread, cx| {
1564 thread.send_raw("Extend the count in /tmp/foo", cx)
1565 });
1566 read_file_rx.await.ok();
1567 buffer.update(cx, |buffer, cx| {
1568 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1569 });
1570 cx.run_until_parked();
1571 assert_eq!(
1572 buffer.read_with(cx, |buffer, _| buffer.text()),
1573 "zero\none\ntwo\nthree\nfour\nfive\n"
1574 );
1575 assert_eq!(
1576 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1577 "zero\none\ntwo\nthree\nfour\nfive\n"
1578 );
1579 request.await.unwrap();
1580 }
1581
1582 #[gpui::test]
1583 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1584 init_test(cx);
1585
1586 let fs = FakeFs::new(cx.executor());
1587 let project = Project::test(fs, [], cx).await;
1588 let (thread, fake_server) = fake_acp_thread(project, cx);
1589
1590 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1591
1592 let tool_call_id = Rc::new(RefCell::new(None));
1593 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1594 fake_server.update(cx, |fake_server, _| {
1595 let tool_call_id = tool_call_id.clone();
1596 fake_server.on_user_message(move |_, server, mut cx| {
1597 let end_turn_rx = end_turn_rx.clone();
1598 let tool_call_id = tool_call_id.clone();
1599 async move {
1600 let tool_call_result = server
1601 .update(&mut cx, |server, _| {
1602 server.send_to_zed(acp_old::PushToolCallParams {
1603 label: "Fetch".to_string(),
1604 icon: acp_old::Icon::Globe,
1605 content: None,
1606 locations: vec![],
1607 })
1608 })?
1609 .await
1610 .unwrap();
1611 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1612 end_turn_rx.take().unwrap().await.ok();
1613
1614 Ok(())
1615 }
1616 })
1617 });
1618
1619 let request = thread.update(cx, |thread, cx| {
1620 thread.send_raw("Fetch https://example.com", cx)
1621 });
1622
1623 run_until_first_tool_call(&thread, cx).await;
1624
1625 thread.read_with(cx, |thread, _| {
1626 assert!(matches!(
1627 thread.entries[1],
1628 AgentThreadEntry::ToolCall(ToolCall {
1629 status: ToolCallStatus::Allowed {
1630 status: acp::ToolCallStatus::InProgress,
1631 ..
1632 },
1633 ..
1634 })
1635 ));
1636 });
1637
1638 cx.run_until_parked();
1639
1640 thread
1641 .update(cx, |thread, cx| thread.cancel(cx))
1642 .await
1643 .unwrap();
1644
1645 thread.read_with(cx, |thread, _| {
1646 assert!(matches!(
1647 &thread.entries[1],
1648 AgentThreadEntry::ToolCall(ToolCall {
1649 status: ToolCallStatus::Canceled,
1650 ..
1651 })
1652 ));
1653 });
1654
1655 fake_server
1656 .update(cx, |fake_server, _| {
1657 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1658 tool_call_id: tool_call_id.borrow().unwrap(),
1659 status: acp_old::ToolCallStatus::Finished,
1660 content: None,
1661 })
1662 })
1663 .await
1664 .unwrap();
1665
1666 drop(end_turn_tx);
1667 request.await.unwrap();
1668
1669 thread.read_with(cx, |thread, _| {
1670 assert!(matches!(
1671 thread.entries[1],
1672 AgentThreadEntry::ToolCall(ToolCall {
1673 status: ToolCallStatus::Allowed {
1674 status: acp::ToolCallStatus::Completed,
1675 ..
1676 },
1677 ..
1678 })
1679 ));
1680 });
1681 }
1682
1683 async fn run_until_first_tool_call(
1684 thread: &Entity<AcpThread>,
1685 cx: &mut TestAppContext,
1686 ) -> usize {
1687 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1688
1689 let subscription = cx.update(|cx| {
1690 cx.subscribe(thread, move |thread, _, cx| {
1691 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1692 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1693 return tx.try_send(ix).unwrap();
1694 }
1695 }
1696 })
1697 });
1698
1699 select! {
1700 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1701 panic!("Timeout waiting for tool call")
1702 }
1703 ix = rx.next().fuse() => {
1704 drop(subscription);
1705 ix.unwrap()
1706 }
1707 }
1708 }
1709
1710 pub fn fake_acp_thread(
1711 project: Entity<Project>,
1712 cx: &mut TestAppContext,
1713 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1714 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1715 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1716
1717 let thread = cx.new(|cx| {
1718 let foreground_executor = cx.foreground_executor().clone();
1719 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1720 OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1721 stdin_tx,
1722 stdout_rx,
1723 move |fut| {
1724 foreground_executor.spawn(fut).detach();
1725 },
1726 );
1727
1728 let io_task = cx.background_spawn({
1729 async move {
1730 io_fut.await.log_err();
1731 Ok(())
1732 }
1733 });
1734 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1735 });
1736 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1737 (thread, agent)
1738 }
1739
1740 pub struct FakeAcpServer {
1741 connection: acp_old::ClientConnection,
1742
1743 _io_task: Task<()>,
1744 on_user_message: Option<
1745 Rc<
1746 dyn Fn(
1747 acp_old::SendUserMessageParams,
1748 Entity<FakeAcpServer>,
1749 AsyncApp,
1750 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1751 >,
1752 >,
1753 }
1754
1755 #[derive(Clone)]
1756 struct FakeAgent {
1757 server: Entity<FakeAcpServer>,
1758 cx: AsyncApp,
1759 }
1760
1761 impl acp_old::Agent for FakeAgent {
1762 async fn initialize(
1763 &self,
1764 params: acp_old::InitializeParams,
1765 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1766 Ok(acp_old::InitializeResponse {
1767 protocol_version: params.protocol_version,
1768 is_authenticated: true,
1769 })
1770 }
1771
1772 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1773 Ok(())
1774 }
1775
1776 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1777 Ok(())
1778 }
1779
1780 async fn send_user_message(
1781 &self,
1782 request: acp_old::SendUserMessageParams,
1783 ) -> Result<(), acp_old::Error> {
1784 let mut cx = self.cx.clone();
1785 let handler = self
1786 .server
1787 .update(&mut cx, |server, _| server.on_user_message.clone())
1788 .ok()
1789 .flatten();
1790 if let Some(handler) = handler {
1791 handler(request, self.server.clone(), self.cx.clone()).await
1792 } else {
1793 Err(anyhow::anyhow!("No handler for on_user_message").into())
1794 }
1795 }
1796 }
1797
1798 impl FakeAcpServer {
1799 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1800 let agent = FakeAgent {
1801 server: cx.entity(),
1802 cx: cx.to_async(),
1803 };
1804 let foreground_executor = cx.foreground_executor().clone();
1805
1806 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1807 agent.clone(),
1808 stdout,
1809 stdin,
1810 move |fut| {
1811 foreground_executor.spawn(fut).detach();
1812 },
1813 );
1814 FakeAcpServer {
1815 connection: connection,
1816 on_user_message: None,
1817 _io_task: cx.background_spawn(async move {
1818 io_fut.await.log_err();
1819 }),
1820 }
1821 }
1822
1823 fn on_user_message<F>(
1824 &mut self,
1825 handler: impl for<'a> Fn(
1826 acp_old::SendUserMessageParams,
1827 Entity<FakeAcpServer>,
1828 AsyncApp,
1829 ) -> F
1830 + 'static,
1831 ) where
1832 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1833 {
1834 self.on_user_message
1835 .replace(Rc::new(move |request, server, cx| {
1836 handler(request, server, cx).boxed_local()
1837 }));
1838 }
1839
1840 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1841 &self,
1842 message: T,
1843 ) -> BoxedLocal<Result<T::Response>> {
1844 self.connection
1845 .request(message)
1846 .map(|f| f.map_err(|err| anyhow!(err)))
1847 .boxed_local()
1848 }
1849 }
1850}