1mod connection;
2mod old_acp_support;
3pub use connection::*;
4pub use old_acp_support::*;
5
6use agent_client_protocol as acp;
7use anyhow::{Context as _, Result};
8use assistant_tool::ActionLog;
9use buffer_diff::BufferDiff;
10use editor::{Bias, MultiBuffer, PathKey};
11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
12use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
13use itertools::Itertools;
14use language::{
15 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
16 text_diff,
17};
18use markdown::Markdown;
19use project::{AgentLocation, Project};
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt::Formatter;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
145 if let AgentThreadEntry::ToolCall(call) = self {
146 itertools::Either::Left(call.diffs())
147 } else {
148 itertools::Either::Right(std::iter::empty())
149 }
150 }
151
152 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
153 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
154 Some(locations)
155 } else {
156 None
157 }
158 }
159}
160
161#[derive(Debug)]
162pub struct ToolCall {
163 pub id: acp::ToolCallId,
164 pub label: Entity<Markdown>,
165 pub kind: acp::ToolKind,
166 pub content: Vec<ToolCallContent>,
167 pub status: ToolCallStatus,
168 pub locations: Vec<acp::ToolCallLocation>,
169 pub raw_input: Option<serde_json::Value>,
170}
171
172impl ToolCall {
173 fn from_acp(
174 tool_call: acp::ToolCall,
175 status: ToolCallStatus,
176 language_registry: Arc<LanguageRegistry>,
177 cx: &mut App,
178 ) -> Self {
179 Self {
180 id: tool_call.id,
181 label: cx.new(|cx| {
182 Markdown::new(
183 tool_call.label.into(),
184 Some(language_registry.clone()),
185 None,
186 cx,
187 )
188 }),
189 kind: tool_call.kind,
190 content: tool_call
191 .content
192 .into_iter()
193 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
194 .collect(),
195 locations: tool_call.locations,
196 status,
197 raw_input: tool_call.raw_input,
198 }
199 }
200
201 fn update(
202 &mut self,
203 fields: acp::ToolCallUpdateFields,
204 language_registry: Arc<LanguageRegistry>,
205 cx: &mut App,
206 ) {
207 let acp::ToolCallUpdateFields {
208 kind,
209 status,
210 label,
211 content,
212 locations,
213 raw_input,
214 } = fields;
215
216 if let Some(kind) = kind {
217 self.kind = kind;
218 }
219
220 if let Some(status) = status {
221 self.status = ToolCallStatus::Allowed { status };
222 }
223
224 if let Some(label) = label {
225 self.label = cx.new(|cx| Markdown::new_text(label.into(), cx));
226 }
227
228 if let Some(content) = content {
229 self.content = content
230 .into_iter()
231 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
232 .collect();
233 }
234
235 if let Some(locations) = locations {
236 self.locations = locations;
237 }
238
239 if let Some(raw_input) = raw_input {
240 self.raw_input = Some(raw_input);
241 }
242 }
243
244 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
245 self.content.iter().filter_map(|content| match content {
246 ToolCallContent::ContentBlock { .. } => None,
247 ToolCallContent::Diff { diff } => Some(diff),
248 })
249 }
250
251 fn to_markdown(&self, cx: &App) -> String {
252 let mut markdown = format!(
253 "**Tool Call: {}**\nStatus: {}\n\n",
254 self.label.read(cx).source(),
255 self.status
256 );
257 for content in &self.content {
258 markdown.push_str(content.to_markdown(cx).as_str());
259 markdown.push_str("\n\n");
260 }
261 markdown
262 }
263}
264
265#[derive(Debug)]
266pub enum ToolCallStatus {
267 WaitingForConfirmation {
268 options: Vec<acp::PermissionOption>,
269 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
270 },
271 Allowed {
272 status: acp::ToolCallStatus,
273 },
274 Rejected,
275 Canceled,
276}
277
278impl Display for ToolCallStatus {
279 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280 write!(
281 f,
282 "{}",
283 match self {
284 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
285 ToolCallStatus::Allowed { status } => match status {
286 acp::ToolCallStatus::Pending => "Pending",
287 acp::ToolCallStatus::InProgress => "In Progress",
288 acp::ToolCallStatus::Completed => "Completed",
289 acp::ToolCallStatus::Failed => "Failed",
290 },
291 ToolCallStatus::Rejected => "Rejected",
292 ToolCallStatus::Canceled => "Canceled",
293 }
294 )
295 }
296}
297
298#[derive(Debug, PartialEq, Clone)]
299pub enum ContentBlock {
300 Empty,
301 Markdown { markdown: Entity<Markdown> },
302}
303
304impl ContentBlock {
305 pub fn new(
306 block: acp::ContentBlock,
307 language_registry: &Arc<LanguageRegistry>,
308 cx: &mut App,
309 ) -> Self {
310 let mut this = Self::Empty;
311 this.append(block, language_registry, cx);
312 this
313 }
314
315 pub fn new_combined(
316 blocks: impl IntoIterator<Item = acp::ContentBlock>,
317 language_registry: Arc<LanguageRegistry>,
318 cx: &mut App,
319 ) -> Self {
320 let mut this = Self::Empty;
321 for block in blocks {
322 this.append(block, &language_registry, cx);
323 }
324 this
325 }
326
327 pub fn append(
328 &mut self,
329 block: acp::ContentBlock,
330 language_registry: &Arc<LanguageRegistry>,
331 cx: &mut App,
332 ) {
333 let new_content = match block {
334 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
335 acp::ContentBlock::ResourceLink(resource_link) => {
336 if let Some(path) = resource_link.uri.strip_prefix("file://") {
337 format!("{}", MentionPath(path.as_ref()))
338 } else {
339 resource_link.uri.clone()
340 }
341 }
342 acp::ContentBlock::Image(_)
343 | acp::ContentBlock::Audio(_)
344 | acp::ContentBlock::Resource(_) => String::new(),
345 };
346
347 match self {
348 ContentBlock::Empty => {
349 *self = ContentBlock::Markdown {
350 markdown: cx.new(|cx| {
351 Markdown::new(
352 new_content.into(),
353 Some(language_registry.clone()),
354 None,
355 cx,
356 )
357 }),
358 };
359 }
360 ContentBlock::Markdown { markdown } => {
361 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
362 }
363 }
364 }
365
366 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
367 match self {
368 ContentBlock::Empty => "",
369 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
370 }
371 }
372
373 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
374 match self {
375 ContentBlock::Empty => None,
376 ContentBlock::Markdown { markdown } => Some(markdown),
377 }
378 }
379}
380
381#[derive(Debug)]
382pub enum ToolCallContent {
383 ContentBlock { content: ContentBlock },
384 Diff { diff: Diff },
385}
386
387impl ToolCallContent {
388 pub fn from_acp(
389 content: acp::ToolCallContent,
390 language_registry: Arc<LanguageRegistry>,
391 cx: &mut App,
392 ) -> Self {
393 match content {
394 acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
395 content: ContentBlock::new(content, &language_registry, cx),
396 },
397 acp::ToolCallContent::Diff { diff } => Self::Diff {
398 diff: Diff::from_acp(diff, language_registry, cx),
399 },
400 }
401 }
402
403 pub fn to_markdown(&self, cx: &App) -> String {
404 match self {
405 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
406 Self::Diff { diff } => diff.to_markdown(cx),
407 }
408 }
409}
410
411#[derive(Debug)]
412pub struct Diff {
413 pub multibuffer: Entity<MultiBuffer>,
414 pub path: PathBuf,
415 pub new_buffer: Entity<Buffer>,
416 pub old_buffer: Entity<Buffer>,
417 _task: Task<Result<()>>,
418}
419
420impl Diff {
421 pub fn from_acp(
422 diff: acp::Diff,
423 language_registry: Arc<LanguageRegistry>,
424 cx: &mut App,
425 ) -> Self {
426 let acp::Diff {
427 path,
428 old_text,
429 new_text,
430 } = diff;
431
432 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
433
434 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
435 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
436 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
437 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
438 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
439 let diff_task = buffer_diff.update(cx, |diff, cx| {
440 diff.set_base_text(
441 old_buffer_snapshot,
442 Some(language_registry.clone()),
443 new_buffer_snapshot,
444 cx,
445 )
446 });
447
448 let task = cx.spawn({
449 let multibuffer = multibuffer.clone();
450 let path = path.clone();
451 let new_buffer = new_buffer.clone();
452 async move |cx| {
453 diff_task.await?;
454
455 multibuffer
456 .update(cx, |multibuffer, cx| {
457 let hunk_ranges = {
458 let buffer = new_buffer.read(cx);
459 let diff = buffer_diff.read(cx);
460 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
461 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
462 .collect::<Vec<_>>()
463 };
464
465 multibuffer.set_excerpts_for_path(
466 PathKey::for_buffer(&new_buffer, cx),
467 new_buffer.clone(),
468 hunk_ranges,
469 editor::DEFAULT_MULTIBUFFER_CONTEXT,
470 cx,
471 );
472 multibuffer.add_diff(buffer_diff.clone(), cx);
473 })
474 .log_err();
475
476 if let Some(language) = language_registry
477 .language_for_file_path(&path)
478 .await
479 .log_err()
480 {
481 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
482 }
483
484 anyhow::Ok(())
485 }
486 });
487
488 Self {
489 multibuffer,
490 path,
491 new_buffer,
492 old_buffer,
493 _task: task,
494 }
495 }
496
497 fn to_markdown(&self, cx: &App) -> String {
498 let buffer_text = self
499 .multibuffer
500 .read(cx)
501 .all_buffers()
502 .iter()
503 .map(|buffer| buffer.read(cx).text())
504 .join("\n");
505 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
506 }
507}
508
509#[derive(Debug, Default)]
510pub struct Plan {
511 pub entries: Vec<PlanEntry>,
512}
513
514#[derive(Debug)]
515pub struct PlanStats<'a> {
516 pub in_progress_entry: Option<&'a PlanEntry>,
517 pub pending: u32,
518 pub completed: u32,
519}
520
521impl Plan {
522 pub fn is_empty(&self) -> bool {
523 self.entries.is_empty()
524 }
525
526 pub fn stats(&self) -> PlanStats<'_> {
527 let mut stats = PlanStats {
528 in_progress_entry: None,
529 pending: 0,
530 completed: 0,
531 };
532
533 for entry in &self.entries {
534 match &entry.status {
535 acp::PlanEntryStatus::Pending => {
536 stats.pending += 1;
537 }
538 acp::PlanEntryStatus::InProgress => {
539 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
540 }
541 acp::PlanEntryStatus::Completed => {
542 stats.completed += 1;
543 }
544 }
545 }
546
547 stats
548 }
549}
550
551#[derive(Debug)]
552pub struct PlanEntry {
553 pub content: Entity<Markdown>,
554 pub priority: acp::PlanEntryPriority,
555 pub status: acp::PlanEntryStatus,
556}
557
558impl PlanEntry {
559 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
560 Self {
561 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
562 priority: entry.priority,
563 status: entry.status,
564 }
565 }
566}
567
568pub struct AcpThread {
569 title: SharedString,
570 entries: Vec<AgentThreadEntry>,
571 plan: Plan,
572 project: Entity<Project>,
573 action_log: Entity<ActionLog>,
574 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
575 send_task: Option<Task<()>>,
576 connection: Rc<dyn AgentConnection>,
577 session_id: acp::SessionId,
578}
579
580pub enum AcpThreadEvent {
581 NewEntry,
582 EntryUpdated(usize),
583 ToolAuthorizationRequired,
584 Stopped,
585 Error,
586}
587
588impl EventEmitter<AcpThreadEvent> for AcpThread {}
589
590#[derive(PartialEq, Eq)]
591pub enum ThreadStatus {
592 Idle,
593 WaitingForToolConfirmation,
594 Generating,
595}
596
597#[derive(Debug, Clone)]
598pub enum LoadError {
599 Unsupported {
600 error_message: SharedString,
601 upgrade_message: SharedString,
602 upgrade_command: String,
603 },
604 Exited(i32),
605 Other(SharedString),
606}
607
608impl Display for LoadError {
609 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
610 match self {
611 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
612 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
613 LoadError::Other(msg) => write!(f, "{}", msg),
614 }
615 }
616}
617
618impl Error for LoadError {}
619
620impl AcpThread {
621 pub fn new(
622 connection: Rc<dyn AgentConnection>,
623 project: Entity<Project>,
624 session_id: acp::SessionId,
625 cx: &mut Context<Self>,
626 ) -> Self {
627 let action_log = cx.new(|_| ActionLog::new(project.clone()));
628
629 Self {
630 action_log,
631 shared_buffers: Default::default(),
632 entries: Default::default(),
633 plan: Default::default(),
634 title: connection.name().into(),
635 project,
636 send_task: None,
637 connection,
638 session_id,
639 }
640 }
641
642 pub fn action_log(&self) -> &Entity<ActionLog> {
643 &self.action_log
644 }
645
646 pub fn project(&self) -> &Entity<Project> {
647 &self.project
648 }
649
650 pub fn title(&self) -> SharedString {
651 self.title.clone()
652 }
653
654 pub fn entries(&self) -> &[AgentThreadEntry] {
655 &self.entries
656 }
657
658 pub fn status(&self) -> ThreadStatus {
659 if self.send_task.is_some() {
660 if self.waiting_for_tool_confirmation() {
661 ThreadStatus::WaitingForToolConfirmation
662 } else {
663 ThreadStatus::Generating
664 }
665 } else {
666 ThreadStatus::Idle
667 }
668 }
669
670 pub fn has_pending_edit_tool_calls(&self) -> bool {
671 for entry in self.entries.iter().rev() {
672 match entry {
673 AgentThreadEntry::UserMessage(_) => return false,
674 AgentThreadEntry::ToolCall(
675 call @ ToolCall {
676 status:
677 ToolCallStatus::Allowed {
678 status:
679 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
680 },
681 ..
682 },
683 ) if call.diffs().next().is_some() => {
684 return true;
685 }
686 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
687 }
688 }
689
690 false
691 }
692
693 pub fn used_tools_since_last_user_message(&self) -> bool {
694 for entry in self.entries.iter().rev() {
695 match entry {
696 AgentThreadEntry::UserMessage(..) => return false,
697 AgentThreadEntry::AssistantMessage(..) => continue,
698 AgentThreadEntry::ToolCall(..) => return true,
699 }
700 }
701
702 false
703 }
704
705 pub fn handle_session_update(
706 &mut self,
707 update: acp::SessionUpdate,
708 cx: &mut Context<Self>,
709 ) -> Result<()> {
710 match update {
711 acp::SessionUpdate::UserMessage(content_block) => {
712 self.push_user_content_block(content_block, cx);
713 }
714 acp::SessionUpdate::AgentMessageChunk(content_block) => {
715 self.push_assistant_content_block(content_block, false, cx);
716 }
717 acp::SessionUpdate::AgentThoughtChunk(content_block) => {
718 self.push_assistant_content_block(content_block, true, cx);
719 }
720 acp::SessionUpdate::ToolCall(tool_call) => {
721 self.upsert_tool_call(tool_call, cx);
722 }
723 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
724 self.update_tool_call(tool_call_update, cx)?;
725 }
726 acp::SessionUpdate::Plan(plan) => {
727 self.update_plan(plan, cx);
728 }
729 }
730 Ok(())
731 }
732
733 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
734 let language_registry = self.project.read(cx).languages().clone();
735 let entries_len = self.entries.len();
736
737 if let Some(last_entry) = self.entries.last_mut()
738 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
739 {
740 content.append(chunk, &language_registry, cx);
741 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
742 } else {
743 let content = ContentBlock::new(chunk, &language_registry, cx);
744 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
745 }
746 }
747
748 pub fn push_assistant_content_block(
749 &mut self,
750 chunk: acp::ContentBlock,
751 is_thought: bool,
752 cx: &mut Context<Self>,
753 ) {
754 let language_registry = self.project.read(cx).languages().clone();
755 let entries_len = self.entries.len();
756 if let Some(last_entry) = self.entries.last_mut()
757 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
758 {
759 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
760 match (chunks.last_mut(), is_thought) {
761 (Some(AssistantMessageChunk::Message { block }), false)
762 | (Some(AssistantMessageChunk::Thought { block }), true) => {
763 block.append(chunk, &language_registry, cx)
764 }
765 _ => {
766 let block = ContentBlock::new(chunk, &language_registry, cx);
767 if is_thought {
768 chunks.push(AssistantMessageChunk::Thought { block })
769 } else {
770 chunks.push(AssistantMessageChunk::Message { block })
771 }
772 }
773 }
774 } else {
775 let block = ContentBlock::new(chunk, &language_registry, cx);
776 let chunk = if is_thought {
777 AssistantMessageChunk::Thought { block }
778 } else {
779 AssistantMessageChunk::Message { block }
780 };
781
782 self.push_entry(
783 AgentThreadEntry::AssistantMessage(AssistantMessage {
784 chunks: vec![chunk],
785 }),
786 cx,
787 );
788 }
789 }
790
791 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
792 self.entries.push(entry);
793 cx.emit(AcpThreadEvent::NewEntry);
794 }
795
796 pub fn update_tool_call(
797 &mut self,
798 update: acp::ToolCallUpdate,
799 cx: &mut Context<Self>,
800 ) -> Result<()> {
801 let languages = self.project.read(cx).languages().clone();
802
803 let (ix, current_call) = self
804 .tool_call_mut(&update.id)
805 .context("Tool call not found")?;
806 current_call.update(update.fields, languages, cx);
807
808 cx.emit(AcpThreadEvent::EntryUpdated(ix));
809
810 Ok(())
811 }
812
813 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
814 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
815 let status = ToolCallStatus::Allowed {
816 status: tool_call.status,
817 };
818 self.upsert_tool_call_inner(tool_call, status, cx)
819 }
820
821 pub fn upsert_tool_call_inner(
822 &mut self,
823 tool_call: acp::ToolCall,
824 status: ToolCallStatus,
825 cx: &mut Context<Self>,
826 ) {
827 let language_registry = self.project.read(cx).languages().clone();
828 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
829
830 let location = call.locations.last().cloned();
831
832 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
833 *current_call = call;
834
835 cx.emit(AcpThreadEvent::EntryUpdated(ix));
836 } else {
837 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
838 }
839
840 if let Some(location) = location {
841 self.set_project_location(location, cx)
842 }
843 }
844
845 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
846 // The tool call we are looking for is typically the last one, or very close to the end.
847 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
848 self.entries
849 .iter_mut()
850 .enumerate()
851 .rev()
852 .find_map(|(index, tool_call)| {
853 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
854 && &tool_call.id == id
855 {
856 Some((index, tool_call))
857 } else {
858 None
859 }
860 })
861 }
862
863 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
864 self.project.update(cx, |project, cx| {
865 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
866 return;
867 };
868 let buffer = project.open_buffer(path, cx);
869 cx.spawn(async move |project, cx| {
870 let buffer = buffer.await?;
871
872 project.update(cx, |project, cx| {
873 let position = if let Some(line) = location.line {
874 let snapshot = buffer.read(cx).snapshot();
875 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
876 snapshot.anchor_before(point)
877 } else {
878 Anchor::MIN
879 };
880
881 project.set_agent_location(
882 Some(AgentLocation {
883 buffer: buffer.downgrade(),
884 position,
885 }),
886 cx,
887 );
888 })
889 })
890 .detach_and_log_err(cx);
891 });
892 }
893
894 pub fn request_tool_call_permission(
895 &mut self,
896 tool_call: acp::ToolCall,
897 options: Vec<acp::PermissionOption>,
898 cx: &mut Context<Self>,
899 ) -> oneshot::Receiver<acp::PermissionOptionId> {
900 let (tx, rx) = oneshot::channel();
901
902 let status = ToolCallStatus::WaitingForConfirmation {
903 options,
904 respond_tx: tx,
905 };
906
907 self.upsert_tool_call_inner(tool_call, status, cx);
908 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
909 rx
910 }
911
912 pub fn authorize_tool_call(
913 &mut self,
914 id: acp::ToolCallId,
915 option_id: acp::PermissionOptionId,
916 option_kind: acp::PermissionOptionKind,
917 cx: &mut Context<Self>,
918 ) {
919 let Some((ix, call)) = self.tool_call_mut(&id) else {
920 return;
921 };
922
923 let new_status = match option_kind {
924 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
925 ToolCallStatus::Rejected
926 }
927 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
928 ToolCallStatus::Allowed {
929 status: acp::ToolCallStatus::InProgress,
930 }
931 }
932 };
933
934 let curr_status = mem::replace(&mut call.status, new_status);
935
936 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
937 respond_tx.send(option_id).log_err();
938 } else if cfg!(debug_assertions) {
939 panic!("tried to authorize an already authorized tool call");
940 }
941
942 cx.emit(AcpThreadEvent::EntryUpdated(ix));
943 }
944
945 /// Returns true if the last turn is awaiting tool authorization
946 pub fn waiting_for_tool_confirmation(&self) -> bool {
947 for entry in self.entries.iter().rev() {
948 match &entry {
949 AgentThreadEntry::ToolCall(call) => match call.status {
950 ToolCallStatus::WaitingForConfirmation { .. } => return true,
951 ToolCallStatus::Allowed { .. }
952 | ToolCallStatus::Rejected
953 | ToolCallStatus::Canceled => continue,
954 },
955 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
956 // Reached the beginning of the turn
957 return false;
958 }
959 }
960 }
961 false
962 }
963
964 pub fn plan(&self) -> &Plan {
965 &self.plan
966 }
967
968 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
969 self.plan = Plan {
970 entries: request
971 .entries
972 .into_iter()
973 .map(|entry| PlanEntry::from_acp(entry, cx))
974 .collect(),
975 };
976
977 cx.notify();
978 }
979
980 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
981 self.plan
982 .entries
983 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
984 cx.notify();
985 }
986
987 pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
988 self.connection.authenticate(cx)
989 }
990
991 #[cfg(any(test, feature = "test-support"))]
992 pub fn send_raw(
993 &mut self,
994 message: &str,
995 cx: &mut Context<Self>,
996 ) -> BoxFuture<'static, Result<()>> {
997 self.send(
998 vec![acp::ContentBlock::Text(acp::TextContent {
999 text: message.to_string(),
1000 annotations: None,
1001 })],
1002 cx,
1003 )
1004 }
1005
1006 pub fn send(
1007 &mut self,
1008 message: Vec<acp::ContentBlock>,
1009 cx: &mut Context<Self>,
1010 ) -> BoxFuture<'static, Result<()>> {
1011 let block = ContentBlock::new_combined(
1012 message.clone(),
1013 self.project.read(cx).languages().clone(),
1014 cx,
1015 );
1016 self.push_entry(
1017 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1018 cx,
1019 );
1020 self.clear_completed_plan_entries(cx);
1021
1022 let (tx, rx) = oneshot::channel();
1023 let cancel_task = self.cancel(cx);
1024
1025 self.send_task = Some(cx.spawn(async move |this, cx| {
1026 async {
1027 cancel_task.await;
1028
1029 let result = this
1030 .update(cx, |this, cx| {
1031 this.connection.prompt(
1032 acp::PromptArguments {
1033 prompt: message,
1034 session_id: this.session_id.clone(),
1035 },
1036 cx,
1037 )
1038 })?
1039 .await;
1040 tx.send(result).log_err();
1041 this.update(cx, |this, _cx| this.send_task.take())?;
1042 anyhow::Ok(())
1043 }
1044 .await
1045 .log_err();
1046 }));
1047
1048 cx.spawn(async move |this, cx| match rx.await {
1049 Ok(Err(e)) => {
1050 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1051 .log_err();
1052 Err(e)?
1053 }
1054 _ => {
1055 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1056 .log_err();
1057 Ok(())
1058 }
1059 })
1060 .boxed()
1061 }
1062
1063 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1064 let Some(send_task) = self.send_task.take() else {
1065 return Task::ready(());
1066 };
1067
1068 for entry in self.entries.iter_mut() {
1069 if let AgentThreadEntry::ToolCall(call) = entry {
1070 let cancel = matches!(
1071 call.status,
1072 ToolCallStatus::WaitingForConfirmation { .. }
1073 | ToolCallStatus::Allowed {
1074 status: acp::ToolCallStatus::InProgress
1075 }
1076 );
1077
1078 if cancel {
1079 call.status = ToolCallStatus::Canceled;
1080 }
1081 }
1082 }
1083
1084 self.connection.cancel(&self.session_id, cx);
1085
1086 // Wait for the send task to complete
1087 cx.foreground_executor().spawn(send_task)
1088 }
1089
1090 pub fn read_text_file(
1091 &self,
1092 path: PathBuf,
1093 line: Option<u32>,
1094 limit: Option<u32>,
1095 reuse_shared_snapshot: bool,
1096 cx: &mut Context<Self>,
1097 ) -> Task<Result<String>> {
1098 let project = self.project.clone();
1099 let action_log = self.action_log.clone();
1100 cx.spawn(async move |this, cx| {
1101 let load = project.update(cx, |project, cx| {
1102 let path = project
1103 .project_path_for_absolute_path(&path, cx)
1104 .context("invalid path")?;
1105 anyhow::Ok(project.open_buffer(path, cx))
1106 });
1107 let buffer = load??.await?;
1108
1109 let snapshot = if reuse_shared_snapshot {
1110 this.read_with(cx, |this, _| {
1111 this.shared_buffers.get(&buffer.clone()).cloned()
1112 })
1113 .log_err()
1114 .flatten()
1115 } else {
1116 None
1117 };
1118
1119 let snapshot = if let Some(snapshot) = snapshot {
1120 snapshot
1121 } else {
1122 action_log.update(cx, |action_log, cx| {
1123 action_log.buffer_read(buffer.clone(), cx);
1124 })?;
1125 project.update(cx, |project, cx| {
1126 let position = buffer
1127 .read(cx)
1128 .snapshot()
1129 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1130 project.set_agent_location(
1131 Some(AgentLocation {
1132 buffer: buffer.downgrade(),
1133 position,
1134 }),
1135 cx,
1136 );
1137 })?;
1138
1139 buffer.update(cx, |buffer, _| buffer.snapshot())?
1140 };
1141
1142 this.update(cx, |this, _| {
1143 let text = snapshot.text();
1144 this.shared_buffers.insert(buffer.clone(), snapshot);
1145 if line.is_none() && limit.is_none() {
1146 return Ok(text);
1147 }
1148 let limit = limit.unwrap_or(u32::MAX) as usize;
1149 let Some(line) = line else {
1150 return Ok(text.lines().take(limit).collect::<String>());
1151 };
1152
1153 let count = text.lines().count();
1154 if count < line as usize {
1155 anyhow::bail!("There are only {} lines", count);
1156 }
1157 Ok(text
1158 .lines()
1159 .skip(line as usize + 1)
1160 .take(limit)
1161 .collect::<String>())
1162 })?
1163 })
1164 }
1165
1166 pub fn write_text_file(
1167 &self,
1168 path: PathBuf,
1169 content: String,
1170 cx: &mut Context<Self>,
1171 ) -> Task<Result<()>> {
1172 let project = self.project.clone();
1173 let action_log = self.action_log.clone();
1174 cx.spawn(async move |this, cx| {
1175 let load = project.update(cx, |project, cx| {
1176 let path = project
1177 .project_path_for_absolute_path(&path, cx)
1178 .context("invalid path")?;
1179 anyhow::Ok(project.open_buffer(path, cx))
1180 });
1181 let buffer = load??.await?;
1182 let snapshot = this.update(cx, |this, cx| {
1183 this.shared_buffers
1184 .get(&buffer)
1185 .cloned()
1186 .unwrap_or_else(|| buffer.read(cx).snapshot())
1187 })?;
1188 let edits = cx
1189 .background_executor()
1190 .spawn(async move {
1191 let old_text = snapshot.text();
1192 text_diff(old_text.as_str(), &content)
1193 .into_iter()
1194 .map(|(range, replacement)| {
1195 (
1196 snapshot.anchor_after(range.start)
1197 ..snapshot.anchor_before(range.end),
1198 replacement,
1199 )
1200 })
1201 .collect::<Vec<_>>()
1202 })
1203 .await;
1204 cx.update(|cx| {
1205 project.update(cx, |project, cx| {
1206 project.set_agent_location(
1207 Some(AgentLocation {
1208 buffer: buffer.downgrade(),
1209 position: edits
1210 .last()
1211 .map(|(range, _)| range.end)
1212 .unwrap_or(Anchor::MIN),
1213 }),
1214 cx,
1215 );
1216 });
1217
1218 action_log.update(cx, |action_log, cx| {
1219 action_log.buffer_read(buffer.clone(), cx);
1220 });
1221 buffer.update(cx, |buffer, cx| {
1222 buffer.edit(edits, None, cx);
1223 });
1224 action_log.update(cx, |action_log, cx| {
1225 action_log.buffer_edited(buffer.clone(), cx);
1226 });
1227 })?;
1228 project
1229 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1230 .await
1231 })
1232 }
1233
1234 pub fn to_markdown(&self, cx: &App) -> String {
1235 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1236 }
1237}
1238
1239#[cfg(test)]
1240mod tests {
1241 use super::*;
1242 use agentic_coding_protocol as acp_old;
1243 use anyhow::anyhow;
1244 use async_pipe::{PipeReader, PipeWriter};
1245 use futures::{
1246 channel::mpsc,
1247 future::{LocalBoxFuture, try_join_all},
1248 select,
1249 };
1250 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1251 use indoc::indoc;
1252 use project::FakeFs;
1253 use rand::Rng as _;
1254 use serde_json::json;
1255 use settings::SettingsStore;
1256 use smol::{future::BoxedLocal, stream::StreamExt as _};
1257 use std::{cell::RefCell, rc::Rc, time::Duration};
1258
1259 use util::path;
1260
1261 fn init_test(cx: &mut TestAppContext) {
1262 env_logger::try_init().ok();
1263 cx.update(|cx| {
1264 let settings_store = SettingsStore::test(cx);
1265 cx.set_global(settings_store);
1266 Project::init_settings(cx);
1267 language::init(cx);
1268 });
1269 }
1270
1271 #[gpui::test]
1272 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1273 init_test(cx);
1274
1275 let fs = FakeFs::new(cx.executor());
1276 let project = Project::test(fs, [], cx).await;
1277 let (thread, _fake_server) = fake_acp_thread(project, cx);
1278
1279 // Test creating a new user message
1280 thread.update(cx, |thread, cx| {
1281 thread.push_user_content_block(
1282 acp::ContentBlock::Text(acp::TextContent {
1283 annotations: None,
1284 text: "Hello, ".to_string(),
1285 }),
1286 cx,
1287 );
1288 });
1289
1290 thread.update(cx, |thread, cx| {
1291 assert_eq!(thread.entries.len(), 1);
1292 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1293 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1294 } else {
1295 panic!("Expected UserMessage");
1296 }
1297 });
1298
1299 // Test appending to existing user message
1300 thread.update(cx, |thread, cx| {
1301 thread.push_user_content_block(
1302 acp::ContentBlock::Text(acp::TextContent {
1303 annotations: None,
1304 text: "world!".to_string(),
1305 }),
1306 cx,
1307 );
1308 });
1309
1310 thread.update(cx, |thread, cx| {
1311 assert_eq!(thread.entries.len(), 1);
1312 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1313 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1314 } else {
1315 panic!("Expected UserMessage");
1316 }
1317 });
1318
1319 // Test creating new user message after assistant message
1320 thread.update(cx, |thread, cx| {
1321 thread.push_assistant_content_block(
1322 acp::ContentBlock::Text(acp::TextContent {
1323 annotations: None,
1324 text: "Assistant response".to_string(),
1325 }),
1326 false,
1327 cx,
1328 );
1329 });
1330
1331 thread.update(cx, |thread, cx| {
1332 thread.push_user_content_block(
1333 acp::ContentBlock::Text(acp::TextContent {
1334 annotations: None,
1335 text: "New user message".to_string(),
1336 }),
1337 cx,
1338 );
1339 });
1340
1341 thread.update(cx, |thread, cx| {
1342 assert_eq!(thread.entries.len(), 3);
1343 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1344 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1345 } else {
1346 panic!("Expected UserMessage at index 2");
1347 }
1348 });
1349 }
1350
1351 #[gpui::test]
1352 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1353 init_test(cx);
1354
1355 let fs = FakeFs::new(cx.executor());
1356 let project = Project::test(fs, [], cx).await;
1357 let (thread, fake_server) = fake_acp_thread(project, cx);
1358
1359 fake_server.update(cx, |fake_server, _| {
1360 fake_server.on_user_message(move |_, server, mut cx| async move {
1361 server
1362 .update(&mut cx, |server, _| {
1363 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1364 chunk: acp_old::AssistantMessageChunk::Thought {
1365 thought: "Thinking ".into(),
1366 },
1367 })
1368 })?
1369 .await
1370 .unwrap();
1371 server
1372 .update(&mut cx, |server, _| {
1373 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1374 chunk: acp_old::AssistantMessageChunk::Thought {
1375 thought: "hard!".into(),
1376 },
1377 })
1378 })?
1379 .await
1380 .unwrap();
1381
1382 Ok(())
1383 })
1384 });
1385
1386 thread
1387 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1388 .await
1389 .unwrap();
1390
1391 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1392 assert_eq!(
1393 output,
1394 indoc! {r#"
1395 ## User
1396
1397 Hello from Zed!
1398
1399 ## Assistant
1400
1401 <thinking>
1402 Thinking hard!
1403 </thinking>
1404
1405 "#}
1406 );
1407 }
1408
1409 #[gpui::test]
1410 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1411 init_test(cx);
1412
1413 let fs = FakeFs::new(cx.executor());
1414 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1415 .await;
1416 let project = Project::test(fs.clone(), [], cx).await;
1417 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1418 let (worktree, pathbuf) = project
1419 .update(cx, |project, cx| {
1420 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1421 })
1422 .await
1423 .unwrap();
1424 let buffer = project
1425 .update(cx, |project, cx| {
1426 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1427 })
1428 .await
1429 .unwrap();
1430
1431 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1432 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1433
1434 fake_server.update(cx, |fake_server, _| {
1435 fake_server.on_user_message(move |_, server, mut cx| {
1436 let read_file_tx = read_file_tx.clone();
1437 async move {
1438 let content = server
1439 .update(&mut cx, |server, _| {
1440 server.send_to_zed(acp_old::ReadTextFileParams {
1441 path: path!("/tmp/foo").into(),
1442 line: None,
1443 limit: None,
1444 })
1445 })?
1446 .await
1447 .unwrap();
1448 assert_eq!(content.content, "one\ntwo\nthree\n");
1449 read_file_tx.take().unwrap().send(()).unwrap();
1450 server
1451 .update(&mut cx, |server, _| {
1452 server.send_to_zed(acp_old::WriteTextFileParams {
1453 path: path!("/tmp/foo").into(),
1454 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1455 })
1456 })?
1457 .await
1458 .unwrap();
1459 Ok(())
1460 }
1461 })
1462 });
1463
1464 let request = thread.update(cx, |thread, cx| {
1465 thread.send_raw("Extend the count in /tmp/foo", cx)
1466 });
1467 read_file_rx.await.ok();
1468 buffer.update(cx, |buffer, cx| {
1469 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1470 });
1471 cx.run_until_parked();
1472 assert_eq!(
1473 buffer.read_with(cx, |buffer, _| buffer.text()),
1474 "zero\none\ntwo\nthree\nfour\nfive\n"
1475 );
1476 assert_eq!(
1477 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1478 "zero\none\ntwo\nthree\nfour\nfive\n"
1479 );
1480 request.await.unwrap();
1481 }
1482
1483 #[gpui::test]
1484 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1485 init_test(cx);
1486
1487 let fs = FakeFs::new(cx.executor());
1488 let project = Project::test(fs, [], cx).await;
1489 let (thread, fake_server) = fake_acp_thread(project, cx);
1490
1491 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1492
1493 let tool_call_id = Rc::new(RefCell::new(None));
1494 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1495 fake_server.update(cx, |fake_server, _| {
1496 let tool_call_id = tool_call_id.clone();
1497 fake_server.on_user_message(move |_, server, mut cx| {
1498 let end_turn_rx = end_turn_rx.clone();
1499 let tool_call_id = tool_call_id.clone();
1500 async move {
1501 let tool_call_result = server
1502 .update(&mut cx, |server, _| {
1503 server.send_to_zed(acp_old::PushToolCallParams {
1504 label: "Fetch".to_string(),
1505 icon: acp_old::Icon::Globe,
1506 content: None,
1507 locations: vec![],
1508 })
1509 })?
1510 .await
1511 .unwrap();
1512 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1513 end_turn_rx.take().unwrap().await.ok();
1514
1515 Ok(())
1516 }
1517 })
1518 });
1519
1520 let request = thread.update(cx, |thread, cx| {
1521 thread.send_raw("Fetch https://example.com", cx)
1522 });
1523
1524 run_until_first_tool_call(&thread, cx).await;
1525
1526 thread.read_with(cx, |thread, _| {
1527 assert!(matches!(
1528 thread.entries[1],
1529 AgentThreadEntry::ToolCall(ToolCall {
1530 status: ToolCallStatus::Allowed {
1531 status: acp::ToolCallStatus::InProgress,
1532 ..
1533 },
1534 ..
1535 })
1536 ));
1537 });
1538
1539 cx.run_until_parked();
1540
1541 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1542
1543 thread.read_with(cx, |thread, _| {
1544 assert!(matches!(
1545 &thread.entries[1],
1546 AgentThreadEntry::ToolCall(ToolCall {
1547 status: ToolCallStatus::Canceled,
1548 ..
1549 })
1550 ));
1551 });
1552
1553 fake_server
1554 .update(cx, |fake_server, _| {
1555 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1556 tool_call_id: tool_call_id.borrow().unwrap(),
1557 status: acp_old::ToolCallStatus::Finished,
1558 content: None,
1559 })
1560 })
1561 .await
1562 .unwrap();
1563
1564 drop(end_turn_tx);
1565 assert!(request.await.unwrap_err().to_string().contains("canceled"));
1566
1567 thread.read_with(cx, |thread, _| {
1568 assert!(matches!(
1569 thread.entries[1],
1570 AgentThreadEntry::ToolCall(ToolCall {
1571 status: ToolCallStatus::Allowed {
1572 status: acp::ToolCallStatus::Completed,
1573 ..
1574 },
1575 ..
1576 })
1577 ));
1578 });
1579 }
1580
1581 #[gpui::test]
1582 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1583 init_test(cx);
1584 let fs = FakeFs::new(cx.background_executor.clone());
1585 fs.insert_tree(path!("/test"), json!({})).await;
1586 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1587
1588 let connection = Rc::new(StubAgentConnection::new(vec![
1589 acp::SessionUpdate::ToolCall(acp::ToolCall {
1590 id: acp::ToolCallId("test".into()),
1591 label: "Label".into(),
1592 kind: acp::ToolKind::Edit,
1593 status: acp::ToolCallStatus::Completed,
1594 content: vec![acp::ToolCallContent::Diff {
1595 diff: acp::Diff {
1596 path: "/test/test.txt".into(),
1597 old_text: None,
1598 new_text: "foo".into(),
1599 },
1600 }],
1601 locations: vec![],
1602 raw_input: None,
1603 }),
1604 ]));
1605
1606 let thread = connection
1607 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1608 .await
1609 .unwrap();
1610 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1611 .await
1612 .unwrap();
1613
1614 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1615 }
1616
1617 async fn run_until_first_tool_call(
1618 thread: &Entity<AcpThread>,
1619 cx: &mut TestAppContext,
1620 ) -> usize {
1621 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1622
1623 let subscription = cx.update(|cx| {
1624 cx.subscribe(thread, move |thread, _, cx| {
1625 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1626 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1627 return tx.try_send(ix).unwrap();
1628 }
1629 }
1630 })
1631 });
1632
1633 select! {
1634 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1635 panic!("Timeout waiting for tool call")
1636 }
1637 ix = rx.next().fuse() => {
1638 drop(subscription);
1639 ix.unwrap()
1640 }
1641 }
1642 }
1643
1644 #[derive(Clone, Default)]
1645 struct StubAgentConnection {
1646 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1647 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
1648 updates: Vec<acp::SessionUpdate>,
1649 }
1650
1651 impl StubAgentConnection {
1652 fn new(updates: Vec<acp::SessionUpdate>) -> Self {
1653 Self {
1654 updates,
1655 permission_requests: HashMap::default(),
1656 sessions: Arc::default(),
1657 }
1658 }
1659 }
1660
1661 impl AgentConnection for StubAgentConnection {
1662 fn name(&self) -> &'static str {
1663 "StubAgentConnection"
1664 }
1665
1666 fn new_thread(
1667 self: Rc<Self>,
1668 project: Entity<Project>,
1669 _cwd: &Path,
1670 cx: &mut gpui::AsyncApp,
1671 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1672 let session_id = acp::SessionId(
1673 rand::thread_rng()
1674 .sample_iter(&rand::distributions::Alphanumeric)
1675 .take(7)
1676 .map(char::from)
1677 .collect::<String>()
1678 .into(),
1679 );
1680 let thread = cx
1681 .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
1682 .unwrap();
1683 self.sessions.lock().insert(session_id, thread.downgrade());
1684 Task::ready(Ok(thread))
1685 }
1686
1687 fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
1688 unimplemented!()
1689 }
1690
1691 fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
1692 let sessions = self.sessions.lock();
1693 let thread = sessions.get(¶ms.session_id).unwrap();
1694 let mut tasks = vec![];
1695 for update in &self.updates {
1696 let thread = thread.clone();
1697 let update = update.clone();
1698 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
1699 && let Some(options) = self.permission_requests.get(&tool_call.id)
1700 {
1701 Some((tool_call.clone(), options.clone()))
1702 } else {
1703 None
1704 };
1705 let task = cx.spawn(async move |cx| {
1706 if let Some((tool_call, options)) = permission_request {
1707 let permission = thread.update(cx, |thread, cx| {
1708 thread.request_tool_call_permission(
1709 tool_call.clone(),
1710 options.clone(),
1711 cx,
1712 )
1713 })?;
1714 permission.await?;
1715 }
1716 thread.update(cx, |thread, cx| {
1717 thread.handle_session_update(update.clone(), cx).unwrap();
1718 })?;
1719 anyhow::Ok(())
1720 });
1721 tasks.push(task);
1722 }
1723 cx.spawn(async move |_| {
1724 try_join_all(tasks).await?;
1725 Ok(())
1726 })
1727 }
1728
1729 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
1730 unimplemented!()
1731 }
1732 }
1733
1734 pub fn fake_acp_thread(
1735 project: Entity<Project>,
1736 cx: &mut TestAppContext,
1737 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1738 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1739 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1740
1741 let thread = cx.new(|cx| {
1742 let foreground_executor = cx.foreground_executor().clone();
1743 let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade()));
1744
1745 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1746 OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()),
1747 stdin_tx,
1748 stdout_rx,
1749 move |fut| {
1750 foreground_executor.spawn(fut).detach();
1751 },
1752 );
1753
1754 let io_task = cx.background_spawn({
1755 async move {
1756 io_fut.await.log_err();
1757 Ok(())
1758 }
1759 });
1760 let connection = OldAcpAgentConnection {
1761 name: "test",
1762 connection,
1763 child_status: io_task,
1764 current_thread: thread_rc,
1765 };
1766
1767 AcpThread::new(
1768 Rc::new(connection),
1769 project,
1770 acp::SessionId("test".into()),
1771 cx,
1772 )
1773 });
1774 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1775 (thread, agent)
1776 }
1777
1778 pub struct FakeAcpServer {
1779 connection: acp_old::ClientConnection,
1780
1781 _io_task: Task<()>,
1782 on_user_message: Option<
1783 Rc<
1784 dyn Fn(
1785 acp_old::SendUserMessageParams,
1786 Entity<FakeAcpServer>,
1787 AsyncApp,
1788 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1789 >,
1790 >,
1791 }
1792
1793 #[derive(Clone)]
1794 struct FakeAgent {
1795 server: Entity<FakeAcpServer>,
1796 cx: AsyncApp,
1797 cancel_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
1798 }
1799
1800 impl acp_old::Agent for FakeAgent {
1801 async fn initialize(
1802 &self,
1803 params: acp_old::InitializeParams,
1804 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1805 Ok(acp_old::InitializeResponse {
1806 protocol_version: params.protocol_version,
1807 is_authenticated: true,
1808 })
1809 }
1810
1811 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1812 Ok(())
1813 }
1814
1815 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1816 if let Some(cancel_tx) = self.cancel_tx.take() {
1817 cancel_tx.send(()).log_err();
1818 }
1819 Ok(())
1820 }
1821
1822 async fn send_user_message(
1823 &self,
1824 request: acp_old::SendUserMessageParams,
1825 ) -> Result<(), acp_old::Error> {
1826 let (cancel_tx, cancel_rx) = oneshot::channel();
1827 self.cancel_tx.replace(Some(cancel_tx));
1828
1829 let mut cx = self.cx.clone();
1830 let handler = self
1831 .server
1832 .update(&mut cx, |server, _| server.on_user_message.clone())
1833 .ok()
1834 .flatten();
1835 if let Some(handler) = handler {
1836 select! {
1837 _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()),
1838 _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()),
1839 }
1840 } else {
1841 Err(anyhow::anyhow!("No handler for on_user_message").into())
1842 }
1843 }
1844 }
1845
1846 impl FakeAcpServer {
1847 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1848 let agent = FakeAgent {
1849 server: cx.entity(),
1850 cx: cx.to_async(),
1851 cancel_tx: Default::default(),
1852 };
1853 let foreground_executor = cx.foreground_executor().clone();
1854
1855 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1856 agent.clone(),
1857 stdout,
1858 stdin,
1859 move |fut| {
1860 foreground_executor.spawn(fut).detach();
1861 },
1862 );
1863 FakeAcpServer {
1864 connection: connection,
1865 on_user_message: None,
1866 _io_task: cx.background_spawn(async move {
1867 io_fut.await.log_err();
1868 }),
1869 }
1870 }
1871
1872 fn on_user_message<F>(
1873 &mut self,
1874 handler: impl for<'a> Fn(
1875 acp_old::SendUserMessageParams,
1876 Entity<FakeAcpServer>,
1877 AsyncApp,
1878 ) -> F
1879 + 'static,
1880 ) where
1881 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1882 {
1883 self.on_user_message
1884 .replace(Rc::new(move |request, server, cx| {
1885 handler(request, server, cx).boxed_local()
1886 }));
1887 }
1888
1889 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1890 &self,
1891 message: T,
1892 ) -> BoxedLocal<Result<T::Response>> {
1893 self.connection
1894 .request(message)
1895 .map(|f| f.map_err(|err| anyhow!(err)))
1896 .boxed_local()
1897 }
1898 }
1899}