1mod connection;
2pub use connection::*;
3
4pub use acp_old::ToolCallId;
5use agent_client_protocol::{self as acp};
6use agentic_coding_protocol as acp_old;
7use anyhow::{Context as _, Result};
8use assistant_tool::ActionLog;
9use buffer_diff::BufferDiff;
10use editor::{Bias, MultiBuffer, PathKey};
11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
12use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
13use itertools::{Itertools, PeekingNext};
14use language::{
15 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
16 text_diff,
17};
18use markdown::Markdown;
19use project::{AgentLocation, Project};
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt::Formatter;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::{App, IconName};
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n{}\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(
114 acp::ContentBlock::Text(acp::TextContent {
115 text: chunk.to_owned().into(),
116 annotations: None,
117 }),
118 language_registry,
119 cx,
120 ),
121 }
122 }
123
124 fn to_markdown(&self, cx: &App) -> String {
125 match self {
126 Self::Message { block } => block.to_markdown(cx).to_string(),
127 Self::Thought { block } => {
128 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
129 }
130 }
131 }
132}
133
134#[derive(Debug)]
135pub enum AgentThreadEntry {
136 UserMessage(UserMessage),
137 AssistantMessage(AssistantMessage),
138 ToolCall(ToolCall),
139}
140
141impl AgentThreadEntry {
142 fn to_markdown(&self, cx: &App) -> String {
143 match self {
144 Self::UserMessage(message) => message.to_markdown(cx),
145 Self::AssistantMessage(message) => message.to_markdown(cx),
146 Self::ToolCall(too_call) => too_call.to_markdown(cx),
147 }
148 }
149
150 pub fn diff(&self) -> Option<&Diff> {
151 if let AgentThreadEntry::ToolCall(ToolCall {
152 content: Some(ToolCallContent::Diff { diff }),
153 ..
154 }) = self
155 {
156 Some(&diff)
157 } else {
158 None
159 }
160 }
161
162 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
163 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
164 Some(locations)
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 // todo! Should this be a vec?
177 pub content: Option<ToolCallContent>,
178 pub status: ToolCallStatus,
179 pub locations: Vec<acp::ToolCallLocation>,
180}
181
182impl ToolCall {
183 fn to_markdown(&self, cx: &App) -> String {
184 let mut markdown = format!(
185 "**Tool Call: {}**\nStatus: {}\n\n",
186 self.label.read(cx).source(),
187 self.status
188 );
189 if let Some(content) = &self.content {
190 markdown.push_str(content.to_markdown(cx).as_str());
191 markdown.push_str("\n\n");
192 }
193 markdown
194 }
195}
196
197#[derive(Debug)]
198pub enum ToolCallStatus {
199 WaitingForConfirmation {
200 possible_grants: Vec<acp::Grant>,
201 respond_tx: oneshot::Sender<acp::GrantId>,
202 },
203 Allowed {
204 status: acp::ToolCallStatus,
205 },
206 Rejected,
207 Canceled,
208}
209
210impl Display for ToolCallStatus {
211 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
212 write!(
213 f,
214 "{}",
215 match self {
216 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
217 ToolCallStatus::Allowed { status } => match status {
218 acp::ToolCallStatus::InProgress => "In Progress",
219 acp::ToolCallStatus::Completed => "Completed",
220 acp::ToolCallStatus::Failed => "Failed",
221 },
222 ToolCallStatus::Rejected => "Rejected",
223 ToolCallStatus::Canceled => "Canceled",
224 }
225 )
226 }
227}
228
229#[derive(Debug)]
230pub enum ToolCallConfirmation {
231 Edit {
232 description: Option<Entity<Markdown>>,
233 },
234 Execute {
235 command: String,
236 root_command: String,
237 description: Option<Entity<Markdown>>,
238 },
239 Mcp {
240 server_name: String,
241 tool_name: String,
242 tool_display_name: String,
243 description: Option<Entity<Markdown>>,
244 },
245 Fetch {
246 urls: Vec<SharedString>,
247 description: Option<Entity<Markdown>>,
248 },
249 Other {
250 description: Entity<Markdown>,
251 },
252}
253
254impl ToolCallConfirmation {
255 pub fn from_acp(
256 confirmation: acp_old::ToolCallConfirmation,
257 language_registry: Arc<LanguageRegistry>,
258 cx: &mut App,
259 ) -> Self {
260 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
261 cx.new(|cx| {
262 Markdown::new(
263 description.into(),
264 Some(language_registry.clone()),
265 None,
266 cx,
267 )
268 })
269 };
270
271 match confirmation {
272 acp_old::ToolCallConfirmation::Edit { description } => Self::Edit {
273 description: description.map(|description| to_md(description, cx)),
274 },
275 acp_old::ToolCallConfirmation::Execute {
276 command,
277 root_command,
278 description,
279 } => Self::Execute {
280 command,
281 root_command,
282 description: description.map(|description| to_md(description, cx)),
283 },
284 acp_old::ToolCallConfirmation::Mcp {
285 server_name,
286 tool_name,
287 tool_display_name,
288 description,
289 } => Self::Mcp {
290 server_name,
291 tool_name,
292 tool_display_name,
293 description: description.map(|description| to_md(description, cx)),
294 },
295 acp_old::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
296 urls: urls.iter().map(|url| url.into()).collect(),
297 description: description.map(|description| to_md(description, cx)),
298 },
299 acp_old::ToolCallConfirmation::Other { description } => Self::Other {
300 description: to_md(description, cx),
301 },
302 }
303 }
304}
305
306#[derive(Debug, PartialEq, Clone)]
307enum ContentBlock {
308 Empty,
309 Markdown { markdown: Entity<Markdown> },
310}
311
312impl ContentBlock {
313 pub fn new(
314 block: acp::ContentBlock,
315 language_registry: &Arc<LanguageRegistry>,
316 cx: &mut App,
317 ) -> Self {
318 let mut this = Self::Empty;
319 this.append(block, language_registry, cx);
320 this
321 }
322
323 pub fn new_combined(
324 blocks: impl IntoIterator<Item = acp::ContentBlock>,
325 language_registry: &Arc<LanguageRegistry>,
326 cx: &mut App,
327 ) -> Self {
328 let mut this = Self::Empty;
329 for block in blocks {
330 this.append(block, language_registry, cx);
331 }
332 this
333 }
334
335 pub fn append(
336 &mut self,
337 block: acp::ContentBlock,
338 language_registry: &Arc<LanguageRegistry>,
339 cx: &mut App,
340 ) {
341 let new_content = match block {
342 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
343 acp::ContentBlock::ResourceLink(resource_link) => {
344 if let Some(path) = resource_link.uri.strip_prefix("file://") {
345 format!("{}", MentionPath(path.as_ref()))
346 } else {
347 resource_link.uri.clone()
348 }
349 }
350 acp::ContentBlock::Image(_)
351 | acp::ContentBlock::Audio(_)
352 | acp::ContentBlock::Resource(_) => String::new(),
353 };
354
355 match self {
356 ContentBlock::Empty => {
357 *self = ContentBlock::Markdown {
358 markdown: cx.new(|cx| {
359 Markdown::new(
360 new_content.into(),
361 Some(language_registry.clone()),
362 None,
363 cx,
364 )
365 }),
366 };
367 }
368 ContentBlock::Markdown { markdown } => {
369 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
370 }
371 }
372 }
373
374 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
375 match self {
376 ContentBlock::Empty => "",
377 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
378 }
379 }
380}
381
382#[derive(Debug)]
383pub enum ToolCallContent {
384 ContentBlock { content: ContentBlock },
385 Diff { diff: Diff },
386}
387
388impl ToolCallContent {
389 pub fn from_acp(
390 content: acp::ToolCallContent,
391 language_registry: Arc<LanguageRegistry>,
392 cx: &mut App,
393 ) -> Self {
394 match content {
395 acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock {
396 content: ContentBlock::new(content, &language_registry, cx),
397 },
398 acp::ToolCallContent::Diff { diff } => Self::Diff {
399 diff: Diff::from_acp(diff, language_registry, cx),
400 },
401 }
402 }
403
404 pub fn from_acp_contents(
405 content: Vec<acp::ToolCallContent>,
406 language_registry: Arc<LanguageRegistry>,
407 cx: &mut App,
408 ) -> Vec<Self> {
409 content
410 .into_iter()
411 .peekable()
412 .batching(|it| match it.next()? {
413 acp::ToolCallContent::ContentBlock { content } => {
414 let mut block = ContentBlock::new(content, &language_registry, cx);
415 while let Some(acp::ToolCallContent::ContentBlock { content }) =
416 it.peeking_next(|c| matches!(c, acp::ToolCallContent::ContentBlock { .. }))
417 {
418 block.append(content, &language_registry, cx);
419 }
420 Some(ToolCallContent::ContentBlock { content: block })
421 }
422 content @ acp::ToolCallContent::Diff { .. } => Some(ToolCallContent::from_acp(
423 content,
424 language_registry.clone(),
425 cx,
426 )),
427 })
428 .collect()
429 }
430
431 fn to_markdown(&self, cx: &App) -> String {
432 match self {
433 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
434 Self::Diff { diff } => diff.to_markdown(cx),
435 }
436 }
437}
438
439#[derive(Debug)]
440pub struct Diff {
441 pub multibuffer: Entity<MultiBuffer>,
442 pub path: PathBuf,
443 pub new_buffer: Entity<Buffer>,
444 pub old_buffer: Entity<Buffer>,
445 _task: Task<Result<()>>,
446}
447
448impl Diff {
449 pub fn from_acp(
450 diff: acp::Diff,
451 language_registry: Arc<LanguageRegistry>,
452 cx: &mut App,
453 ) -> Self {
454 let acp::Diff {
455 path,
456 old_text,
457 new_text,
458 } = diff;
459
460 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
461
462 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
463 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
464 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
465 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
466 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
467 let diff_task = buffer_diff.update(cx, |diff, cx| {
468 diff.set_base_text(
469 old_buffer_snapshot,
470 Some(language_registry.clone()),
471 new_buffer_snapshot,
472 cx,
473 )
474 });
475
476 let task = cx.spawn({
477 let multibuffer = multibuffer.clone();
478 let path = path.clone();
479 let new_buffer = new_buffer.clone();
480 async move |cx| {
481 diff_task.await?;
482
483 multibuffer
484 .update(cx, |multibuffer, cx| {
485 let hunk_ranges = {
486 let buffer = new_buffer.read(cx);
487 let diff = buffer_diff.read(cx);
488 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
489 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
490 .collect::<Vec<_>>()
491 };
492
493 multibuffer.set_excerpts_for_path(
494 PathKey::for_buffer(&new_buffer, cx),
495 new_buffer.clone(),
496 hunk_ranges,
497 editor::DEFAULT_MULTIBUFFER_CONTEXT,
498 cx,
499 );
500 multibuffer.add_diff(buffer_diff.clone(), cx);
501 })
502 .log_err();
503
504 if let Some(language) = language_registry
505 .language_for_file_path(&path)
506 .await
507 .log_err()
508 {
509 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
510 }
511
512 anyhow::Ok(())
513 }
514 });
515
516 Self {
517 multibuffer,
518 path,
519 new_buffer,
520 old_buffer,
521 _task: task,
522 }
523 }
524
525 fn to_markdown(&self, cx: &App) -> String {
526 let buffer_text = self
527 .multibuffer
528 .read(cx)
529 .all_buffers()
530 .iter()
531 .map(|buffer| buffer.read(cx).text())
532 .join("\n");
533 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
534 }
535}
536
537#[derive(Debug, Default)]
538pub struct Plan {
539 pub entries: Vec<PlanEntry>,
540}
541
542#[derive(Debug)]
543pub struct PlanStats<'a> {
544 pub in_progress_entry: Option<&'a PlanEntry>,
545 pub pending: u32,
546 pub completed: u32,
547}
548
549impl Plan {
550 pub fn is_empty(&self) -> bool {
551 self.entries.is_empty()
552 }
553
554 pub fn stats(&self) -> PlanStats<'_> {
555 let mut stats = PlanStats {
556 in_progress_entry: None,
557 pending: 0,
558 completed: 0,
559 };
560
561 for entry in &self.entries {
562 match &entry.status {
563 acp_old::PlanEntryStatus::Pending => {
564 stats.pending += 1;
565 }
566 acp_old::PlanEntryStatus::InProgress => {
567 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
568 }
569 acp_old::PlanEntryStatus::Completed => {
570 stats.completed += 1;
571 }
572 }
573 }
574
575 stats
576 }
577}
578
579#[derive(Debug)]
580pub struct PlanEntry {
581 pub content: Entity<Markdown>,
582 pub priority: acp_old::PlanEntryPriority,
583 pub status: acp_old::PlanEntryStatus,
584}
585
586impl PlanEntry {
587 pub fn from_acp(entry: acp_old::PlanEntry, cx: &mut App) -> Self {
588 Self {
589 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
590 priority: entry.priority,
591 status: entry.status,
592 }
593 }
594}
595
596pub struct AcpThread {
597 title: SharedString,
598 entries: Vec<AgentThreadEntry>,
599 plan: Plan,
600 project: Entity<Project>,
601 action_log: Entity<ActionLog>,
602 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
603 send_task: Option<Task<()>>,
604 connection: Arc<dyn AgentConnection>,
605 child_status: Option<Task<Result<()>>>,
606}
607
608pub enum AcpThreadEvent {
609 NewEntry,
610 EntryUpdated(usize),
611}
612
613impl EventEmitter<AcpThreadEvent> for AcpThread {}
614
615#[derive(PartialEq, Eq)]
616pub enum ThreadStatus {
617 Idle,
618 WaitingForToolConfirmation,
619 Generating,
620}
621
622#[derive(Debug, Clone)]
623pub enum LoadError {
624 Unsupported {
625 error_message: SharedString,
626 upgrade_message: SharedString,
627 upgrade_command: String,
628 },
629 Exited(i32),
630 Other(SharedString),
631}
632
633impl Display for LoadError {
634 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
635 match self {
636 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
637 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
638 LoadError::Other(msg) => write!(f, "{}", msg),
639 }
640 }
641}
642
643impl Error for LoadError {}
644
645impl AcpThread {
646 pub fn new(
647 connection: impl AgentConnection + 'static,
648 title: SharedString,
649 child_status: Option<Task<Result<()>>>,
650 project: Entity<Project>,
651 cx: &mut Context<Self>,
652 ) -> Self {
653 let action_log = cx.new(|_| ActionLog::new(project.clone()));
654
655 Self {
656 action_log,
657 shared_buffers: Default::default(),
658 entries: Default::default(),
659 plan: Default::default(),
660 title,
661 project,
662 send_task: None,
663 connection: Arc::new(connection),
664 child_status,
665 }
666 }
667
668 /// Send a request to the agent and wait for a response.
669 pub fn request<R: acp_old::AgentRequest + 'static>(
670 &self,
671 params: R,
672 ) -> impl use<R> + Future<Output = Result<R::Response>> {
673 let params = params.into_any();
674 let result = self.connection.request_any(params);
675 async move {
676 let result = result.await?;
677 Ok(R::response_from_any(result)?)
678 }
679 }
680
681 pub fn action_log(&self) -> &Entity<ActionLog> {
682 &self.action_log
683 }
684
685 pub fn project(&self) -> &Entity<Project> {
686 &self.project
687 }
688
689 pub fn title(&self) -> SharedString {
690 self.title.clone()
691 }
692
693 pub fn entries(&self) -> &[AgentThreadEntry] {
694 &self.entries
695 }
696
697 pub fn status(&self) -> ThreadStatus {
698 if self.send_task.is_some() {
699 if self.waiting_for_tool_confirmation() {
700 ThreadStatus::WaitingForToolConfirmation
701 } else {
702 ThreadStatus::Generating
703 }
704 } else {
705 ThreadStatus::Idle
706 }
707 }
708
709 pub fn has_pending_edit_tool_calls(&self) -> bool {
710 for entry in self.entries.iter().rev() {
711 match entry {
712 AgentThreadEntry::UserMessage(_) => return false,
713 AgentThreadEntry::ToolCall(ToolCall {
714 status:
715 ToolCallStatus::Allowed {
716 status: acp::ToolCallStatus::InProgress,
717 ..
718 },
719 content: Some(ToolCallContent::Diff { .. }),
720 ..
721 }) => return true,
722 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
723 }
724 }
725
726 false
727 }
728
729 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
730 self.entries.push(entry);
731 cx.emit(AcpThreadEvent::NewEntry);
732 }
733
734 pub fn push_assistant_chunk(
735 &mut self,
736 chunk: acp::ContentBlock,
737 is_thought: bool,
738 cx: &mut Context<Self>,
739 ) {
740 let language_registry = self.project.read(cx).languages().clone();
741 let entries_len = self.entries.len();
742 if let Some(last_entry) = self.entries.last_mut()
743 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
744 {
745 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
746 match (chunks.last_mut(), is_thought) {
747 (Some(AssistantMessageChunk::Message { block }), false)
748 | (Some(AssistantMessageChunk::Thought { block }), true) => {
749 block.append(chunk, &language_registry, cx)
750 }
751 _ => {
752 let block = ContentBlock::new(chunk, &language_registry, cx);
753 if is_thought {
754 chunks.push(AssistantMessageChunk::Thought { block })
755 } else {
756 chunks.push(AssistantMessageChunk::Message { block })
757 }
758 }
759 }
760 } else {
761 let block = ContentBlock::new(chunk, &language_registry, cx);
762 let chunk = if is_thought {
763 AssistantMessageChunk::Thought { block }
764 } else {
765 AssistantMessageChunk::Message { block }
766 };
767
768 self.push_entry(
769 AgentThreadEntry::AssistantMessage(AssistantMessage {
770 chunks: vec![chunk],
771 }),
772 cx,
773 );
774 }
775 }
776
777 pub fn update_tool_call(
778 &mut self,
779 tool_call: acp::ToolCall,
780 cx: &mut Context<Self>,
781 ) -> Result<()> {
782 let language_registry = self.project.read(cx).languages().clone();
783
784 let new_tool_call = ToolCall {
785 id: tool_call.id,
786 label: cx.new(|cx| {
787 Markdown::new(
788 tool_call.label.into(),
789 Some(language_registry.clone()),
790 None,
791 cx,
792 )
793 }),
794 kind: tool_call.kind,
795 // todo! Do we assume there is either a coalesced content OR diff?
796 content: ToolCallContent::from_acp_contents(tool_call.content, language_registry, cx)
797 .into_iter()
798 .next(),
799 locations: tool_call.locations,
800 status: ToolCallStatus::Allowed {
801 status: tool_call.status,
802 },
803 };
804
805 if let Some((ix, current_call)) = self.tool_call_mut(tool_call.id) {
806 match &mut current_call.status {
807 ToolCallStatus::Allowed { status } => {
808 *status = tool_call.status;
809 }
810 ToolCallStatus::WaitingForConfirmation { .. } => {
811 anyhow::bail!("Tool call hasn't been authorized yet")
812 }
813 ToolCallStatus::Rejected => {
814 anyhow::bail!("Tool call was rejected and therefore can't be updated")
815 }
816 ToolCallStatus::Canceled => {
817 current_call.status = ToolCallStatus::Allowed { status: new_status };
818 }
819 }
820
821 *current_call = new_tool_call;
822
823 let location = current_call.locations.last().cloned();
824 if let Some(location) = location {
825 self.set_project_location(location, cx)
826 }
827
828 cx.emit(AcpThreadEvent::EntryUpdated(ix));
829 } else {
830 let language_registry = self.project.read(cx).languages().clone();
831 let call = ToolCall {
832 id: tool_call.id,
833 label: cx.new(|cx| {
834 Markdown::new(
835 tool_call.label.into(),
836 Some(language_registry.clone()),
837 None,
838 cx,
839 )
840 }),
841 kind: tool_call.kind,
842 // todo! Do we assume there is either a coalesced content OR diff?
843 content: ToolCallContent::from_acp_contents(
844 tool_call.content,
845 language_registry,
846 cx,
847 )
848 .into_iter()
849 .next(),
850 locations: tool_call.locations,
851 status: ToolCallStatus::Allowed {
852 status: tool_call.status,
853 },
854 };
855
856 let location = call.locations.last().cloned();
857 if let Some(location) = location {
858 self.set_project_location(location, cx)
859 }
860
861 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
862 }
863
864 Ok(())
865 }
866
867 fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
868 self.entries
869 .iter_mut()
870 .enumerate()
871 .rev()
872 .find_map(|(index, tool_call)| {
873 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
874 && tool_call.id == id
875 {
876 Some((index, tool_call))
877 } else {
878 None
879 }
880 })
881 }
882
883 pub fn request_tool_call_permission(
884 &mut self,
885 tool_call: acp::ToolCall,
886 possible_grants: Vec<acp::Grant>,
887 cx: &mut Context<Self>,
888 ) -> ToolCallRequest {
889 let (tx, rx) = oneshot::channel();
890
891 let status = ToolCallStatus::WaitingForConfirmation {
892 possible_grants,
893 respond_tx: tx,
894 };
895
896 let id = self.insert_tool_call(tool_call, status, cx);
897 ToolCallRequest { id, outcome: rx }
898 }
899
900 pub fn request_tool_call_confirmation(
901 &mut self,
902 tool_call_id: ToolCallId,
903 confirmation: acp_old::ToolCallConfirmation,
904 cx: &mut Context<Self>,
905 ) -> Result<ToolCallRequest> {
906 let project = self.project.read(cx).languages().clone();
907 let Some((idx, call)) = self.tool_call_mut(tool_call_id) else {
908 anyhow::bail!("Tool call not found");
909 };
910
911 let (tx, rx) = oneshot::channel();
912
913 call.status = ToolCallStatus::WaitingForConfirmation {
914 confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
915 respond_tx: tx,
916 };
917
918 cx.emit(AcpThreadEvent::EntryUpdated(idx));
919
920 Ok(ToolCallRequest {
921 id: tool_call_id,
922 outcome: rx,
923 })
924 }
925
926 fn insert_tool_call(
927 &mut self,
928 tool_call: acp_old::PushToolCallParams,
929 status: ToolCallStatus,
930 cx: &mut Context<Self>,
931 ) -> acp_old::ToolCallId {
932 let language_registry = self.project.read(cx).languages().clone();
933 let id = acp_old::ToolCallId(self.entries.len() as u64);
934 let call = ToolCall {
935 id,
936 label: cx.new(|cx| {
937 Markdown::new(
938 tool_call.label.into(),
939 Some(language_registry.clone()),
940 None,
941 cx,
942 )
943 }),
944 icon: acp_icon_to_ui_icon(tool_call.icon),
945 content: tool_call
946 .content
947 .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
948 locations: tool_call.locations,
949 status,
950 };
951
952 let location = call.locations.last().cloned();
953 if let Some(location) = location {
954 self.set_project_location(location, cx)
955 }
956
957 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
958
959 id
960 }
961
962 pub fn authorize_tool_call(
963 &mut self,
964 id: acp::ToolCallId,
965 grant: acp::Grant,
966 cx: &mut Context<Self>,
967 ) {
968 let Some((ix, call)) = self.tool_call_mut(id) else {
969 return;
970 };
971
972 let new_status = if grant.is_allowed {
973 ToolCallStatus::Allowed {
974 status: acp::ToolCallStatus::InProgress,
975 }
976 } else {
977 ToolCallStatus::Rejected
978 };
979
980 let curr_status = mem::replace(&mut call.status, new_status);
981
982 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
983 respond_tx.send(grant.id).log_err();
984 } else if cfg!(debug_assertions) {
985 panic!("tried to authorize an already authorized tool call");
986 }
987
988 cx.emit(AcpThreadEvent::EntryUpdated(ix));
989 }
990
991 pub fn plan(&self) -> &Plan {
992 &self.plan
993 }
994
995 pub fn update_plan(&mut self, request: acp_old::UpdatePlanParams, cx: &mut Context<Self>) {
996 self.plan = Plan {
997 entries: request
998 .entries
999 .into_iter()
1000 .map(|entry| PlanEntry::from_acp(entry, cx))
1001 .collect(),
1002 };
1003
1004 cx.notify();
1005 }
1006
1007 pub fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1008 self.plan
1009 .entries
1010 .retain(|entry| !matches!(entry.status, acp_old::PlanEntryStatus::Completed));
1011 cx.notify();
1012 }
1013
1014 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
1015 self.project.update(cx, |project, cx| {
1016 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
1017 return;
1018 };
1019 let buffer = project.open_buffer(path, cx);
1020 cx.spawn(async move |project, cx| {
1021 let buffer = buffer.await?;
1022
1023 project.update(cx, |project, cx| {
1024 let position = if let Some(line) = location.line {
1025 let snapshot = buffer.read(cx).snapshot();
1026 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
1027 snapshot.anchor_before(point)
1028 } else {
1029 Anchor::MIN
1030 };
1031
1032 project.set_agent_location(
1033 Some(AgentLocation {
1034 buffer: buffer.downgrade(),
1035 position,
1036 }),
1037 cx,
1038 );
1039 })
1040 })
1041 .detach_and_log_err(cx);
1042 });
1043 }
1044
1045 /// Returns true if the last turn is awaiting tool authorization
1046 pub fn waiting_for_tool_confirmation(&self) -> bool {
1047 for entry in self.entries.iter().rev() {
1048 match &entry {
1049 AgentThreadEntry::ToolCall(call) => match call.status {
1050 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1051 ToolCallStatus::Allowed { .. }
1052 | ToolCallStatus::Rejected
1053 | ToolCallStatus::Canceled => continue,
1054 },
1055 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1056 // Reached the beginning of the turn
1057 return false;
1058 }
1059 }
1060 }
1061 false
1062 }
1063
1064 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp_old::InitializeResponse>> {
1065 self.request(acp_old::InitializeParams {
1066 protocol_version: acp_old::ProtocolVersion::latest(),
1067 })
1068 }
1069
1070 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
1071 self.request(acp_old::AuthenticateParams)
1072 }
1073
1074 #[cfg(any(test, feature = "test-support"))]
1075 pub fn send_raw(
1076 &mut self,
1077 message: &str,
1078 cx: &mut Context<Self>,
1079 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1080 self.send(
1081 vec![acp::ContentBlock::Text(acp::TextContent {
1082 text: message.to_string(),
1083 annotations: None,
1084 })],
1085 cx,
1086 )
1087 }
1088
1089 pub fn send(
1090 &mut self,
1091 message: Vec<acp::ContentBlock>,
1092 cx: &mut Context<Self>,
1093 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1094 let block =
1095 ContentBlock::new_combined(message.clone(), self.project.read(cx).languages(), cx);
1096 self.push_entry(
1097 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1098 cx,
1099 );
1100
1101 let (tx, rx) = oneshot::channel();
1102 let cancel = self.cancel(cx);
1103
1104 self.send_task = Some(cx.spawn(async move |this, cx| {
1105 async {
1106 cancel.await.log_err();
1107
1108 let result = this.update(cx, |this, _| this.request(message))?.await;
1109 tx.send(result).log_err();
1110 this.update(cx, |this, _cx| this.send_task.take())?;
1111 anyhow::Ok(())
1112 }
1113 .await
1114 .log_err();
1115 }));
1116
1117 async move {
1118 match rx.await {
1119 Ok(Err(e)) => Err(e)?,
1120 _ => Ok(()),
1121 }
1122 }
1123 .boxed()
1124 }
1125
1126 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp_old::Error>> {
1127 if self.send_task.take().is_some() {
1128 let request = self.request(acp_old::CancelSendMessageParams);
1129 cx.spawn(async move |this, cx| {
1130 request.await?;
1131 this.update(cx, |this, _cx| {
1132 for entry in this.entries.iter_mut() {
1133 if let AgentThreadEntry::ToolCall(call) = entry {
1134 let cancel = matches!(
1135 call.status,
1136 ToolCallStatus::WaitingForConfirmation { .. }
1137 | ToolCallStatus::Allowed {
1138 status: acp_old::ToolCallStatus::Running
1139 }
1140 );
1141
1142 if cancel {
1143 let curr_status =
1144 mem::replace(&mut call.status, ToolCallStatus::Canceled);
1145
1146 if let ToolCallStatus::WaitingForConfirmation {
1147 respond_tx, ..
1148 } = curr_status
1149 {
1150 respond_tx
1151 .send(acp_old::ToolCallConfirmationOutcome::Cancel)
1152 .ok();
1153 }
1154 }
1155 }
1156 }
1157 })?;
1158 Ok(())
1159 })
1160 } else {
1161 Task::ready(Ok(()))
1162 }
1163 }
1164
1165 pub fn read_text_file(
1166 &self,
1167 request: acp_old::ReadTextFileParams,
1168 reuse_shared_snapshot: bool,
1169 cx: &mut Context<Self>,
1170 ) -> Task<Result<String>> {
1171 let project = self.project.clone();
1172 let action_log = self.action_log.clone();
1173 cx.spawn(async move |this, cx| {
1174 let load = project.update(cx, |project, cx| {
1175 let path = project
1176 .project_path_for_absolute_path(&request.path, cx)
1177 .context("invalid path")?;
1178 anyhow::Ok(project.open_buffer(path, cx))
1179 });
1180 let buffer = load??.await?;
1181
1182 let snapshot = if reuse_shared_snapshot {
1183 this.read_with(cx, |this, _| {
1184 this.shared_buffers.get(&buffer.clone()).cloned()
1185 })
1186 .log_err()
1187 .flatten()
1188 } else {
1189 None
1190 };
1191
1192 let snapshot = if let Some(snapshot) = snapshot {
1193 snapshot
1194 } else {
1195 action_log.update(cx, |action_log, cx| {
1196 action_log.buffer_read(buffer.clone(), cx);
1197 })?;
1198 project.update(cx, |project, cx| {
1199 let position = buffer
1200 .read(cx)
1201 .snapshot()
1202 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1203 project.set_agent_location(
1204 Some(AgentLocation {
1205 buffer: buffer.downgrade(),
1206 position,
1207 }),
1208 cx,
1209 );
1210 })?;
1211
1212 buffer.update(cx, |buffer, _| buffer.snapshot())?
1213 };
1214
1215 this.update(cx, |this, _| {
1216 let text = snapshot.text();
1217 this.shared_buffers.insert(buffer.clone(), snapshot);
1218 if request.line.is_none() && request.limit.is_none() {
1219 return Ok(text);
1220 }
1221 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1222 let Some(line) = request.line else {
1223 return Ok(text.lines().take(limit).collect::<String>());
1224 };
1225
1226 let count = text.lines().count();
1227 if count < line as usize {
1228 anyhow::bail!("There are only {} lines", count);
1229 }
1230 Ok(text
1231 .lines()
1232 .skip(line as usize + 1)
1233 .take(limit)
1234 .collect::<String>())
1235 })?
1236 })
1237 }
1238
1239 pub fn write_text_file(
1240 &self,
1241 path: PathBuf,
1242 content: String,
1243 cx: &mut Context<Self>,
1244 ) -> Task<Result<()>> {
1245 let project = self.project.clone();
1246 let action_log = self.action_log.clone();
1247 cx.spawn(async move |this, cx| {
1248 let load = project.update(cx, |project, cx| {
1249 let path = project
1250 .project_path_for_absolute_path(&path, cx)
1251 .context("invalid path")?;
1252 anyhow::Ok(project.open_buffer(path, cx))
1253 });
1254 let buffer = load??.await?;
1255 let snapshot = this.update(cx, |this, cx| {
1256 this.shared_buffers
1257 .get(&buffer)
1258 .cloned()
1259 .unwrap_or_else(|| buffer.read(cx).snapshot())
1260 })?;
1261 let edits = cx
1262 .background_executor()
1263 .spawn(async move {
1264 let old_text = snapshot.text();
1265 text_diff(old_text.as_str(), &content)
1266 .into_iter()
1267 .map(|(range, replacement)| {
1268 (
1269 snapshot.anchor_after(range.start)
1270 ..snapshot.anchor_before(range.end),
1271 replacement,
1272 )
1273 })
1274 .collect::<Vec<_>>()
1275 })
1276 .await;
1277 cx.update(|cx| {
1278 project.update(cx, |project, cx| {
1279 project.set_agent_location(
1280 Some(AgentLocation {
1281 buffer: buffer.downgrade(),
1282 position: edits
1283 .last()
1284 .map(|(range, _)| range.end)
1285 .unwrap_or(Anchor::MIN),
1286 }),
1287 cx,
1288 );
1289 });
1290
1291 action_log.update(cx, |action_log, cx| {
1292 action_log.buffer_read(buffer.clone(), cx);
1293 });
1294 buffer.update(cx, |buffer, cx| {
1295 buffer.edit(edits, None, cx);
1296 });
1297 action_log.update(cx, |action_log, cx| {
1298 action_log.buffer_edited(buffer.clone(), cx);
1299 });
1300 })?;
1301 project
1302 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1303 .await
1304 })
1305 }
1306
1307 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1308 self.child_status.take()
1309 }
1310
1311 pub fn to_markdown(&self, cx: &App) -> String {
1312 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1313 }
1314}
1315
1316#[derive(Clone)]
1317pub struct OldAcpClientDelegate {
1318 thread: WeakEntity<AcpThread>,
1319 cx: AsyncApp,
1320 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1321}
1322
1323impl OldAcpClientDelegate {
1324 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1325 Self { thread, cx }
1326 }
1327
1328 pub async fn clear_completed_plan_entries(&self) -> Result<()> {
1329 let cx = &mut self.cx.clone();
1330 cx.update(|cx| {
1331 self.thread
1332 .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
1333 })?
1334 .context("Failed to update thread")?;
1335
1336 Ok(())
1337 }
1338
1339 pub async fn request_existing_tool_call_confirmation(
1340 &self,
1341 tool_call_id: ToolCallId,
1342 confirmation: acp_old::ToolCallConfirmation,
1343 ) -> Result<acp_old::ToolCallConfirmationOutcome> {
1344 let cx = &mut self.cx.clone();
1345 let ToolCallRequest { outcome, .. } = cx
1346 .update(|cx| {
1347 self.thread.update(cx, |thread, cx| {
1348 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1349 })
1350 })?
1351 .context("Failed to update thread")??;
1352
1353 Ok(outcome.await?)
1354 }
1355
1356 pub async fn read_text_file_reusing_snapshot(
1357 &self,
1358 request: acp_old::ReadTextFileParams,
1359 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1360 let content = self
1361 .cx
1362 .update(|cx| {
1363 self.thread
1364 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1365 })?
1366 .context("Failed to update thread")?
1367 .await?;
1368 Ok(acp_old::ReadTextFileResponse { content })
1369 }
1370}
1371
1372impl acp_old::Client for OldAcpClientDelegate {
1373 async fn stream_assistant_message_chunk(
1374 &self,
1375 params: acp_old::StreamAssistantMessageChunkParams,
1376 ) -> Result<(), acp_old::Error> {
1377 let cx = &mut self.cx.clone();
1378
1379 cx.update(|cx| {
1380 self.thread
1381 .update(cx, |thread, cx| {
1382 thread.push_assistant_chunk(params.chunk, cx)
1383 })
1384 .ok();
1385 })?;
1386
1387 Ok(())
1388 }
1389
1390 async fn request_tool_call_confirmation(
1391 &self,
1392 request: acp_old::RequestToolCallConfirmationParams,
1393 ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
1394 let cx = &mut self.cx.clone();
1395 let ToolCallRequest { id, outcome } = cx
1396 .update(|cx| {
1397 self.thread
1398 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1399 })?
1400 .context("Failed to update thread")?;
1401
1402 Ok(acp_old::RequestToolCallConfirmationResponse {
1403 id,
1404 outcome: outcome.await.map_err(acp_old::Error::into_internal_error)?,
1405 })
1406 }
1407
1408 async fn push_tool_call(
1409 &self,
1410 request: acp_old::PushToolCallParams,
1411 ) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
1412 let cx = &mut self.cx.clone();
1413 let id = cx
1414 .update(|cx| {
1415 self.thread
1416 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1417 })?
1418 .context("Failed to update thread")?;
1419
1420 Ok(acp_old::PushToolCallResponse { id })
1421 }
1422
1423 async fn update_tool_call(
1424 &self,
1425 request: acp_old::UpdateToolCallParams,
1426 ) -> Result<(), acp_old::Error> {
1427 let cx = &mut self.cx.clone();
1428
1429 cx.update(|cx| {
1430 self.thread.update(cx, |thread, cx| {
1431 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1432 })
1433 })?
1434 .context("Failed to update thread")??;
1435
1436 Ok(())
1437 }
1438
1439 async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
1440 let cx = &mut self.cx.clone();
1441
1442 cx.update(|cx| {
1443 self.thread
1444 .update(cx, |thread, cx| thread.update_plan(request, cx))
1445 })?
1446 .context("Failed to update thread")?;
1447
1448 Ok(())
1449 }
1450
1451 async fn read_text_file(
1452 &self,
1453 request: acp_old::ReadTextFileParams,
1454 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1455 let content = self
1456 .cx
1457 .update(|cx| {
1458 self.thread
1459 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1460 })?
1461 .context("Failed to update thread")?
1462 .await?;
1463 Ok(acp_old::ReadTextFileResponse { content })
1464 }
1465
1466 async fn write_text_file(
1467 &self,
1468 request: acp_old::WriteTextFileParams,
1469 ) -> Result<(), acp_old::Error> {
1470 self.cx
1471 .update(|cx| {
1472 self.thread.update(cx, |thread, cx| {
1473 thread.write_text_file(request.path, request.content, cx)
1474 })
1475 })?
1476 .context("Failed to update thread")?
1477 .await?;
1478
1479 Ok(())
1480 }
1481}
1482
1483fn acp_icon_to_ui_icon(icon: acp_old::Icon) -> IconName {
1484 match icon {
1485 acp_old::Icon::FileSearch => IconName::ToolSearch,
1486 acp_old::Icon::Folder => IconName::ToolFolder,
1487 acp_old::Icon::Globe => IconName::ToolWeb,
1488 acp_old::Icon::Hammer => IconName::ToolHammer,
1489 acp_old::Icon::LightBulb => IconName::ToolBulb,
1490 acp_old::Icon::Pencil => IconName::ToolPencil,
1491 acp_old::Icon::Regex => IconName::ToolRegex,
1492 acp_old::Icon::Terminal => IconName::ToolTerminal,
1493 }
1494}
1495
1496#[cfg(test)]
1497mod tests {
1498 use super::*;
1499 use anyhow::anyhow;
1500 use async_pipe::{PipeReader, PipeWriter};
1501 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1502 use gpui::{AsyncApp, TestAppContext};
1503 use indoc::indoc;
1504 use project::FakeFs;
1505 use serde_json::json;
1506 use settings::SettingsStore;
1507 use smol::{future::BoxedLocal, stream::StreamExt as _};
1508 use std::{cell::RefCell, rc::Rc, time::Duration};
1509 use util::path;
1510
1511 fn init_test(cx: &mut TestAppContext) {
1512 env_logger::try_init().ok();
1513 cx.update(|cx| {
1514 let settings_store = SettingsStore::test(cx);
1515 cx.set_global(settings_store);
1516 Project::init_settings(cx);
1517 language::init(cx);
1518 });
1519 }
1520
1521 #[gpui::test]
1522 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1523 init_test(cx);
1524
1525 let fs = FakeFs::new(cx.executor());
1526 let project = Project::test(fs, [], cx).await;
1527 let (thread, fake_server) = fake_acp_thread(project, cx);
1528
1529 fake_server.update(cx, |fake_server, _| {
1530 fake_server.on_user_message(move |_, server, mut cx| async move {
1531 server
1532 .update(&mut cx, |server, _| {
1533 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1534 chunk: acp_old::AssistantMessageChunk::Thought {
1535 thought: "Thinking ".into(),
1536 },
1537 })
1538 })?
1539 .await
1540 .unwrap();
1541 server
1542 .update(&mut cx, |server, _| {
1543 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1544 chunk: acp_old::AssistantMessageChunk::Thought {
1545 thought: "hard!".into(),
1546 },
1547 })
1548 })?
1549 .await
1550 .unwrap();
1551
1552 Ok(())
1553 })
1554 });
1555
1556 thread
1557 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1558 .await
1559 .unwrap();
1560
1561 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1562 assert_eq!(
1563 output,
1564 indoc! {r#"
1565 ## User
1566
1567 Hello from Zed!
1568
1569 ## Assistant
1570
1571 <thinking>
1572 Thinking hard!
1573 </thinking>
1574
1575 "#}
1576 );
1577 }
1578
1579 #[gpui::test]
1580 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1581 init_test(cx);
1582
1583 let fs = FakeFs::new(cx.executor());
1584 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1585 .await;
1586 let project = Project::test(fs.clone(), [], cx).await;
1587 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1588 let (worktree, pathbuf) = project
1589 .update(cx, |project, cx| {
1590 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1591 })
1592 .await
1593 .unwrap();
1594 let buffer = project
1595 .update(cx, |project, cx| {
1596 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1597 })
1598 .await
1599 .unwrap();
1600
1601 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1602 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1603
1604 fake_server.update(cx, |fake_server, _| {
1605 fake_server.on_user_message(move |_, server, mut cx| {
1606 let read_file_tx = read_file_tx.clone();
1607 async move {
1608 let content = server
1609 .update(&mut cx, |server, _| {
1610 server.send_to_zed(acp_old::ReadTextFileParams {
1611 path: path!("/tmp/foo").into(),
1612 line: None,
1613 limit: None,
1614 })
1615 })?
1616 .await
1617 .unwrap();
1618 assert_eq!(content.content, "one\ntwo\nthree\n");
1619 read_file_tx.take().unwrap().send(()).unwrap();
1620 server
1621 .update(&mut cx, |server, _| {
1622 server.send_to_zed(acp_old::WriteTextFileParams {
1623 path: path!("/tmp/foo").into(),
1624 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1625 })
1626 })?
1627 .await
1628 .unwrap();
1629 Ok(())
1630 }
1631 })
1632 });
1633
1634 let request = thread.update(cx, |thread, cx| {
1635 thread.send_raw("Extend the count in /tmp/foo", cx)
1636 });
1637 read_file_rx.await.ok();
1638 buffer.update(cx, |buffer, cx| {
1639 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1640 });
1641 cx.run_until_parked();
1642 assert_eq!(
1643 buffer.read_with(cx, |buffer, _| buffer.text()),
1644 "zero\none\ntwo\nthree\nfour\nfive\n"
1645 );
1646 assert_eq!(
1647 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1648 "zero\none\ntwo\nthree\nfour\nfive\n"
1649 );
1650 request.await.unwrap();
1651 }
1652
1653 #[gpui::test]
1654 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1655 init_test(cx);
1656
1657 let fs = FakeFs::new(cx.executor());
1658 let project = Project::test(fs, [], cx).await;
1659 let (thread, fake_server) = fake_acp_thread(project, cx);
1660
1661 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1662
1663 let tool_call_id = Rc::new(RefCell::new(None));
1664 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1665 fake_server.update(cx, |fake_server, _| {
1666 let tool_call_id = tool_call_id.clone();
1667 fake_server.on_user_message(move |_, server, mut cx| {
1668 let end_turn_rx = end_turn_rx.clone();
1669 let tool_call_id = tool_call_id.clone();
1670 async move {
1671 let tool_call_result = server
1672 .update(&mut cx, |server, _| {
1673 server.send_to_zed(acp_old::PushToolCallParams {
1674 label: "Fetch".to_string(),
1675 icon: acp_old::Icon::Globe,
1676 content: None,
1677 locations: vec![],
1678 })
1679 })?
1680 .await
1681 .unwrap();
1682 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1683 end_turn_rx.take().unwrap().await.ok();
1684
1685 Ok(())
1686 }
1687 })
1688 });
1689
1690 let request = thread.update(cx, |thread, cx| {
1691 thread.send_raw("Fetch https://example.com", cx)
1692 });
1693
1694 run_until_first_tool_call(&thread, cx).await;
1695
1696 thread.read_with(cx, |thread, _| {
1697 assert!(matches!(
1698 thread.entries[1],
1699 AgentThreadEntry::ToolCall(ToolCall {
1700 status: ToolCallStatus::Allowed {
1701 status: acp_old::ToolCallStatus::Running,
1702 ..
1703 },
1704 ..
1705 })
1706 ));
1707 });
1708
1709 cx.run_until_parked();
1710
1711 thread
1712 .update(cx, |thread, cx| thread.cancel(cx))
1713 .await
1714 .unwrap();
1715
1716 thread.read_with(cx, |thread, _| {
1717 assert!(matches!(
1718 &thread.entries[1],
1719 AgentThreadEntry::ToolCall(ToolCall {
1720 status: ToolCallStatus::Canceled,
1721 ..
1722 })
1723 ));
1724 });
1725
1726 fake_server
1727 .update(cx, |fake_server, _| {
1728 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1729 tool_call_id: tool_call_id.borrow().unwrap(),
1730 status: acp_old::ToolCallStatus::Finished,
1731 content: None,
1732 })
1733 })
1734 .await
1735 .unwrap();
1736
1737 drop(end_turn_tx);
1738 request.await.unwrap();
1739
1740 thread.read_with(cx, |thread, _| {
1741 assert!(matches!(
1742 thread.entries[1],
1743 AgentThreadEntry::ToolCall(ToolCall {
1744 status: ToolCallStatus::Allowed {
1745 status: acp_old::ToolCallStatus::Finished,
1746 ..
1747 },
1748 ..
1749 })
1750 ));
1751 });
1752 }
1753
1754 async fn run_until_first_tool_call(
1755 thread: &Entity<AcpThread>,
1756 cx: &mut TestAppContext,
1757 ) -> usize {
1758 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1759
1760 let subscription = cx.update(|cx| {
1761 cx.subscribe(thread, move |thread, _, cx| {
1762 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1763 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1764 return tx.try_send(ix).unwrap();
1765 }
1766 }
1767 })
1768 });
1769
1770 select! {
1771 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1772 panic!("Timeout waiting for tool call")
1773 }
1774 ix = rx.next().fuse() => {
1775 drop(subscription);
1776 ix.unwrap()
1777 }
1778 }
1779 }
1780
1781 pub fn fake_acp_thread(
1782 project: Entity<Project>,
1783 cx: &mut TestAppContext,
1784 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1785 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1786 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1787
1788 let thread = cx.new(|cx| {
1789 let foreground_executor = cx.foreground_executor().clone();
1790 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1791 OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1792 stdin_tx,
1793 stdout_rx,
1794 move |fut| {
1795 foreground_executor.spawn(fut).detach();
1796 },
1797 );
1798
1799 let io_task = cx.background_spawn({
1800 async move {
1801 io_fut.await.log_err();
1802 Ok(())
1803 }
1804 });
1805 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1806 });
1807 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1808 (thread, agent)
1809 }
1810
1811 pub struct FakeAcpServer {
1812 connection: acp_old::ClientConnection,
1813
1814 _io_task: Task<()>,
1815 on_user_message: Option<
1816 Rc<
1817 dyn Fn(
1818 acp_old::SendUserMessageParams,
1819 Entity<FakeAcpServer>,
1820 AsyncApp,
1821 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1822 >,
1823 >,
1824 }
1825
1826 #[derive(Clone)]
1827 struct FakeAgent {
1828 server: Entity<FakeAcpServer>,
1829 cx: AsyncApp,
1830 }
1831
1832 impl acp_old::Agent for FakeAgent {
1833 async fn initialize(
1834 &self,
1835 params: acp_old::InitializeParams,
1836 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1837 Ok(acp_old::InitializeResponse {
1838 protocol_version: params.protocol_version,
1839 is_authenticated: true,
1840 })
1841 }
1842
1843 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1844 Ok(())
1845 }
1846
1847 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1848 Ok(())
1849 }
1850
1851 async fn send_user_message(
1852 &self,
1853 request: acp_old::SendUserMessageParams,
1854 ) -> Result<(), acp_old::Error> {
1855 let mut cx = self.cx.clone();
1856 let handler = self
1857 .server
1858 .update(&mut cx, |server, _| server.on_user_message.clone())
1859 .ok()
1860 .flatten();
1861 if let Some(handler) = handler {
1862 handler(request, self.server.clone(), self.cx.clone()).await
1863 } else {
1864 Err(anyhow::anyhow!("No handler for on_user_message").into())
1865 }
1866 }
1867 }
1868
1869 impl FakeAcpServer {
1870 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1871 let agent = FakeAgent {
1872 server: cx.entity(),
1873 cx: cx.to_async(),
1874 };
1875 let foreground_executor = cx.foreground_executor().clone();
1876
1877 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1878 agent.clone(),
1879 stdout,
1880 stdin,
1881 move |fut| {
1882 foreground_executor.spawn(fut).detach();
1883 },
1884 );
1885 FakeAcpServer {
1886 connection: connection,
1887 on_user_message: None,
1888 _io_task: cx.background_spawn(async move {
1889 io_fut.await.log_err();
1890 }),
1891 }
1892 }
1893
1894 fn on_user_message<F>(
1895 &mut self,
1896 handler: impl for<'a> Fn(
1897 acp_old::SendUserMessageParams,
1898 Entity<FakeAcpServer>,
1899 AsyncApp,
1900 ) -> F
1901 + 'static,
1902 ) where
1903 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1904 {
1905 self.on_user_message
1906 .replace(Rc::new(move |request, server, cx| {
1907 handler(request, server, cx).boxed_local()
1908 }));
1909 }
1910
1911 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1912 &self,
1913 message: T,
1914 ) -> BoxedLocal<Result<T::Response>> {
1915 self.connection
1916 .request(message)
1917 .map(|f| f.map_err(|err| anyhow!(err)))
1918 .boxed_local()
1919 }
1920 }
1921}