1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::process::ExitStatus;
22use std::rc::Rc;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(chunk.into(), language_registry, cx),
114 }
115 }
116
117 fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::Message { block } => block.to_markdown(cx).to_string(),
120 Self::Thought { block } => {
121 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
128pub enum AgentThreadEntry {
129 UserMessage(UserMessage),
130 AssistantMessage(AssistantMessage),
131 ToolCall(ToolCall),
132}
133
134impl AgentThreadEntry {
135 fn to_markdown(&self, cx: &App) -> String {
136 match self {
137 Self::UserMessage(message) => message.to_markdown(cx),
138 Self::AssistantMessage(message) => message.to_markdown(cx),
139 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
140 }
141 }
142
143 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.diffs())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
152 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
153 Some(locations)
154 } else {
155 None
156 }
157 }
158}
159
160#[derive(Debug)]
161pub struct ToolCall {
162 pub id: acp::ToolCallId,
163 pub label: Entity<Markdown>,
164 pub kind: acp::ToolKind,
165 pub content: Vec<ToolCallContent>,
166 pub status: ToolCallStatus,
167 pub locations: Vec<acp::ToolCallLocation>,
168 pub raw_input: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 }
198 }
199
200 fn update(
201 &mut self,
202 fields: acp::ToolCallUpdateFields,
203 language_registry: Arc<LanguageRegistry>,
204 cx: &mut App,
205 ) {
206 let acp::ToolCallUpdateFields {
207 kind,
208 status,
209 title,
210 content,
211 locations,
212 raw_input,
213 } = fields;
214
215 if let Some(kind) = kind {
216 self.kind = kind;
217 }
218
219 if let Some(status) = status {
220 self.status = ToolCallStatus::Allowed { status };
221 }
222
223 if let Some(title) = title {
224 self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
225 }
226
227 if let Some(content) = content {
228 self.content = content
229 .into_iter()
230 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
231 .collect();
232 }
233
234 if let Some(locations) = locations {
235 self.locations = locations;
236 }
237
238 if let Some(raw_input) = raw_input {
239 self.raw_input = Some(raw_input);
240 }
241 }
242
243 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::ContentBlock { .. } => None,
246 ToolCallContent::Diff { diff } => Some(diff),
247 })
248 }
249
250 fn to_markdown(&self, cx: &App) -> String {
251 let mut markdown = format!(
252 "**Tool Call: {}**\nStatus: {}\n\n",
253 self.label.read(cx).source(),
254 self.status
255 );
256 for content in &self.content {
257 markdown.push_str(content.to_markdown(cx).as_str());
258 markdown.push_str("\n\n");
259 }
260 markdown
261 }
262}
263
264#[derive(Debug)]
265pub enum ToolCallStatus {
266 WaitingForConfirmation {
267 options: Vec<acp::PermissionOption>,
268 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
269 },
270 Allowed {
271 status: acp::ToolCallStatus,
272 },
273 Rejected,
274 Canceled,
275}
276
277impl Display for ToolCallStatus {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 write!(
280 f,
281 "{}",
282 match self {
283 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
284 ToolCallStatus::Allowed { status } => match status {
285 acp::ToolCallStatus::Pending => "Pending",
286 acp::ToolCallStatus::InProgress => "In Progress",
287 acp::ToolCallStatus::Completed => "Completed",
288 acp::ToolCallStatus::Failed => "Failed",
289 },
290 ToolCallStatus::Rejected => "Rejected",
291 ToolCallStatus::Canceled => "Canceled",
292 }
293 )
294 }
295}
296
297#[derive(Debug, PartialEq, Clone)]
298pub enum ContentBlock {
299 Empty,
300 Markdown { markdown: Entity<Markdown> },
301}
302
303impl ContentBlock {
304 pub fn new(
305 block: acp::ContentBlock,
306 language_registry: &Arc<LanguageRegistry>,
307 cx: &mut App,
308 ) -> Self {
309 let mut this = Self::Empty;
310 this.append(block, language_registry, cx);
311 this
312 }
313
314 pub fn new_combined(
315 blocks: impl IntoIterator<Item = acp::ContentBlock>,
316 language_registry: Arc<LanguageRegistry>,
317 cx: &mut App,
318 ) -> Self {
319 let mut this = Self::Empty;
320 for block in blocks {
321 this.append(block, &language_registry, cx);
322 }
323 this
324 }
325
326 pub fn append(
327 &mut self,
328 block: acp::ContentBlock,
329 language_registry: &Arc<LanguageRegistry>,
330 cx: &mut App,
331 ) {
332 let new_content = match block {
333 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
334 acp::ContentBlock::ResourceLink(resource_link) => {
335 if let Some(path) = resource_link.uri.strip_prefix("file://") {
336 format!("{}", MentionPath(path.as_ref()))
337 } else {
338 resource_link.uri.clone()
339 }
340 }
341 acp::ContentBlock::Image(_)
342 | acp::ContentBlock::Audio(_)
343 | acp::ContentBlock::Resource(_) => String::new(),
344 };
345
346 match self {
347 ContentBlock::Empty => {
348 *self = ContentBlock::Markdown {
349 markdown: cx.new(|cx| {
350 Markdown::new(
351 new_content.into(),
352 Some(language_registry.clone()),
353 None,
354 cx,
355 )
356 }),
357 };
358 }
359 ContentBlock::Markdown { markdown } => {
360 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
361 }
362 }
363 }
364
365 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
366 match self {
367 ContentBlock::Empty => "",
368 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
369 }
370 }
371
372 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
373 match self {
374 ContentBlock::Empty => None,
375 ContentBlock::Markdown { markdown } => Some(markdown),
376 }
377 }
378}
379
380#[derive(Debug)]
381pub enum ToolCallContent {
382 ContentBlock { content: ContentBlock },
383 Diff { diff: Diff },
384}
385
386impl ToolCallContent {
387 pub fn from_acp(
388 content: acp::ToolCallContent,
389 language_registry: Arc<LanguageRegistry>,
390 cx: &mut App,
391 ) -> Self {
392 match content {
393 acp::ToolCallContent::Content { content } => Self::ContentBlock {
394 content: ContentBlock::new(content, &language_registry, cx),
395 },
396 acp::ToolCallContent::Diff { diff } => Self::Diff {
397 diff: Diff::from_acp(diff, language_registry, cx),
398 },
399 }
400 }
401
402 pub fn to_markdown(&self, cx: &App) -> String {
403 match self {
404 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
405 Self::Diff { diff } => diff.to_markdown(cx),
406 }
407 }
408}
409
410#[derive(Debug)]
411pub struct Diff {
412 pub multibuffer: Entity<MultiBuffer>,
413 pub path: PathBuf,
414 pub new_buffer: Entity<Buffer>,
415 pub old_buffer: Entity<Buffer>,
416 _task: Task<Result<()>>,
417}
418
419impl Diff {
420 pub fn from_acp(
421 diff: acp::Diff,
422 language_registry: Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let acp::Diff {
426 path,
427 old_text,
428 new_text,
429 } = diff;
430
431 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
432
433 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
434 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
435 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
436 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
437 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
438 let diff_task = buffer_diff.update(cx, |diff, cx| {
439 diff.set_base_text(
440 old_buffer_snapshot,
441 Some(language_registry.clone()),
442 new_buffer_snapshot,
443 cx,
444 )
445 });
446
447 let task = cx.spawn({
448 let multibuffer = multibuffer.clone();
449 let path = path.clone();
450 let new_buffer = new_buffer.clone();
451 async move |cx| {
452 diff_task.await?;
453
454 multibuffer
455 .update(cx, |multibuffer, cx| {
456 let hunk_ranges = {
457 let buffer = new_buffer.read(cx);
458 let diff = buffer_diff.read(cx);
459 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
460 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
461 .collect::<Vec<_>>()
462 };
463
464 multibuffer.set_excerpts_for_path(
465 PathKey::for_buffer(&new_buffer, cx),
466 new_buffer.clone(),
467 hunk_ranges,
468 editor::DEFAULT_MULTIBUFFER_CONTEXT,
469 cx,
470 );
471 multibuffer.add_diff(buffer_diff.clone(), cx);
472 })
473 .log_err();
474
475 if let Some(language) = language_registry
476 .language_for_file_path(&path)
477 .await
478 .log_err()
479 {
480 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
481 }
482
483 anyhow::Ok(())
484 }
485 });
486
487 Self {
488 multibuffer,
489 path,
490 new_buffer,
491 old_buffer,
492 _task: task,
493 }
494 }
495
496 fn to_markdown(&self, cx: &App) -> String {
497 let buffer_text = self
498 .multibuffer
499 .read(cx)
500 .all_buffers()
501 .iter()
502 .map(|buffer| buffer.read(cx).text())
503 .join("\n");
504 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
505 }
506}
507
508#[derive(Debug, Default)]
509pub struct Plan {
510 pub entries: Vec<PlanEntry>,
511}
512
513#[derive(Debug)]
514pub struct PlanStats<'a> {
515 pub in_progress_entry: Option<&'a PlanEntry>,
516 pub pending: u32,
517 pub completed: u32,
518}
519
520impl Plan {
521 pub fn is_empty(&self) -> bool {
522 self.entries.is_empty()
523 }
524
525 pub fn stats(&self) -> PlanStats<'_> {
526 let mut stats = PlanStats {
527 in_progress_entry: None,
528 pending: 0,
529 completed: 0,
530 };
531
532 for entry in &self.entries {
533 match &entry.status {
534 acp::PlanEntryStatus::Pending => {
535 stats.pending += 1;
536 }
537 acp::PlanEntryStatus::InProgress => {
538 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
539 }
540 acp::PlanEntryStatus::Completed => {
541 stats.completed += 1;
542 }
543 }
544 }
545
546 stats
547 }
548}
549
550#[derive(Debug)]
551pub struct PlanEntry {
552 pub content: Entity<Markdown>,
553 pub priority: acp::PlanEntryPriority,
554 pub status: acp::PlanEntryStatus,
555}
556
557impl PlanEntry {
558 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
559 Self {
560 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
561 priority: entry.priority,
562 status: entry.status,
563 }
564 }
565}
566
567pub struct AcpThread {
568 title: SharedString,
569 entries: Vec<AgentThreadEntry>,
570 plan: Plan,
571 project: Entity<Project>,
572 action_log: Entity<ActionLog>,
573 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
574 send_task: Option<Task<()>>,
575 connection: Rc<dyn AgentConnection>,
576 session_id: acp::SessionId,
577}
578
579pub enum AcpThreadEvent {
580 NewEntry,
581 EntryUpdated(usize),
582 ToolAuthorizationRequired,
583 Stopped,
584 Error,
585 ServerExited(ExitStatus),
586}
587
588impl EventEmitter<AcpThreadEvent> for AcpThread {}
589
590#[derive(PartialEq, Eq)]
591pub enum ThreadStatus {
592 Idle,
593 WaitingForToolConfirmation,
594 Generating,
595}
596
597#[derive(Debug, Clone)]
598pub enum LoadError {
599 Unsupported {
600 error_message: SharedString,
601 upgrade_message: SharedString,
602 upgrade_command: String,
603 },
604 Exited(i32),
605 Other(SharedString),
606}
607
608impl Display for LoadError {
609 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
610 match self {
611 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
612 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
613 LoadError::Other(msg) => write!(f, "{}", msg),
614 }
615 }
616}
617
618impl Error for LoadError {}
619
620impl AcpThread {
621 pub fn new(
622 title: impl Into<SharedString>,
623 connection: Rc<dyn AgentConnection>,
624 project: Entity<Project>,
625 session_id: acp::SessionId,
626 cx: &mut Context<Self>,
627 ) -> Self {
628 let action_log = cx.new(|_| ActionLog::new(project.clone()));
629
630 Self {
631 action_log,
632 shared_buffers: Default::default(),
633 entries: Default::default(),
634 plan: Default::default(),
635 title: title.into(),
636 project,
637 send_task: None,
638 connection,
639 session_id,
640 }
641 }
642
643 pub fn action_log(&self) -> &Entity<ActionLog> {
644 &self.action_log
645 }
646
647 pub fn project(&self) -> &Entity<Project> {
648 &self.project
649 }
650
651 pub fn title(&self) -> SharedString {
652 self.title.clone()
653 }
654
655 pub fn entries(&self) -> &[AgentThreadEntry] {
656 &self.entries
657 }
658
659 pub fn session_id(&self) -> &acp::SessionId {
660 &self.session_id
661 }
662
663 pub fn status(&self) -> ThreadStatus {
664 if self.send_task.is_some() {
665 if self.waiting_for_tool_confirmation() {
666 ThreadStatus::WaitingForToolConfirmation
667 } else {
668 ThreadStatus::Generating
669 }
670 } else {
671 ThreadStatus::Idle
672 }
673 }
674
675 pub fn has_pending_edit_tool_calls(&self) -> bool {
676 for entry in self.entries.iter().rev() {
677 match entry {
678 AgentThreadEntry::UserMessage(_) => return false,
679 AgentThreadEntry::ToolCall(
680 call @ ToolCall {
681 status:
682 ToolCallStatus::Allowed {
683 status:
684 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
685 },
686 ..
687 },
688 ) if call.diffs().next().is_some() => {
689 return true;
690 }
691 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
692 }
693 }
694
695 false
696 }
697
698 pub fn used_tools_since_last_user_message(&self) -> bool {
699 for entry in self.entries.iter().rev() {
700 match entry {
701 AgentThreadEntry::UserMessage(..) => return false,
702 AgentThreadEntry::AssistantMessage(..) => continue,
703 AgentThreadEntry::ToolCall(..) => return true,
704 }
705 }
706
707 false
708 }
709
710 pub fn handle_session_update(
711 &mut self,
712 update: acp::SessionUpdate,
713 cx: &mut Context<Self>,
714 ) -> Result<()> {
715 match update {
716 acp::SessionUpdate::UserMessageChunk { content } => {
717 self.push_user_content_block(content, cx);
718 }
719 acp::SessionUpdate::AgentMessageChunk { content } => {
720 self.push_assistant_content_block(content, false, cx);
721 }
722 acp::SessionUpdate::AgentThoughtChunk { content } => {
723 self.push_assistant_content_block(content, true, cx);
724 }
725 acp::SessionUpdate::ToolCall(tool_call) => {
726 self.upsert_tool_call(tool_call, cx);
727 }
728 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
729 self.update_tool_call(tool_call_update, cx)?;
730 }
731 acp::SessionUpdate::Plan(plan) => {
732 self.update_plan(plan, cx);
733 }
734 }
735 Ok(())
736 }
737
738 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
739 let language_registry = self.project.read(cx).languages().clone();
740 let entries_len = self.entries.len();
741
742 if let Some(last_entry) = self.entries.last_mut()
743 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
744 {
745 content.append(chunk, &language_registry, cx);
746 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
747 } else {
748 let content = ContentBlock::new(chunk, &language_registry, cx);
749 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
750 }
751 }
752
753 pub fn push_assistant_content_block(
754 &mut self,
755 chunk: acp::ContentBlock,
756 is_thought: bool,
757 cx: &mut Context<Self>,
758 ) {
759 let language_registry = self.project.read(cx).languages().clone();
760 let entries_len = self.entries.len();
761 if let Some(last_entry) = self.entries.last_mut()
762 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
763 {
764 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
765 match (chunks.last_mut(), is_thought) {
766 (Some(AssistantMessageChunk::Message { block }), false)
767 | (Some(AssistantMessageChunk::Thought { block }), true) => {
768 block.append(chunk, &language_registry, cx)
769 }
770 _ => {
771 let block = ContentBlock::new(chunk, &language_registry, cx);
772 if is_thought {
773 chunks.push(AssistantMessageChunk::Thought { block })
774 } else {
775 chunks.push(AssistantMessageChunk::Message { block })
776 }
777 }
778 }
779 } else {
780 let block = ContentBlock::new(chunk, &language_registry, cx);
781 let chunk = if is_thought {
782 AssistantMessageChunk::Thought { block }
783 } else {
784 AssistantMessageChunk::Message { block }
785 };
786
787 self.push_entry(
788 AgentThreadEntry::AssistantMessage(AssistantMessage {
789 chunks: vec![chunk],
790 }),
791 cx,
792 );
793 }
794 }
795
796 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
797 self.entries.push(entry);
798 cx.emit(AcpThreadEvent::NewEntry);
799 }
800
801 pub fn update_tool_call(
802 &mut self,
803 update: acp::ToolCallUpdate,
804 cx: &mut Context<Self>,
805 ) -> Result<()> {
806 let languages = self.project.read(cx).languages().clone();
807
808 let (ix, current_call) = self
809 .tool_call_mut(&update.id)
810 .context("Tool call not found")?;
811 current_call.update(update.fields, languages, cx);
812
813 cx.emit(AcpThreadEvent::EntryUpdated(ix));
814
815 Ok(())
816 }
817
818 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
819 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
820 let status = ToolCallStatus::Allowed {
821 status: tool_call.status,
822 };
823 self.upsert_tool_call_inner(tool_call, status, cx)
824 }
825
826 pub fn upsert_tool_call_inner(
827 &mut self,
828 tool_call: acp::ToolCall,
829 status: ToolCallStatus,
830 cx: &mut Context<Self>,
831 ) {
832 let language_registry = self.project.read(cx).languages().clone();
833 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
834
835 let location = call.locations.last().cloned();
836
837 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
838 *current_call = call;
839
840 cx.emit(AcpThreadEvent::EntryUpdated(ix));
841 } else {
842 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
843 }
844
845 if let Some(location) = location {
846 self.set_project_location(location, cx)
847 }
848 }
849
850 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
851 // The tool call we are looking for is typically the last one, or very close to the end.
852 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
853 self.entries
854 .iter_mut()
855 .enumerate()
856 .rev()
857 .find_map(|(index, tool_call)| {
858 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
859 && &tool_call.id == id
860 {
861 Some((index, tool_call))
862 } else {
863 None
864 }
865 })
866 }
867
868 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
869 self.project.update(cx, |project, cx| {
870 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
871 return;
872 };
873 let buffer = project.open_buffer(path, cx);
874 cx.spawn(async move |project, cx| {
875 let buffer = buffer.await?;
876
877 project.update(cx, |project, cx| {
878 let position = if let Some(line) = location.line {
879 let snapshot = buffer.read(cx).snapshot();
880 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
881 snapshot.anchor_before(point)
882 } else {
883 Anchor::MIN
884 };
885
886 project.set_agent_location(
887 Some(AgentLocation {
888 buffer: buffer.downgrade(),
889 position,
890 }),
891 cx,
892 );
893 })
894 })
895 .detach_and_log_err(cx);
896 });
897 }
898
899 pub fn request_tool_call_permission(
900 &mut self,
901 tool_call: acp::ToolCall,
902 options: Vec<acp::PermissionOption>,
903 cx: &mut Context<Self>,
904 ) -> oneshot::Receiver<acp::PermissionOptionId> {
905 let (tx, rx) = oneshot::channel();
906
907 let status = ToolCallStatus::WaitingForConfirmation {
908 options,
909 respond_tx: tx,
910 };
911
912 self.upsert_tool_call_inner(tool_call, status, cx);
913 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
914 rx
915 }
916
917 pub fn authorize_tool_call(
918 &mut self,
919 id: acp::ToolCallId,
920 option_id: acp::PermissionOptionId,
921 option_kind: acp::PermissionOptionKind,
922 cx: &mut Context<Self>,
923 ) {
924 let Some((ix, call)) = self.tool_call_mut(&id) else {
925 return;
926 };
927
928 let new_status = match option_kind {
929 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
930 ToolCallStatus::Rejected
931 }
932 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
933 ToolCallStatus::Allowed {
934 status: acp::ToolCallStatus::InProgress,
935 }
936 }
937 };
938
939 let curr_status = mem::replace(&mut call.status, new_status);
940
941 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
942 respond_tx.send(option_id).log_err();
943 } else if cfg!(debug_assertions) {
944 panic!("tried to authorize an already authorized tool call");
945 }
946
947 cx.emit(AcpThreadEvent::EntryUpdated(ix));
948 }
949
950 /// Returns true if the last turn is awaiting tool authorization
951 pub fn waiting_for_tool_confirmation(&self) -> bool {
952 for entry in self.entries.iter().rev() {
953 match &entry {
954 AgentThreadEntry::ToolCall(call) => match call.status {
955 ToolCallStatus::WaitingForConfirmation { .. } => return true,
956 ToolCallStatus::Allowed { .. }
957 | ToolCallStatus::Rejected
958 | ToolCallStatus::Canceled => continue,
959 },
960 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
961 // Reached the beginning of the turn
962 return false;
963 }
964 }
965 }
966 false
967 }
968
969 pub fn plan(&self) -> &Plan {
970 &self.plan
971 }
972
973 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
974 self.plan = Plan {
975 entries: request
976 .entries
977 .into_iter()
978 .map(|entry| PlanEntry::from_acp(entry, cx))
979 .collect(),
980 };
981
982 cx.notify();
983 }
984
985 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
986 self.plan
987 .entries
988 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
989 cx.notify();
990 }
991
992 #[cfg(any(test, feature = "test-support"))]
993 pub fn send_raw(
994 &mut self,
995 message: &str,
996 cx: &mut Context<Self>,
997 ) -> BoxFuture<'static, Result<()>> {
998 self.send(
999 vec![acp::ContentBlock::Text(acp::TextContent {
1000 text: message.to_string(),
1001 annotations: None,
1002 })],
1003 cx,
1004 )
1005 }
1006
1007 pub fn send(
1008 &mut self,
1009 message: Vec<acp::ContentBlock>,
1010 cx: &mut Context<Self>,
1011 ) -> BoxFuture<'static, Result<()>> {
1012 let block = ContentBlock::new_combined(
1013 message.clone(),
1014 self.project.read(cx).languages().clone(),
1015 cx,
1016 );
1017 self.push_entry(
1018 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1019 cx,
1020 );
1021 self.clear_completed_plan_entries(cx);
1022
1023 let (tx, rx) = oneshot::channel();
1024 let cancel_task = self.cancel(cx);
1025
1026 self.send_task = Some(cx.spawn(async move |this, cx| {
1027 async {
1028 cancel_task.await;
1029
1030 let result = this
1031 .update(cx, |this, cx| {
1032 this.connection.prompt(
1033 acp::PromptRequest {
1034 prompt: message,
1035 session_id: this.session_id.clone(),
1036 },
1037 cx,
1038 )
1039 })?
1040 .await;
1041 tx.send(result).log_err();
1042 this.update(cx, |this, _cx| this.send_task.take())?;
1043 anyhow::Ok(())
1044 }
1045 .await
1046 .log_err();
1047 }));
1048
1049 cx.spawn(async move |this, cx| match rx.await {
1050 Ok(Err(e)) => {
1051 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1052 .log_err();
1053 Err(e)?
1054 }
1055 _ => {
1056 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1057 .log_err();
1058 Ok(())
1059 }
1060 })
1061 .boxed()
1062 }
1063
1064 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1065 let Some(send_task) = self.send_task.take() else {
1066 return Task::ready(());
1067 };
1068
1069 for entry in self.entries.iter_mut() {
1070 if let AgentThreadEntry::ToolCall(call) = entry {
1071 let cancel = matches!(
1072 call.status,
1073 ToolCallStatus::WaitingForConfirmation { .. }
1074 | ToolCallStatus::Allowed {
1075 status: acp::ToolCallStatus::InProgress
1076 }
1077 );
1078
1079 if cancel {
1080 call.status = ToolCallStatus::Canceled;
1081 }
1082 }
1083 }
1084
1085 self.connection.cancel(&self.session_id, cx);
1086
1087 // Wait for the send task to complete
1088 cx.foreground_executor().spawn(send_task)
1089 }
1090
1091 pub fn read_text_file(
1092 &self,
1093 path: PathBuf,
1094 line: Option<u32>,
1095 limit: Option<u32>,
1096 reuse_shared_snapshot: bool,
1097 cx: &mut Context<Self>,
1098 ) -> Task<Result<String>> {
1099 let project = self.project.clone();
1100 let action_log = self.action_log.clone();
1101 cx.spawn(async move |this, cx| {
1102 let load = project.update(cx, |project, cx| {
1103 let path = project
1104 .project_path_for_absolute_path(&path, cx)
1105 .context("invalid path")?;
1106 anyhow::Ok(project.open_buffer(path, cx))
1107 });
1108 let buffer = load??.await?;
1109
1110 let snapshot = if reuse_shared_snapshot {
1111 this.read_with(cx, |this, _| {
1112 this.shared_buffers.get(&buffer.clone()).cloned()
1113 })
1114 .log_err()
1115 .flatten()
1116 } else {
1117 None
1118 };
1119
1120 let snapshot = if let Some(snapshot) = snapshot {
1121 snapshot
1122 } else {
1123 action_log.update(cx, |action_log, cx| {
1124 action_log.buffer_read(buffer.clone(), cx);
1125 })?;
1126 project.update(cx, |project, cx| {
1127 let position = buffer
1128 .read(cx)
1129 .snapshot()
1130 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1131 project.set_agent_location(
1132 Some(AgentLocation {
1133 buffer: buffer.downgrade(),
1134 position,
1135 }),
1136 cx,
1137 );
1138 })?;
1139
1140 buffer.update(cx, |buffer, _| buffer.snapshot())?
1141 };
1142
1143 this.update(cx, |this, _| {
1144 let text = snapshot.text();
1145 this.shared_buffers.insert(buffer.clone(), snapshot);
1146 if line.is_none() && limit.is_none() {
1147 return Ok(text);
1148 }
1149 let limit = limit.unwrap_or(u32::MAX) as usize;
1150 let Some(line) = line else {
1151 return Ok(text.lines().take(limit).collect::<String>());
1152 };
1153
1154 let count = text.lines().count();
1155 if count < line as usize {
1156 anyhow::bail!("There are only {} lines", count);
1157 }
1158 Ok(text
1159 .lines()
1160 .skip(line as usize + 1)
1161 .take(limit)
1162 .collect::<String>())
1163 })?
1164 })
1165 }
1166
1167 pub fn write_text_file(
1168 &self,
1169 path: PathBuf,
1170 content: String,
1171 cx: &mut Context<Self>,
1172 ) -> Task<Result<()>> {
1173 let project = self.project.clone();
1174 let action_log = self.action_log.clone();
1175 cx.spawn(async move |this, cx| {
1176 let load = project.update(cx, |project, cx| {
1177 let path = project
1178 .project_path_for_absolute_path(&path, cx)
1179 .context("invalid path")?;
1180 anyhow::Ok(project.open_buffer(path, cx))
1181 });
1182 let buffer = load??.await?;
1183 let snapshot = this.update(cx, |this, cx| {
1184 this.shared_buffers
1185 .get(&buffer)
1186 .cloned()
1187 .unwrap_or_else(|| buffer.read(cx).snapshot())
1188 })?;
1189 let edits = cx
1190 .background_executor()
1191 .spawn(async move {
1192 let old_text = snapshot.text();
1193 text_diff(old_text.as_str(), &content)
1194 .into_iter()
1195 .map(|(range, replacement)| {
1196 (
1197 snapshot.anchor_after(range.start)
1198 ..snapshot.anchor_before(range.end),
1199 replacement,
1200 )
1201 })
1202 .collect::<Vec<_>>()
1203 })
1204 .await;
1205 cx.update(|cx| {
1206 project.update(cx, |project, cx| {
1207 project.set_agent_location(
1208 Some(AgentLocation {
1209 buffer: buffer.downgrade(),
1210 position: edits
1211 .last()
1212 .map(|(range, _)| range.end)
1213 .unwrap_or(Anchor::MIN),
1214 }),
1215 cx,
1216 );
1217 });
1218
1219 action_log.update(cx, |action_log, cx| {
1220 action_log.buffer_read(buffer.clone(), cx);
1221 });
1222 buffer.update(cx, |buffer, cx| {
1223 buffer.edit(edits, None, cx);
1224 });
1225 action_log.update(cx, |action_log, cx| {
1226 action_log.buffer_edited(buffer.clone(), cx);
1227 });
1228 })?;
1229 project
1230 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1231 .await
1232 })
1233 }
1234
1235 pub fn to_markdown(&self, cx: &App) -> String {
1236 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1237 }
1238
1239 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1240 cx.emit(AcpThreadEvent::ServerExited(status));
1241 }
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246 use super::*;
1247 use anyhow::anyhow;
1248 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1249 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1250 use indoc::indoc;
1251 use project::FakeFs;
1252 use rand::Rng as _;
1253 use serde_json::json;
1254 use settings::SettingsStore;
1255 use smol::stream::StreamExt as _;
1256 use std::{cell::RefCell, rc::Rc, time::Duration};
1257
1258 use util::path;
1259
1260 fn init_test(cx: &mut TestAppContext) {
1261 env_logger::try_init().ok();
1262 cx.update(|cx| {
1263 let settings_store = SettingsStore::test(cx);
1264 cx.set_global(settings_store);
1265 Project::init_settings(cx);
1266 language::init(cx);
1267 });
1268 }
1269
1270 #[gpui::test]
1271 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1272 init_test(cx);
1273
1274 let fs = FakeFs::new(cx.executor());
1275 let project = Project::test(fs, [], cx).await;
1276 let connection = Rc::new(FakeAgentConnection::new());
1277 let thread = cx
1278 .spawn(async move |mut cx| {
1279 connection
1280 .new_thread(project, Path::new(path!("/test")), &mut cx)
1281 .await
1282 })
1283 .await
1284 .unwrap();
1285
1286 // Test creating a new user message
1287 thread.update(cx, |thread, cx| {
1288 thread.push_user_content_block(
1289 acp::ContentBlock::Text(acp::TextContent {
1290 annotations: None,
1291 text: "Hello, ".to_string(),
1292 }),
1293 cx,
1294 );
1295 });
1296
1297 thread.update(cx, |thread, cx| {
1298 assert_eq!(thread.entries.len(), 1);
1299 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1300 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1301 } else {
1302 panic!("Expected UserMessage");
1303 }
1304 });
1305
1306 // Test appending to existing user message
1307 thread.update(cx, |thread, cx| {
1308 thread.push_user_content_block(
1309 acp::ContentBlock::Text(acp::TextContent {
1310 annotations: None,
1311 text: "world!".to_string(),
1312 }),
1313 cx,
1314 );
1315 });
1316
1317 thread.update(cx, |thread, cx| {
1318 assert_eq!(thread.entries.len(), 1);
1319 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1320 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1321 } else {
1322 panic!("Expected UserMessage");
1323 }
1324 });
1325
1326 // Test creating new user message after assistant message
1327 thread.update(cx, |thread, cx| {
1328 thread.push_assistant_content_block(
1329 acp::ContentBlock::Text(acp::TextContent {
1330 annotations: None,
1331 text: "Assistant response".to_string(),
1332 }),
1333 false,
1334 cx,
1335 );
1336 });
1337
1338 thread.update(cx, |thread, cx| {
1339 thread.push_user_content_block(
1340 acp::ContentBlock::Text(acp::TextContent {
1341 annotations: None,
1342 text: "New user message".to_string(),
1343 }),
1344 cx,
1345 );
1346 });
1347
1348 thread.update(cx, |thread, cx| {
1349 assert_eq!(thread.entries.len(), 3);
1350 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1351 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1352 } else {
1353 panic!("Expected UserMessage at index 2");
1354 }
1355 });
1356 }
1357
1358 #[gpui::test]
1359 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1360 init_test(cx);
1361
1362 let fs = FakeFs::new(cx.executor());
1363 let project = Project::test(fs, [], cx).await;
1364 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1365 |_, thread, mut cx| {
1366 async move {
1367 thread.update(&mut cx, |thread, cx| {
1368 thread
1369 .handle_session_update(
1370 acp::SessionUpdate::AgentThoughtChunk {
1371 content: "Thinking ".into(),
1372 },
1373 cx,
1374 )
1375 .unwrap();
1376 thread
1377 .handle_session_update(
1378 acp::SessionUpdate::AgentThoughtChunk {
1379 content: "hard!".into(),
1380 },
1381 cx,
1382 )
1383 .unwrap();
1384 })?;
1385 Ok(acp::PromptResponse {
1386 stop_reason: acp::StopReason::EndTurn,
1387 })
1388 }
1389 .boxed_local()
1390 },
1391 ));
1392
1393 let thread = cx
1394 .spawn(async move |mut cx| {
1395 connection
1396 .new_thread(project, Path::new(path!("/test")), &mut cx)
1397 .await
1398 })
1399 .await
1400 .unwrap();
1401
1402 thread
1403 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1404 .await
1405 .unwrap();
1406
1407 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1408 assert_eq!(
1409 output,
1410 indoc! {r#"
1411 ## User
1412
1413 Hello from Zed!
1414
1415 ## Assistant
1416
1417 <thinking>
1418 Thinking hard!
1419 </thinking>
1420
1421 "#}
1422 );
1423 }
1424
1425 #[gpui::test]
1426 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1427 init_test(cx);
1428
1429 let fs = FakeFs::new(cx.executor());
1430 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1431 .await;
1432 let project = Project::test(fs.clone(), [], cx).await;
1433 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1434 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1435 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1436 move |_, thread, mut cx| {
1437 let read_file_tx = read_file_tx.clone();
1438 async move {
1439 let content = thread
1440 .update(&mut cx, |thread, cx| {
1441 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1442 })
1443 .unwrap()
1444 .await
1445 .unwrap();
1446 assert_eq!(content, "one\ntwo\nthree\n");
1447 read_file_tx.take().unwrap().send(()).unwrap();
1448 thread
1449 .update(&mut cx, |thread, cx| {
1450 thread.write_text_file(
1451 path!("/tmp/foo").into(),
1452 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1453 cx,
1454 )
1455 })
1456 .unwrap()
1457 .await
1458 .unwrap();
1459 Ok(acp::PromptResponse {
1460 stop_reason: acp::StopReason::EndTurn,
1461 })
1462 }
1463 .boxed_local()
1464 },
1465 ));
1466
1467 let (worktree, pathbuf) = project
1468 .update(cx, |project, cx| {
1469 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1470 })
1471 .await
1472 .unwrap();
1473 let buffer = project
1474 .update(cx, |project, cx| {
1475 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1476 })
1477 .await
1478 .unwrap();
1479
1480 let thread = cx
1481 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1482 .await
1483 .unwrap();
1484
1485 let request = thread.update(cx, |thread, cx| {
1486 thread.send_raw("Extend the count in /tmp/foo", cx)
1487 });
1488 read_file_rx.await.ok();
1489 buffer.update(cx, |buffer, cx| {
1490 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1491 });
1492 cx.run_until_parked();
1493 assert_eq!(
1494 buffer.read_with(cx, |buffer, _| buffer.text()),
1495 "zero\none\ntwo\nthree\nfour\nfive\n"
1496 );
1497 assert_eq!(
1498 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1499 "zero\none\ntwo\nthree\nfour\nfive\n"
1500 );
1501 request.await.unwrap();
1502 }
1503
1504 #[gpui::test]
1505 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1506 init_test(cx);
1507
1508 let fs = FakeFs::new(cx.executor());
1509 let project = Project::test(fs, [], cx).await;
1510 let id = acp::ToolCallId("test".into());
1511
1512 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1513 let id = id.clone();
1514 move |_, thread, mut cx| {
1515 let id = id.clone();
1516 async move {
1517 thread
1518 .update(&mut cx, |thread, cx| {
1519 thread.handle_session_update(
1520 acp::SessionUpdate::ToolCall(acp::ToolCall {
1521 id: id.clone(),
1522 title: "Label".into(),
1523 kind: acp::ToolKind::Fetch,
1524 status: acp::ToolCallStatus::InProgress,
1525 content: vec![],
1526 locations: vec![],
1527 raw_input: None,
1528 }),
1529 cx,
1530 )
1531 })
1532 .unwrap()
1533 .unwrap();
1534 Ok(acp::PromptResponse {
1535 stop_reason: acp::StopReason::EndTurn,
1536 })
1537 }
1538 .boxed_local()
1539 }
1540 }));
1541
1542 let thread = cx
1543 .spawn(async move |mut cx| {
1544 connection
1545 .new_thread(project, Path::new(path!("/test")), &mut cx)
1546 .await
1547 })
1548 .await
1549 .unwrap();
1550
1551 let request = thread.update(cx, |thread, cx| {
1552 thread.send_raw("Fetch https://example.com", cx)
1553 });
1554
1555 run_until_first_tool_call(&thread, cx).await;
1556
1557 thread.read_with(cx, |thread, _| {
1558 assert!(matches!(
1559 thread.entries[1],
1560 AgentThreadEntry::ToolCall(ToolCall {
1561 status: ToolCallStatus::Allowed {
1562 status: acp::ToolCallStatus::InProgress,
1563 ..
1564 },
1565 ..
1566 })
1567 ));
1568 });
1569
1570 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1571
1572 thread.read_with(cx, |thread, _| {
1573 assert!(matches!(
1574 &thread.entries[1],
1575 AgentThreadEntry::ToolCall(ToolCall {
1576 status: ToolCallStatus::Canceled,
1577 ..
1578 })
1579 ));
1580 });
1581
1582 thread
1583 .update(cx, |thread, cx| {
1584 thread.handle_session_update(
1585 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1586 id,
1587 fields: acp::ToolCallUpdateFields {
1588 status: Some(acp::ToolCallStatus::Completed),
1589 ..Default::default()
1590 },
1591 }),
1592 cx,
1593 )
1594 })
1595 .unwrap();
1596
1597 request.await.unwrap();
1598
1599 thread.read_with(cx, |thread, _| {
1600 assert!(matches!(
1601 thread.entries[1],
1602 AgentThreadEntry::ToolCall(ToolCall {
1603 status: ToolCallStatus::Allowed {
1604 status: acp::ToolCallStatus::Completed,
1605 ..
1606 },
1607 ..
1608 })
1609 ));
1610 });
1611 }
1612
1613 #[gpui::test]
1614 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1615 init_test(cx);
1616 let fs = FakeFs::new(cx.background_executor.clone());
1617 fs.insert_tree(path!("/test"), json!({})).await;
1618 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1619
1620 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1621 move |_, thread, mut cx| {
1622 async move {
1623 thread
1624 .update(&mut cx, |thread, cx| {
1625 thread.handle_session_update(
1626 acp::SessionUpdate::ToolCall(acp::ToolCall {
1627 id: acp::ToolCallId("test".into()),
1628 title: "Label".into(),
1629 kind: acp::ToolKind::Edit,
1630 status: acp::ToolCallStatus::Completed,
1631 content: vec![acp::ToolCallContent::Diff {
1632 diff: acp::Diff {
1633 path: "/test/test.txt".into(),
1634 old_text: None,
1635 new_text: "foo".into(),
1636 },
1637 }],
1638 locations: vec![],
1639 raw_input: None,
1640 }),
1641 cx,
1642 )
1643 })
1644 .unwrap()
1645 .unwrap();
1646 Ok(acp::PromptResponse {
1647 stop_reason: acp::StopReason::EndTurn,
1648 })
1649 }
1650 .boxed_local()
1651 }
1652 }));
1653
1654 let thread = connection
1655 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1656 .await
1657 .unwrap();
1658 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1659 .await
1660 .unwrap();
1661
1662 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1663 }
1664
1665 async fn run_until_first_tool_call(
1666 thread: &Entity<AcpThread>,
1667 cx: &mut TestAppContext,
1668 ) -> usize {
1669 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1670
1671 let subscription = cx.update(|cx| {
1672 cx.subscribe(thread, move |thread, _, cx| {
1673 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1674 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1675 return tx.try_send(ix).unwrap();
1676 }
1677 }
1678 })
1679 });
1680
1681 select! {
1682 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1683 panic!("Timeout waiting for tool call")
1684 }
1685 ix = rx.next().fuse() => {
1686 drop(subscription);
1687 ix.unwrap()
1688 }
1689 }
1690 }
1691
1692 #[derive(Clone, Default)]
1693 struct FakeAgentConnection {
1694 auth_methods: Vec<acp::AuthMethod>,
1695 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1696 on_user_message: Option<
1697 Rc<
1698 dyn Fn(
1699 acp::PromptRequest,
1700 WeakEntity<AcpThread>,
1701 AsyncApp,
1702 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1703 + 'static,
1704 >,
1705 >,
1706 }
1707
1708 impl FakeAgentConnection {
1709 fn new() -> Self {
1710 Self {
1711 auth_methods: Vec::new(),
1712 on_user_message: None,
1713 sessions: Arc::default(),
1714 }
1715 }
1716
1717 #[expect(unused)]
1718 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1719 self.auth_methods = auth_methods;
1720 self
1721 }
1722
1723 fn on_user_message(
1724 mut self,
1725 handler: impl Fn(
1726 acp::PromptRequest,
1727 WeakEntity<AcpThread>,
1728 AsyncApp,
1729 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1730 + 'static,
1731 ) -> Self {
1732 self.on_user_message.replace(Rc::new(handler));
1733 self
1734 }
1735 }
1736
1737 impl AgentConnection for FakeAgentConnection {
1738 fn auth_methods(&self) -> &[acp::AuthMethod] {
1739 &self.auth_methods
1740 }
1741
1742 fn new_thread(
1743 self: Rc<Self>,
1744 project: Entity<Project>,
1745 _cwd: &Path,
1746 cx: &mut gpui::AsyncApp,
1747 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1748 let session_id = acp::SessionId(
1749 rand::thread_rng()
1750 .sample_iter(&rand::distributions::Alphanumeric)
1751 .take(7)
1752 .map(char::from)
1753 .collect::<String>()
1754 .into(),
1755 );
1756 let thread = cx
1757 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1758 .unwrap();
1759 self.sessions.lock().insert(session_id, thread.downgrade());
1760 Task::ready(Ok(thread))
1761 }
1762
1763 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1764 if self.auth_methods().iter().any(|m| m.id == method) {
1765 Task::ready(Ok(()))
1766 } else {
1767 Task::ready(Err(anyhow!("Invalid Auth Method")))
1768 }
1769 }
1770
1771 fn prompt(
1772 &self,
1773 params: acp::PromptRequest,
1774 cx: &mut App,
1775 ) -> Task<gpui::Result<acp::PromptResponse>> {
1776 let sessions = self.sessions.lock();
1777 let thread = sessions.get(¶ms.session_id).unwrap();
1778 if let Some(handler) = &self.on_user_message {
1779 let handler = handler.clone();
1780 let thread = thread.clone();
1781 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1782 } else {
1783 Task::ready(Ok(acp::PromptResponse {
1784 stop_reason: acp::StopReason::EndTurn,
1785 }))
1786 }
1787 }
1788
1789 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1790 let sessions = self.sessions.lock();
1791 let thread = sessions.get(&session_id).unwrap().clone();
1792
1793 cx.spawn(async move |cx| {
1794 thread
1795 .update(cx, |thread, cx| thread.cancel(cx))
1796 .unwrap()
1797 .await
1798 })
1799 .detach();
1800 }
1801 }
1802}