1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::process::ExitStatus;
22use std::rc::Rc;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(chunk.into(), language_registry, cx),
114 }
115 }
116
117 fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::Message { block } => block.to_markdown(cx).to_string(),
120 Self::Thought { block } => {
121 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
128pub enum AgentThreadEntry {
129 UserMessage(UserMessage),
130 AssistantMessage(AssistantMessage),
131 ToolCall(ToolCall),
132}
133
134impl AgentThreadEntry {
135 fn to_markdown(&self, cx: &App) -> String {
136 match self {
137 Self::UserMessage(message) => message.to_markdown(cx),
138 Self::AssistantMessage(message) => message.to_markdown(cx),
139 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
140 }
141 }
142
143 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.diffs())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
152 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
153 Some(locations)
154 } else {
155 None
156 }
157 }
158}
159
160#[derive(Debug)]
161pub struct ToolCall {
162 pub id: acp::ToolCallId,
163 pub label: Entity<Markdown>,
164 pub kind: acp::ToolKind,
165 pub content: Vec<ToolCallContent>,
166 pub status: ToolCallStatus,
167 pub locations: Vec<acp::ToolCallLocation>,
168 pub raw_input: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 }
198 }
199
200 fn update(
201 &mut self,
202 fields: acp::ToolCallUpdateFields,
203 language_registry: Arc<LanguageRegistry>,
204 cx: &mut App,
205 ) {
206 let acp::ToolCallUpdateFields {
207 kind,
208 status,
209 title,
210 content,
211 locations,
212 raw_input,
213 } = fields;
214
215 if let Some(kind) = kind {
216 self.kind = kind;
217 }
218
219 if let Some(status) = status {
220 self.status = ToolCallStatus::Allowed { status };
221 }
222
223 if let Some(title) = title {
224 self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
225 }
226
227 if let Some(content) = content {
228 self.content = content
229 .into_iter()
230 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
231 .collect();
232 }
233
234 if let Some(locations) = locations {
235 self.locations = locations;
236 }
237
238 if let Some(raw_input) = raw_input {
239 self.raw_input = Some(raw_input);
240 }
241 }
242
243 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::ContentBlock { .. } => None,
246 ToolCallContent::Diff { diff } => Some(diff),
247 })
248 }
249
250 fn to_markdown(&self, cx: &App) -> String {
251 let mut markdown = format!(
252 "**Tool Call: {}**\nStatus: {}\n\n",
253 self.label.read(cx).source(),
254 self.status
255 );
256 for content in &self.content {
257 markdown.push_str(content.to_markdown(cx).as_str());
258 markdown.push_str("\n\n");
259 }
260 markdown
261 }
262}
263
264#[derive(Debug)]
265pub enum ToolCallStatus {
266 WaitingForConfirmation {
267 options: Vec<acp::PermissionOption>,
268 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
269 },
270 Allowed {
271 status: acp::ToolCallStatus,
272 },
273 Rejected,
274 Canceled,
275}
276
277impl Display for ToolCallStatus {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 write!(
280 f,
281 "{}",
282 match self {
283 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
284 ToolCallStatus::Allowed { status } => match status {
285 acp::ToolCallStatus::Pending => "Pending",
286 acp::ToolCallStatus::InProgress => "In Progress",
287 acp::ToolCallStatus::Completed => "Completed",
288 acp::ToolCallStatus::Failed => "Failed",
289 },
290 ToolCallStatus::Rejected => "Rejected",
291 ToolCallStatus::Canceled => "Canceled",
292 }
293 )
294 }
295}
296
297#[derive(Debug, PartialEq, Clone)]
298pub enum ContentBlock {
299 Empty,
300 Markdown { markdown: Entity<Markdown> },
301}
302
303impl ContentBlock {
304 pub fn new(
305 block: acp::ContentBlock,
306 language_registry: &Arc<LanguageRegistry>,
307 cx: &mut App,
308 ) -> Self {
309 let mut this = Self::Empty;
310 this.append(block, language_registry, cx);
311 this
312 }
313
314 pub fn new_combined(
315 blocks: impl IntoIterator<Item = acp::ContentBlock>,
316 language_registry: Arc<LanguageRegistry>,
317 cx: &mut App,
318 ) -> Self {
319 let mut this = Self::Empty;
320 for block in blocks {
321 this.append(block, &language_registry, cx);
322 }
323 this
324 }
325
326 pub fn append(
327 &mut self,
328 block: acp::ContentBlock,
329 language_registry: &Arc<LanguageRegistry>,
330 cx: &mut App,
331 ) {
332 let new_content = match block {
333 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
334 acp::ContentBlock::ResourceLink(resource_link) => {
335 if let Some(path) = resource_link.uri.strip_prefix("file://") {
336 format!("{}", MentionPath(path.as_ref()))
337 } else {
338 resource_link.uri.clone()
339 }
340 }
341 acp::ContentBlock::Image(_)
342 | acp::ContentBlock::Audio(_)
343 | acp::ContentBlock::Resource(_) => String::new(),
344 };
345
346 match self {
347 ContentBlock::Empty => {
348 *self = ContentBlock::Markdown {
349 markdown: cx.new(|cx| {
350 Markdown::new(
351 new_content.into(),
352 Some(language_registry.clone()),
353 None,
354 cx,
355 )
356 }),
357 };
358 }
359 ContentBlock::Markdown { markdown } => {
360 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
361 }
362 }
363 }
364
365 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
366 match self {
367 ContentBlock::Empty => "",
368 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
369 }
370 }
371
372 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
373 match self {
374 ContentBlock::Empty => None,
375 ContentBlock::Markdown { markdown } => Some(markdown),
376 }
377 }
378}
379
380#[derive(Debug)]
381pub enum ToolCallContent {
382 ContentBlock { content: ContentBlock },
383 Diff { diff: Diff },
384}
385
386impl ToolCallContent {
387 pub fn from_acp(
388 content: acp::ToolCallContent,
389 language_registry: Arc<LanguageRegistry>,
390 cx: &mut App,
391 ) -> Self {
392 match content {
393 acp::ToolCallContent::Content { content } => Self::ContentBlock {
394 content: ContentBlock::new(content, &language_registry, cx),
395 },
396 acp::ToolCallContent::Diff { diff } => Self::Diff {
397 diff: Diff::from_acp(diff, language_registry, cx),
398 },
399 }
400 }
401
402 pub fn to_markdown(&self, cx: &App) -> String {
403 match self {
404 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
405 Self::Diff { diff } => diff.to_markdown(cx),
406 }
407 }
408}
409
410#[derive(Debug)]
411pub struct Diff {
412 pub multibuffer: Entity<MultiBuffer>,
413 pub path: PathBuf,
414 pub new_buffer: Entity<Buffer>,
415 pub old_buffer: Entity<Buffer>,
416 _task: Task<Result<()>>,
417}
418
419impl Diff {
420 pub fn from_acp(
421 diff: acp::Diff,
422 language_registry: Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let acp::Diff {
426 path,
427 old_text,
428 new_text,
429 } = diff;
430
431 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
432
433 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
434 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
435 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
436 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
437 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
438 let diff_task = buffer_diff.update(cx, |diff, cx| {
439 diff.set_base_text(
440 old_buffer_snapshot,
441 Some(language_registry.clone()),
442 new_buffer_snapshot,
443 cx,
444 )
445 });
446
447 let task = cx.spawn({
448 let multibuffer = multibuffer.clone();
449 let path = path.clone();
450 let new_buffer = new_buffer.clone();
451 async move |cx| {
452 diff_task.await?;
453
454 multibuffer
455 .update(cx, |multibuffer, cx| {
456 let hunk_ranges = {
457 let buffer = new_buffer.read(cx);
458 let diff = buffer_diff.read(cx);
459 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
460 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
461 .collect::<Vec<_>>()
462 };
463
464 multibuffer.set_excerpts_for_path(
465 PathKey::for_buffer(&new_buffer, cx),
466 new_buffer.clone(),
467 hunk_ranges,
468 editor::DEFAULT_MULTIBUFFER_CONTEXT,
469 cx,
470 );
471 multibuffer.add_diff(buffer_diff.clone(), cx);
472 })
473 .log_err();
474
475 if let Some(language) = language_registry
476 .language_for_file_path(&path)
477 .await
478 .log_err()
479 {
480 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
481 }
482
483 anyhow::Ok(())
484 }
485 });
486
487 Self {
488 multibuffer,
489 path,
490 new_buffer,
491 old_buffer,
492 _task: task,
493 }
494 }
495
496 fn to_markdown(&self, cx: &App) -> String {
497 let buffer_text = self
498 .multibuffer
499 .read(cx)
500 .all_buffers()
501 .iter()
502 .map(|buffer| buffer.read(cx).text())
503 .join("\n");
504 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
505 }
506}
507
508#[derive(Debug, Default)]
509pub struct Plan {
510 pub entries: Vec<PlanEntry>,
511}
512
513#[derive(Debug)]
514pub struct PlanStats<'a> {
515 pub in_progress_entry: Option<&'a PlanEntry>,
516 pub pending: u32,
517 pub completed: u32,
518}
519
520impl Plan {
521 pub fn is_empty(&self) -> bool {
522 self.entries.is_empty()
523 }
524
525 pub fn stats(&self) -> PlanStats<'_> {
526 let mut stats = PlanStats {
527 in_progress_entry: None,
528 pending: 0,
529 completed: 0,
530 };
531
532 for entry in &self.entries {
533 match &entry.status {
534 acp::PlanEntryStatus::Pending => {
535 stats.pending += 1;
536 }
537 acp::PlanEntryStatus::InProgress => {
538 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
539 }
540 acp::PlanEntryStatus::Completed => {
541 stats.completed += 1;
542 }
543 }
544 }
545
546 stats
547 }
548}
549
550#[derive(Debug)]
551pub struct PlanEntry {
552 pub content: Entity<Markdown>,
553 pub priority: acp::PlanEntryPriority,
554 pub status: acp::PlanEntryStatus,
555}
556
557impl PlanEntry {
558 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
559 Self {
560 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
561 priority: entry.priority,
562 status: entry.status,
563 }
564 }
565}
566
567pub struct AcpThread {
568 title: SharedString,
569 entries: Vec<AgentThreadEntry>,
570 plan: Plan,
571 project: Entity<Project>,
572 action_log: Entity<ActionLog>,
573 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
574 send_task: Option<Task<()>>,
575 connection: Rc<dyn AgentConnection>,
576 session_id: acp::SessionId,
577}
578
579pub enum AcpThreadEvent {
580 NewEntry,
581 EntryUpdated(usize),
582 ToolAuthorizationRequired,
583 Stopped,
584 Error,
585 ServerExited(ExitStatus),
586}
587
588impl EventEmitter<AcpThreadEvent> for AcpThread {}
589
590#[derive(PartialEq, Eq)]
591pub enum ThreadStatus {
592 Idle,
593 WaitingForToolConfirmation,
594 Generating,
595}
596
597#[derive(Debug, Clone)]
598pub enum LoadError {
599 Unsupported {
600 error_message: SharedString,
601 upgrade_message: SharedString,
602 upgrade_command: String,
603 },
604 Exited(i32),
605 Other(SharedString),
606}
607
608impl Display for LoadError {
609 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
610 match self {
611 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
612 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
613 LoadError::Other(msg) => write!(f, "{}", msg),
614 }
615 }
616}
617
618impl Error for LoadError {}
619
620impl AcpThread {
621 pub fn new(
622 title: impl Into<SharedString>,
623 connection: Rc<dyn AgentConnection>,
624 project: Entity<Project>,
625 session_id: acp::SessionId,
626 cx: &mut Context<Self>,
627 ) -> Self {
628 let action_log = cx.new(|_| ActionLog::new(project.clone()));
629
630 Self {
631 action_log,
632 shared_buffers: Default::default(),
633 entries: Default::default(),
634 plan: Default::default(),
635 title: title.into(),
636 project,
637 send_task: None,
638 connection,
639 session_id,
640 }
641 }
642
643 pub fn action_log(&self) -> &Entity<ActionLog> {
644 &self.action_log
645 }
646
647 pub fn project(&self) -> &Entity<Project> {
648 &self.project
649 }
650
651 pub fn title(&self) -> SharedString {
652 self.title.clone()
653 }
654
655 pub fn entries(&self) -> &[AgentThreadEntry] {
656 &self.entries
657 }
658
659 pub fn status(&self) -> ThreadStatus {
660 if self.send_task.is_some() {
661 if self.waiting_for_tool_confirmation() {
662 ThreadStatus::WaitingForToolConfirmation
663 } else {
664 ThreadStatus::Generating
665 }
666 } else {
667 ThreadStatus::Idle
668 }
669 }
670
671 pub fn has_pending_edit_tool_calls(&self) -> bool {
672 for entry in self.entries.iter().rev() {
673 match entry {
674 AgentThreadEntry::UserMessage(_) => return false,
675 AgentThreadEntry::ToolCall(
676 call @ ToolCall {
677 status:
678 ToolCallStatus::Allowed {
679 status:
680 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
681 },
682 ..
683 },
684 ) if call.diffs().next().is_some() => {
685 return true;
686 }
687 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
688 }
689 }
690
691 false
692 }
693
694 pub fn used_tools_since_last_user_message(&self) -> bool {
695 for entry in self.entries.iter().rev() {
696 match entry {
697 AgentThreadEntry::UserMessage(..) => return false,
698 AgentThreadEntry::AssistantMessage(..) => continue,
699 AgentThreadEntry::ToolCall(..) => return true,
700 }
701 }
702
703 false
704 }
705
706 pub fn handle_session_update(
707 &mut self,
708 update: acp::SessionUpdate,
709 cx: &mut Context<Self>,
710 ) -> Result<()> {
711 match update {
712 acp::SessionUpdate::UserMessageChunk { content } => {
713 self.push_user_content_block(content, cx);
714 }
715 acp::SessionUpdate::AgentMessageChunk { content } => {
716 self.push_assistant_content_block(content, false, cx);
717 }
718 acp::SessionUpdate::AgentThoughtChunk { content } => {
719 self.push_assistant_content_block(content, true, cx);
720 }
721 acp::SessionUpdate::ToolCall(tool_call) => {
722 self.upsert_tool_call(tool_call, cx);
723 }
724 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
725 self.update_tool_call(tool_call_update, cx)?;
726 }
727 acp::SessionUpdate::Plan(plan) => {
728 self.update_plan(plan, cx);
729 }
730 }
731 Ok(())
732 }
733
734 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
735 let language_registry = self.project.read(cx).languages().clone();
736 let entries_len = self.entries.len();
737
738 if let Some(last_entry) = self.entries.last_mut()
739 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
740 {
741 content.append(chunk, &language_registry, cx);
742 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
743 } else {
744 let content = ContentBlock::new(chunk, &language_registry, cx);
745 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
746 }
747 }
748
749 pub fn push_assistant_content_block(
750 &mut self,
751 chunk: acp::ContentBlock,
752 is_thought: bool,
753 cx: &mut Context<Self>,
754 ) {
755 let language_registry = self.project.read(cx).languages().clone();
756 let entries_len = self.entries.len();
757 if let Some(last_entry) = self.entries.last_mut()
758 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
759 {
760 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
761 match (chunks.last_mut(), is_thought) {
762 (Some(AssistantMessageChunk::Message { block }), false)
763 | (Some(AssistantMessageChunk::Thought { block }), true) => {
764 block.append(chunk, &language_registry, cx)
765 }
766 _ => {
767 let block = ContentBlock::new(chunk, &language_registry, cx);
768 if is_thought {
769 chunks.push(AssistantMessageChunk::Thought { block })
770 } else {
771 chunks.push(AssistantMessageChunk::Message { block })
772 }
773 }
774 }
775 } else {
776 let block = ContentBlock::new(chunk, &language_registry, cx);
777 let chunk = if is_thought {
778 AssistantMessageChunk::Thought { block }
779 } else {
780 AssistantMessageChunk::Message { block }
781 };
782
783 self.push_entry(
784 AgentThreadEntry::AssistantMessage(AssistantMessage {
785 chunks: vec![chunk],
786 }),
787 cx,
788 );
789 }
790 }
791
792 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
793 self.entries.push(entry);
794 cx.emit(AcpThreadEvent::NewEntry);
795 }
796
797 pub fn update_tool_call(
798 &mut self,
799 update: acp::ToolCallUpdate,
800 cx: &mut Context<Self>,
801 ) -> Result<()> {
802 let languages = self.project.read(cx).languages().clone();
803
804 let (ix, current_call) = self
805 .tool_call_mut(&update.id)
806 .context("Tool call not found")?;
807 current_call.update(update.fields, languages, cx);
808
809 cx.emit(AcpThreadEvent::EntryUpdated(ix));
810
811 Ok(())
812 }
813
814 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
815 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
816 let status = ToolCallStatus::Allowed {
817 status: tool_call.status,
818 };
819 self.upsert_tool_call_inner(tool_call, status, cx)
820 }
821
822 pub fn upsert_tool_call_inner(
823 &mut self,
824 tool_call: acp::ToolCall,
825 status: ToolCallStatus,
826 cx: &mut Context<Self>,
827 ) {
828 let language_registry = self.project.read(cx).languages().clone();
829 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
830
831 let location = call.locations.last().cloned();
832
833 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
834 *current_call = call;
835
836 cx.emit(AcpThreadEvent::EntryUpdated(ix));
837 } else {
838 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
839 }
840
841 if let Some(location) = location {
842 self.set_project_location(location, cx)
843 }
844 }
845
846 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
847 // The tool call we are looking for is typically the last one, or very close to the end.
848 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
849 self.entries
850 .iter_mut()
851 .enumerate()
852 .rev()
853 .find_map(|(index, tool_call)| {
854 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
855 && &tool_call.id == id
856 {
857 Some((index, tool_call))
858 } else {
859 None
860 }
861 })
862 }
863
864 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
865 self.project.update(cx, |project, cx| {
866 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
867 return;
868 };
869 let buffer = project.open_buffer(path, cx);
870 cx.spawn(async move |project, cx| {
871 let buffer = buffer.await?;
872
873 project.update(cx, |project, cx| {
874 let position = if let Some(line) = location.line {
875 let snapshot = buffer.read(cx).snapshot();
876 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
877 snapshot.anchor_before(point)
878 } else {
879 Anchor::MIN
880 };
881
882 project.set_agent_location(
883 Some(AgentLocation {
884 buffer: buffer.downgrade(),
885 position,
886 }),
887 cx,
888 );
889 })
890 })
891 .detach_and_log_err(cx);
892 });
893 }
894
895 pub fn request_tool_call_permission(
896 &mut self,
897 tool_call: acp::ToolCall,
898 options: Vec<acp::PermissionOption>,
899 cx: &mut Context<Self>,
900 ) -> oneshot::Receiver<acp::PermissionOptionId> {
901 let (tx, rx) = oneshot::channel();
902
903 let status = ToolCallStatus::WaitingForConfirmation {
904 options,
905 respond_tx: tx,
906 };
907
908 self.upsert_tool_call_inner(tool_call, status, cx);
909 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
910 rx
911 }
912
913 pub fn authorize_tool_call(
914 &mut self,
915 id: acp::ToolCallId,
916 option_id: acp::PermissionOptionId,
917 option_kind: acp::PermissionOptionKind,
918 cx: &mut Context<Self>,
919 ) {
920 let Some((ix, call)) = self.tool_call_mut(&id) else {
921 return;
922 };
923
924 let new_status = match option_kind {
925 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
926 ToolCallStatus::Rejected
927 }
928 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
929 ToolCallStatus::Allowed {
930 status: acp::ToolCallStatus::InProgress,
931 }
932 }
933 };
934
935 let curr_status = mem::replace(&mut call.status, new_status);
936
937 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
938 respond_tx.send(option_id).log_err();
939 } else if cfg!(debug_assertions) {
940 panic!("tried to authorize an already authorized tool call");
941 }
942
943 cx.emit(AcpThreadEvent::EntryUpdated(ix));
944 }
945
946 /// Returns true if the last turn is awaiting tool authorization
947 pub fn waiting_for_tool_confirmation(&self) -> bool {
948 for entry in self.entries.iter().rev() {
949 match &entry {
950 AgentThreadEntry::ToolCall(call) => match call.status {
951 ToolCallStatus::WaitingForConfirmation { .. } => return true,
952 ToolCallStatus::Allowed { .. }
953 | ToolCallStatus::Rejected
954 | ToolCallStatus::Canceled => continue,
955 },
956 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
957 // Reached the beginning of the turn
958 return false;
959 }
960 }
961 }
962 false
963 }
964
965 pub fn plan(&self) -> &Plan {
966 &self.plan
967 }
968
969 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
970 self.plan = Plan {
971 entries: request
972 .entries
973 .into_iter()
974 .map(|entry| PlanEntry::from_acp(entry, cx))
975 .collect(),
976 };
977
978 cx.notify();
979 }
980
981 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
982 self.plan
983 .entries
984 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
985 cx.notify();
986 }
987
988 #[cfg(any(test, feature = "test-support"))]
989 pub fn send_raw(
990 &mut self,
991 message: &str,
992 cx: &mut Context<Self>,
993 ) -> BoxFuture<'static, Result<()>> {
994 self.send(
995 vec![acp::ContentBlock::Text(acp::TextContent {
996 text: message.to_string(),
997 annotations: None,
998 })],
999 cx,
1000 )
1001 }
1002
1003 pub fn send(
1004 &mut self,
1005 message: Vec<acp::ContentBlock>,
1006 cx: &mut Context<Self>,
1007 ) -> BoxFuture<'static, Result<()>> {
1008 let block = ContentBlock::new_combined(
1009 message.clone(),
1010 self.project.read(cx).languages().clone(),
1011 cx,
1012 );
1013 self.push_entry(
1014 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1015 cx,
1016 );
1017 self.clear_completed_plan_entries(cx);
1018
1019 let (tx, rx) = oneshot::channel();
1020 let cancel_task = self.cancel(cx);
1021
1022 self.send_task = Some(cx.spawn(async move |this, cx| {
1023 async {
1024 cancel_task.await;
1025
1026 let result = this
1027 .update(cx, |this, cx| {
1028 this.connection.prompt(
1029 acp::PromptRequest {
1030 prompt: message,
1031 session_id: this.session_id.clone(),
1032 },
1033 cx,
1034 )
1035 })?
1036 .await;
1037 tx.send(result).log_err();
1038 this.update(cx, |this, _cx| this.send_task.take())?;
1039 anyhow::Ok(())
1040 }
1041 .await
1042 .log_err();
1043 }));
1044
1045 cx.spawn(async move |this, cx| match rx.await {
1046 Ok(Err(e)) => {
1047 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1048 .log_err();
1049 Err(e)?
1050 }
1051 _ => {
1052 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1053 .log_err();
1054 Ok(())
1055 }
1056 })
1057 .boxed()
1058 }
1059
1060 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1061 let Some(send_task) = self.send_task.take() else {
1062 return Task::ready(());
1063 };
1064
1065 for entry in self.entries.iter_mut() {
1066 if let AgentThreadEntry::ToolCall(call) = entry {
1067 let cancel = matches!(
1068 call.status,
1069 ToolCallStatus::WaitingForConfirmation { .. }
1070 | ToolCallStatus::Allowed {
1071 status: acp::ToolCallStatus::InProgress
1072 }
1073 );
1074
1075 if cancel {
1076 call.status = ToolCallStatus::Canceled;
1077 }
1078 }
1079 }
1080
1081 self.connection.cancel(&self.session_id, cx);
1082
1083 // Wait for the send task to complete
1084 cx.foreground_executor().spawn(send_task)
1085 }
1086
1087 pub fn read_text_file(
1088 &self,
1089 path: PathBuf,
1090 line: Option<u32>,
1091 limit: Option<u32>,
1092 reuse_shared_snapshot: bool,
1093 cx: &mut Context<Self>,
1094 ) -> Task<Result<String>> {
1095 let project = self.project.clone();
1096 let action_log = self.action_log.clone();
1097 cx.spawn(async move |this, cx| {
1098 let load = project.update(cx, |project, cx| {
1099 let path = project
1100 .project_path_for_absolute_path(&path, cx)
1101 .context("invalid path")?;
1102 anyhow::Ok(project.open_buffer(path, cx))
1103 });
1104 let buffer = load??.await?;
1105
1106 let snapshot = if reuse_shared_snapshot {
1107 this.read_with(cx, |this, _| {
1108 this.shared_buffers.get(&buffer.clone()).cloned()
1109 })
1110 .log_err()
1111 .flatten()
1112 } else {
1113 None
1114 };
1115
1116 let snapshot = if let Some(snapshot) = snapshot {
1117 snapshot
1118 } else {
1119 action_log.update(cx, |action_log, cx| {
1120 action_log.buffer_read(buffer.clone(), cx);
1121 })?;
1122 project.update(cx, |project, cx| {
1123 let position = buffer
1124 .read(cx)
1125 .snapshot()
1126 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1127 project.set_agent_location(
1128 Some(AgentLocation {
1129 buffer: buffer.downgrade(),
1130 position,
1131 }),
1132 cx,
1133 );
1134 })?;
1135
1136 buffer.update(cx, |buffer, _| buffer.snapshot())?
1137 };
1138
1139 this.update(cx, |this, _| {
1140 let text = snapshot.text();
1141 this.shared_buffers.insert(buffer.clone(), snapshot);
1142 if line.is_none() && limit.is_none() {
1143 return Ok(text);
1144 }
1145 let limit = limit.unwrap_or(u32::MAX) as usize;
1146 let Some(line) = line else {
1147 return Ok(text.lines().take(limit).collect::<String>());
1148 };
1149
1150 let count = text.lines().count();
1151 if count < line as usize {
1152 anyhow::bail!("There are only {} lines", count);
1153 }
1154 Ok(text
1155 .lines()
1156 .skip(line as usize + 1)
1157 .take(limit)
1158 .collect::<String>())
1159 })?
1160 })
1161 }
1162
1163 pub fn write_text_file(
1164 &self,
1165 path: PathBuf,
1166 content: String,
1167 cx: &mut Context<Self>,
1168 ) -> Task<Result<()>> {
1169 let project = self.project.clone();
1170 let action_log = self.action_log.clone();
1171 cx.spawn(async move |this, cx| {
1172 let load = project.update(cx, |project, cx| {
1173 let path = project
1174 .project_path_for_absolute_path(&path, cx)
1175 .context("invalid path")?;
1176 anyhow::Ok(project.open_buffer(path, cx))
1177 });
1178 let buffer = load??.await?;
1179 let snapshot = this.update(cx, |this, cx| {
1180 this.shared_buffers
1181 .get(&buffer)
1182 .cloned()
1183 .unwrap_or_else(|| buffer.read(cx).snapshot())
1184 })?;
1185 let edits = cx
1186 .background_executor()
1187 .spawn(async move {
1188 let old_text = snapshot.text();
1189 text_diff(old_text.as_str(), &content)
1190 .into_iter()
1191 .map(|(range, replacement)| {
1192 (
1193 snapshot.anchor_after(range.start)
1194 ..snapshot.anchor_before(range.end),
1195 replacement,
1196 )
1197 })
1198 .collect::<Vec<_>>()
1199 })
1200 .await;
1201 cx.update(|cx| {
1202 project.update(cx, |project, cx| {
1203 project.set_agent_location(
1204 Some(AgentLocation {
1205 buffer: buffer.downgrade(),
1206 position: edits
1207 .last()
1208 .map(|(range, _)| range.end)
1209 .unwrap_or(Anchor::MIN),
1210 }),
1211 cx,
1212 );
1213 });
1214
1215 action_log.update(cx, |action_log, cx| {
1216 action_log.buffer_read(buffer.clone(), cx);
1217 });
1218 buffer.update(cx, |buffer, cx| {
1219 buffer.edit(edits, None, cx);
1220 });
1221 action_log.update(cx, |action_log, cx| {
1222 action_log.buffer_edited(buffer.clone(), cx);
1223 });
1224 })?;
1225 project
1226 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1227 .await
1228 })
1229 }
1230
1231 pub fn to_markdown(&self, cx: &App) -> String {
1232 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1233 }
1234
1235 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1236 cx.emit(AcpThreadEvent::ServerExited(status));
1237 }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243 use anyhow::anyhow;
1244 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1245 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1246 use indoc::indoc;
1247 use project::FakeFs;
1248 use rand::Rng as _;
1249 use serde_json::json;
1250 use settings::SettingsStore;
1251 use smol::stream::StreamExt as _;
1252 use std::{cell::RefCell, rc::Rc, time::Duration};
1253
1254 use util::path;
1255
1256 fn init_test(cx: &mut TestAppContext) {
1257 env_logger::try_init().ok();
1258 cx.update(|cx| {
1259 let settings_store = SettingsStore::test(cx);
1260 cx.set_global(settings_store);
1261 Project::init_settings(cx);
1262 language::init(cx);
1263 });
1264 }
1265
1266 #[gpui::test]
1267 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1268 init_test(cx);
1269
1270 let fs = FakeFs::new(cx.executor());
1271 let project = Project::test(fs, [], cx).await;
1272 let connection = Rc::new(FakeAgentConnection::new());
1273 let thread = cx
1274 .spawn(async move |mut cx| {
1275 connection
1276 .new_thread(project, Path::new(path!("/test")), &mut cx)
1277 .await
1278 })
1279 .await
1280 .unwrap();
1281
1282 // Test creating a new user message
1283 thread.update(cx, |thread, cx| {
1284 thread.push_user_content_block(
1285 acp::ContentBlock::Text(acp::TextContent {
1286 annotations: None,
1287 text: "Hello, ".to_string(),
1288 }),
1289 cx,
1290 );
1291 });
1292
1293 thread.update(cx, |thread, cx| {
1294 assert_eq!(thread.entries.len(), 1);
1295 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1296 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1297 } else {
1298 panic!("Expected UserMessage");
1299 }
1300 });
1301
1302 // Test appending to existing user message
1303 thread.update(cx, |thread, cx| {
1304 thread.push_user_content_block(
1305 acp::ContentBlock::Text(acp::TextContent {
1306 annotations: None,
1307 text: "world!".to_string(),
1308 }),
1309 cx,
1310 );
1311 });
1312
1313 thread.update(cx, |thread, cx| {
1314 assert_eq!(thread.entries.len(), 1);
1315 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1316 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1317 } else {
1318 panic!("Expected UserMessage");
1319 }
1320 });
1321
1322 // Test creating new user message after assistant message
1323 thread.update(cx, |thread, cx| {
1324 thread.push_assistant_content_block(
1325 acp::ContentBlock::Text(acp::TextContent {
1326 annotations: None,
1327 text: "Assistant response".to_string(),
1328 }),
1329 false,
1330 cx,
1331 );
1332 });
1333
1334 thread.update(cx, |thread, cx| {
1335 thread.push_user_content_block(
1336 acp::ContentBlock::Text(acp::TextContent {
1337 annotations: None,
1338 text: "New user message".to_string(),
1339 }),
1340 cx,
1341 );
1342 });
1343
1344 thread.update(cx, |thread, cx| {
1345 assert_eq!(thread.entries.len(), 3);
1346 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1347 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1348 } else {
1349 panic!("Expected UserMessage at index 2");
1350 }
1351 });
1352 }
1353
1354 #[gpui::test]
1355 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1356 init_test(cx);
1357
1358 let fs = FakeFs::new(cx.executor());
1359 let project = Project::test(fs, [], cx).await;
1360 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1361 |_, thread, mut cx| {
1362 async move {
1363 thread.update(&mut cx, |thread, cx| {
1364 thread
1365 .handle_session_update(
1366 acp::SessionUpdate::AgentThoughtChunk {
1367 content: "Thinking ".into(),
1368 },
1369 cx,
1370 )
1371 .unwrap();
1372 thread
1373 .handle_session_update(
1374 acp::SessionUpdate::AgentThoughtChunk {
1375 content: "hard!".into(),
1376 },
1377 cx,
1378 )
1379 .unwrap();
1380 })
1381 }
1382 .boxed_local()
1383 },
1384 ));
1385
1386 let thread = cx
1387 .spawn(async move |mut cx| {
1388 connection
1389 .new_thread(project, Path::new(path!("/test")), &mut cx)
1390 .await
1391 })
1392 .await
1393 .unwrap();
1394
1395 thread
1396 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1397 .await
1398 .unwrap();
1399
1400 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1401 assert_eq!(
1402 output,
1403 indoc! {r#"
1404 ## User
1405
1406 Hello from Zed!
1407
1408 ## Assistant
1409
1410 <thinking>
1411 Thinking hard!
1412 </thinking>
1413
1414 "#}
1415 );
1416 }
1417
1418 #[gpui::test]
1419 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1420 init_test(cx);
1421
1422 let fs = FakeFs::new(cx.executor());
1423 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1424 .await;
1425 let project = Project::test(fs.clone(), [], cx).await;
1426 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1427 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1428 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1429 move |_, thread, mut cx| {
1430 let read_file_tx = read_file_tx.clone();
1431 async move {
1432 let content = thread
1433 .update(&mut cx, |thread, cx| {
1434 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1435 })
1436 .unwrap()
1437 .await
1438 .unwrap();
1439 assert_eq!(content, "one\ntwo\nthree\n");
1440 read_file_tx.take().unwrap().send(()).unwrap();
1441 thread
1442 .update(&mut cx, |thread, cx| {
1443 thread.write_text_file(
1444 path!("/tmp/foo").into(),
1445 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1446 cx,
1447 )
1448 })
1449 .unwrap()
1450 .await
1451 .unwrap();
1452 Ok(())
1453 }
1454 .boxed_local()
1455 },
1456 ));
1457
1458 let (worktree, pathbuf) = project
1459 .update(cx, |project, cx| {
1460 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1461 })
1462 .await
1463 .unwrap();
1464 let buffer = project
1465 .update(cx, |project, cx| {
1466 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1467 })
1468 .await
1469 .unwrap();
1470
1471 let thread = cx
1472 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1473 .await
1474 .unwrap();
1475
1476 let request = thread.update(cx, |thread, cx| {
1477 thread.send_raw("Extend the count in /tmp/foo", cx)
1478 });
1479 read_file_rx.await.ok();
1480 buffer.update(cx, |buffer, cx| {
1481 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1482 });
1483 cx.run_until_parked();
1484 assert_eq!(
1485 buffer.read_with(cx, |buffer, _| buffer.text()),
1486 "zero\none\ntwo\nthree\nfour\nfive\n"
1487 );
1488 assert_eq!(
1489 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1490 "zero\none\ntwo\nthree\nfour\nfive\n"
1491 );
1492 request.await.unwrap();
1493 }
1494
1495 #[gpui::test]
1496 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1497 init_test(cx);
1498
1499 let fs = FakeFs::new(cx.executor());
1500 let project = Project::test(fs, [], cx).await;
1501 let id = acp::ToolCallId("test".into());
1502
1503 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1504 let id = id.clone();
1505 move |_, thread, mut cx| {
1506 let id = id.clone();
1507 async move {
1508 thread
1509 .update(&mut cx, |thread, cx| {
1510 thread.handle_session_update(
1511 acp::SessionUpdate::ToolCall(acp::ToolCall {
1512 id: id.clone(),
1513 title: "Label".into(),
1514 kind: acp::ToolKind::Fetch,
1515 status: acp::ToolCallStatus::InProgress,
1516 content: vec![],
1517 locations: vec![],
1518 raw_input: None,
1519 }),
1520 cx,
1521 )
1522 })
1523 .unwrap()
1524 .unwrap();
1525 Ok(())
1526 }
1527 .boxed_local()
1528 }
1529 }));
1530
1531 let thread = cx
1532 .spawn(async move |mut cx| {
1533 connection
1534 .new_thread(project, Path::new(path!("/test")), &mut cx)
1535 .await
1536 })
1537 .await
1538 .unwrap();
1539
1540 let request = thread.update(cx, |thread, cx| {
1541 thread.send_raw("Fetch https://example.com", cx)
1542 });
1543
1544 run_until_first_tool_call(&thread, cx).await;
1545
1546 thread.read_with(cx, |thread, _| {
1547 assert!(matches!(
1548 thread.entries[1],
1549 AgentThreadEntry::ToolCall(ToolCall {
1550 status: ToolCallStatus::Allowed {
1551 status: acp::ToolCallStatus::InProgress,
1552 ..
1553 },
1554 ..
1555 })
1556 ));
1557 });
1558
1559 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1560
1561 thread.read_with(cx, |thread, _| {
1562 assert!(matches!(
1563 &thread.entries[1],
1564 AgentThreadEntry::ToolCall(ToolCall {
1565 status: ToolCallStatus::Canceled,
1566 ..
1567 })
1568 ));
1569 });
1570
1571 thread
1572 .update(cx, |thread, cx| {
1573 thread.handle_session_update(
1574 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1575 id,
1576 fields: acp::ToolCallUpdateFields {
1577 status: Some(acp::ToolCallStatus::Completed),
1578 ..Default::default()
1579 },
1580 }),
1581 cx,
1582 )
1583 })
1584 .unwrap();
1585
1586 request.await.unwrap();
1587
1588 thread.read_with(cx, |thread, _| {
1589 assert!(matches!(
1590 thread.entries[1],
1591 AgentThreadEntry::ToolCall(ToolCall {
1592 status: ToolCallStatus::Allowed {
1593 status: acp::ToolCallStatus::Completed,
1594 ..
1595 },
1596 ..
1597 })
1598 ));
1599 });
1600 }
1601
1602 #[gpui::test]
1603 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1604 init_test(cx);
1605 let fs = FakeFs::new(cx.background_executor.clone());
1606 fs.insert_tree(path!("/test"), json!({})).await;
1607 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1608
1609 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1610 move |_, thread, mut cx| {
1611 async move {
1612 thread
1613 .update(&mut cx, |thread, cx| {
1614 thread.handle_session_update(
1615 acp::SessionUpdate::ToolCall(acp::ToolCall {
1616 id: acp::ToolCallId("test".into()),
1617 title: "Label".into(),
1618 kind: acp::ToolKind::Edit,
1619 status: acp::ToolCallStatus::Completed,
1620 content: vec![acp::ToolCallContent::Diff {
1621 diff: acp::Diff {
1622 path: "/test/test.txt".into(),
1623 old_text: None,
1624 new_text: "foo".into(),
1625 },
1626 }],
1627 locations: vec![],
1628 raw_input: None,
1629 }),
1630 cx,
1631 )
1632 })
1633 .unwrap()
1634 .unwrap();
1635 Ok(())
1636 }
1637 .boxed_local()
1638 }
1639 }));
1640
1641 let thread = connection
1642 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1643 .await
1644 .unwrap();
1645 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1646 .await
1647 .unwrap();
1648
1649 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1650 }
1651
1652 async fn run_until_first_tool_call(
1653 thread: &Entity<AcpThread>,
1654 cx: &mut TestAppContext,
1655 ) -> usize {
1656 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1657
1658 let subscription = cx.update(|cx| {
1659 cx.subscribe(thread, move |thread, _, cx| {
1660 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1661 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1662 return tx.try_send(ix).unwrap();
1663 }
1664 }
1665 })
1666 });
1667
1668 select! {
1669 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1670 panic!("Timeout waiting for tool call")
1671 }
1672 ix = rx.next().fuse() => {
1673 drop(subscription);
1674 ix.unwrap()
1675 }
1676 }
1677 }
1678
1679 #[derive(Clone, Default)]
1680 struct FakeAgentConnection {
1681 auth_methods: Vec<acp::AuthMethod>,
1682 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1683 on_user_message: Option<
1684 Rc<
1685 dyn Fn(
1686 acp::PromptRequest,
1687 WeakEntity<AcpThread>,
1688 AsyncApp,
1689 ) -> LocalBoxFuture<'static, Result<()>>
1690 + 'static,
1691 >,
1692 >,
1693 }
1694
1695 impl FakeAgentConnection {
1696 fn new() -> Self {
1697 Self {
1698 auth_methods: Vec::new(),
1699 on_user_message: None,
1700 sessions: Arc::default(),
1701 }
1702 }
1703
1704 #[expect(unused)]
1705 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1706 self.auth_methods = auth_methods;
1707 self
1708 }
1709
1710 fn on_user_message(
1711 mut self,
1712 handler: impl Fn(
1713 acp::PromptRequest,
1714 WeakEntity<AcpThread>,
1715 AsyncApp,
1716 ) -> LocalBoxFuture<'static, Result<()>>
1717 + 'static,
1718 ) -> Self {
1719 self.on_user_message.replace(Rc::new(handler));
1720 self
1721 }
1722 }
1723
1724 impl AgentConnection for FakeAgentConnection {
1725 fn auth_methods(&self) -> &[acp::AuthMethod] {
1726 &self.auth_methods
1727 }
1728
1729 fn new_thread(
1730 self: Rc<Self>,
1731 project: Entity<Project>,
1732 _cwd: &Path,
1733 cx: &mut gpui::AsyncApp,
1734 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1735 let session_id = acp::SessionId(
1736 rand::thread_rng()
1737 .sample_iter(&rand::distributions::Alphanumeric)
1738 .take(7)
1739 .map(char::from)
1740 .collect::<String>()
1741 .into(),
1742 );
1743 let thread = cx
1744 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1745 .unwrap();
1746 self.sessions.lock().insert(session_id, thread.downgrade());
1747 Task::ready(Ok(thread))
1748 }
1749
1750 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1751 if self.auth_methods().iter().any(|m| m.id == method) {
1752 Task::ready(Ok(()))
1753 } else {
1754 Task::ready(Err(anyhow!("Invalid Auth Method")))
1755 }
1756 }
1757
1758 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
1759 let sessions = self.sessions.lock();
1760 let thread = sessions.get(¶ms.session_id).unwrap();
1761 if let Some(handler) = &self.on_user_message {
1762 let handler = handler.clone();
1763 let thread = thread.clone();
1764 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1765 } else {
1766 Task::ready(Ok(()))
1767 }
1768 }
1769
1770 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1771 let sessions = self.sessions.lock();
1772 let thread = sessions.get(&session_id).unwrap().clone();
1773
1774 cx.spawn(async move |cx| {
1775 thread
1776 .update(cx, |thread, cx| thread.cancel(cx))
1777 .unwrap()
1778 .await
1779 })
1780 .detach();
1781 }
1782 }
1783}