1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::rc::Rc;
22use std::{
23 fmt::Display,
24 mem,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub content: ContentBlock,
34}
35
36impl UserMessage {
37 pub fn from_acp(
38 message: impl IntoIterator<Item = acp::ContentBlock>,
39 language_registry: Arc<LanguageRegistry>,
40 cx: &mut App,
41 ) -> Self {
42 let mut content = ContentBlock::Empty;
43 for chunk in message {
44 content.append(chunk, &language_registry, cx)
45 }
46 Self { content: content }
47 }
48
49 fn to_markdown(&self, cx: &App) -> String {
50 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
51 }
52}
53
54#[derive(Debug)]
55pub struct MentionPath<'a>(&'a Path);
56
57impl<'a> MentionPath<'a> {
58 const PREFIX: &'static str = "@file:";
59
60 pub fn new(path: &'a Path) -> Self {
61 MentionPath(path)
62 }
63
64 pub fn try_parse(url: &'a str) -> Option<Self> {
65 let path = url.strip_prefix(Self::PREFIX)?;
66 Some(MentionPath(Path::new(path)))
67 }
68
69 pub fn path(&self) -> &Path {
70 self.0
71 }
72}
73
74impl Display for MentionPath<'_> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(
77 f,
78 "[@{}]({}{})",
79 self.0.file_name().unwrap_or_default().display(),
80 Self::PREFIX,
81 self.0.display()
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub struct AssistantMessage {
88 pub chunks: Vec<AssistantMessageChunk>,
89}
90
91impl AssistantMessage {
92 pub fn to_markdown(&self, cx: &App) -> String {
93 format!(
94 "## Assistant\n\n{}\n\n",
95 self.chunks
96 .iter()
97 .map(|chunk| chunk.to_markdown(cx))
98 .join("\n\n")
99 )
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub enum AssistantMessageChunk {
105 Message { block: ContentBlock },
106 Thought { block: ContentBlock },
107}
108
109impl AssistantMessageChunk {
110 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 Self::Message {
112 block: ContentBlock::new(chunk.into(), language_registry, cx),
113 }
114 }
115
116 fn to_markdown(&self, cx: &App) -> String {
117 match self {
118 Self::Message { block } => block.to_markdown(cx).to_string(),
119 Self::Thought { block } => {
120 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
121 }
122 }
123 }
124}
125
126#[derive(Debug)]
127pub enum AgentThreadEntry {
128 UserMessage(UserMessage),
129 AssistantMessage(AssistantMessage),
130 ToolCall(ToolCall),
131}
132
133impl AgentThreadEntry {
134 fn to_markdown(&self, cx: &App) -> String {
135 match self {
136 Self::UserMessage(message) => message.to_markdown(cx),
137 Self::AssistantMessage(message) => message.to_markdown(cx),
138 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
151 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
152 Some(locations)
153 } else {
154 None
155 }
156 }
157}
158
159#[derive(Debug)]
160pub struct ToolCall {
161 pub id: acp::ToolCallId,
162 pub label: Entity<Markdown>,
163 pub kind: acp::ToolKind,
164 pub content: Vec<ToolCallContent>,
165 pub status: ToolCallStatus,
166 pub locations: Vec<acp::ToolCallLocation>,
167 pub raw_input: Option<serde_json::Value>,
168}
169
170impl ToolCall {
171 fn from_acp(
172 tool_call: acp::ToolCall,
173 status: ToolCallStatus,
174 language_registry: Arc<LanguageRegistry>,
175 cx: &mut App,
176 ) -> Self {
177 Self {
178 id: tool_call.id,
179 label: cx.new(|cx| {
180 Markdown::new(
181 tool_call.title.into(),
182 Some(language_registry.clone()),
183 None,
184 cx,
185 )
186 }),
187 kind: tool_call.kind,
188 content: tool_call
189 .content
190 .into_iter()
191 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
192 .collect(),
193 locations: tool_call.locations,
194 status,
195 raw_input: tool_call.raw_input,
196 }
197 }
198
199 fn update(
200 &mut self,
201 fields: acp::ToolCallUpdateFields,
202 language_registry: Arc<LanguageRegistry>,
203 cx: &mut App,
204 ) {
205 let acp::ToolCallUpdateFields {
206 kind,
207 status,
208 title,
209 content,
210 locations,
211 raw_input,
212 } = fields;
213
214 if let Some(kind) = kind {
215 self.kind = kind;
216 }
217
218 if let Some(status) = status {
219 self.status = ToolCallStatus::Allowed { status };
220 }
221
222 if let Some(title) = title {
223 self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
224 }
225
226 if let Some(content) = content {
227 self.content = content
228 .into_iter()
229 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
230 .collect();
231 }
232
233 if let Some(locations) = locations {
234 self.locations = locations;
235 }
236
237 if let Some(raw_input) = raw_input {
238 self.raw_input = Some(raw_input);
239 }
240 }
241
242 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
243 self.content.iter().filter_map(|content| match content {
244 ToolCallContent::ContentBlock { .. } => None,
245 ToolCallContent::Diff { diff } => Some(diff),
246 })
247 }
248
249 fn to_markdown(&self, cx: &App) -> String {
250 let mut markdown = format!(
251 "**Tool Call: {}**\nStatus: {}\n\n",
252 self.label.read(cx).source(),
253 self.status
254 );
255 for content in &self.content {
256 markdown.push_str(content.to_markdown(cx).as_str());
257 markdown.push_str("\n\n");
258 }
259 markdown
260 }
261}
262
263#[derive(Debug)]
264pub enum ToolCallStatus {
265 WaitingForConfirmation {
266 options: Vec<acp::PermissionOption>,
267 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
268 },
269 Allowed {
270 status: acp::ToolCallStatus,
271 },
272 Rejected,
273 Canceled,
274}
275
276impl Display for ToolCallStatus {
277 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
278 write!(
279 f,
280 "{}",
281 match self {
282 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
283 ToolCallStatus::Allowed { status } => match status {
284 acp::ToolCallStatus::Pending => "Pending",
285 acp::ToolCallStatus::InProgress => "In Progress",
286 acp::ToolCallStatus::Completed => "Completed",
287 acp::ToolCallStatus::Failed => "Failed",
288 },
289 ToolCallStatus::Rejected => "Rejected",
290 ToolCallStatus::Canceled => "Canceled",
291 }
292 )
293 }
294}
295
296#[derive(Debug, PartialEq, Clone)]
297pub enum ContentBlock {
298 Empty,
299 Markdown { markdown: Entity<Markdown> },
300}
301
302impl ContentBlock {
303 pub fn new(
304 block: acp::ContentBlock,
305 language_registry: &Arc<LanguageRegistry>,
306 cx: &mut App,
307 ) -> Self {
308 let mut this = Self::Empty;
309 this.append(block, language_registry, cx);
310 this
311 }
312
313 pub fn new_combined(
314 blocks: impl IntoIterator<Item = acp::ContentBlock>,
315 language_registry: Arc<LanguageRegistry>,
316 cx: &mut App,
317 ) -> Self {
318 let mut this = Self::Empty;
319 for block in blocks {
320 this.append(block, &language_registry, cx);
321 }
322 this
323 }
324
325 pub fn append(
326 &mut self,
327 block: acp::ContentBlock,
328 language_registry: &Arc<LanguageRegistry>,
329 cx: &mut App,
330 ) {
331 let new_content = match block {
332 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
333 acp::ContentBlock::ResourceLink(resource_link) => {
334 if let Some(path) = resource_link.uri.strip_prefix("file://") {
335 format!("{}", MentionPath(path.as_ref()))
336 } else {
337 resource_link.uri.clone()
338 }
339 }
340 acp::ContentBlock::Image(_)
341 | acp::ContentBlock::Audio(_)
342 | acp::ContentBlock::Resource(_) => String::new(),
343 };
344
345 match self {
346 ContentBlock::Empty => {
347 *self = ContentBlock::Markdown {
348 markdown: cx.new(|cx| {
349 Markdown::new(
350 new_content.into(),
351 Some(language_registry.clone()),
352 None,
353 cx,
354 )
355 }),
356 };
357 }
358 ContentBlock::Markdown { markdown } => {
359 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
360 }
361 }
362 }
363
364 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
365 match self {
366 ContentBlock::Empty => "",
367 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
368 }
369 }
370
371 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
372 match self {
373 ContentBlock::Empty => None,
374 ContentBlock::Markdown { markdown } => Some(markdown),
375 }
376 }
377}
378
379#[derive(Debug)]
380pub enum ToolCallContent {
381 ContentBlock { content: ContentBlock },
382 Diff { diff: Diff },
383}
384
385impl ToolCallContent {
386 pub fn from_acp(
387 content: acp::ToolCallContent,
388 language_registry: Arc<LanguageRegistry>,
389 cx: &mut App,
390 ) -> Self {
391 match content {
392 acp::ToolCallContent::Content { content } => Self::ContentBlock {
393 content: ContentBlock::new(content, &language_registry, cx),
394 },
395 acp::ToolCallContent::Diff { diff } => Self::Diff {
396 diff: Diff::from_acp(diff, language_registry, cx),
397 },
398 }
399 }
400
401 pub fn to_markdown(&self, cx: &App) -> String {
402 match self {
403 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
404 Self::Diff { diff } => diff.to_markdown(cx),
405 }
406 }
407}
408
409#[derive(Debug)]
410pub struct Diff {
411 pub multibuffer: Entity<MultiBuffer>,
412 pub path: PathBuf,
413 pub new_buffer: Entity<Buffer>,
414 pub old_buffer: Entity<Buffer>,
415 _task: Task<Result<()>>,
416}
417
418impl Diff {
419 pub fn from_acp(
420 diff: acp::Diff,
421 language_registry: Arc<LanguageRegistry>,
422 cx: &mut App,
423 ) -> Self {
424 let acp::Diff {
425 path,
426 old_text,
427 new_text,
428 } = diff;
429
430 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
431
432 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
433 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
434 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
435 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
436 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
437 let diff_task = buffer_diff.update(cx, |diff, cx| {
438 diff.set_base_text(
439 old_buffer_snapshot,
440 Some(language_registry.clone()),
441 new_buffer_snapshot,
442 cx,
443 )
444 });
445
446 let task = cx.spawn({
447 let multibuffer = multibuffer.clone();
448 let path = path.clone();
449 let new_buffer = new_buffer.clone();
450 async move |cx| {
451 diff_task.await?;
452
453 multibuffer
454 .update(cx, |multibuffer, cx| {
455 let hunk_ranges = {
456 let buffer = new_buffer.read(cx);
457 let diff = buffer_diff.read(cx);
458 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
459 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
460 .collect::<Vec<_>>()
461 };
462
463 multibuffer.set_excerpts_for_path(
464 PathKey::for_buffer(&new_buffer, cx),
465 new_buffer.clone(),
466 hunk_ranges,
467 editor::DEFAULT_MULTIBUFFER_CONTEXT,
468 cx,
469 );
470 multibuffer.add_diff(buffer_diff.clone(), cx);
471 })
472 .log_err();
473
474 if let Some(language) = language_registry
475 .language_for_file_path(&path)
476 .await
477 .log_err()
478 {
479 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
480 }
481
482 anyhow::Ok(())
483 }
484 });
485
486 Self {
487 multibuffer,
488 path,
489 new_buffer,
490 old_buffer,
491 _task: task,
492 }
493 }
494
495 fn to_markdown(&self, cx: &App) -> String {
496 let buffer_text = self
497 .multibuffer
498 .read(cx)
499 .all_buffers()
500 .iter()
501 .map(|buffer| buffer.read(cx).text())
502 .join("\n");
503 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
504 }
505}
506
507#[derive(Debug, Default)]
508pub struct Plan {
509 pub entries: Vec<PlanEntry>,
510}
511
512#[derive(Debug)]
513pub struct PlanStats<'a> {
514 pub in_progress_entry: Option<&'a PlanEntry>,
515 pub pending: u32,
516 pub completed: u32,
517}
518
519impl Plan {
520 pub fn is_empty(&self) -> bool {
521 self.entries.is_empty()
522 }
523
524 pub fn stats(&self) -> PlanStats<'_> {
525 let mut stats = PlanStats {
526 in_progress_entry: None,
527 pending: 0,
528 completed: 0,
529 };
530
531 for entry in &self.entries {
532 match &entry.status {
533 acp::PlanEntryStatus::Pending => {
534 stats.pending += 1;
535 }
536 acp::PlanEntryStatus::InProgress => {
537 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
538 }
539 acp::PlanEntryStatus::Completed => {
540 stats.completed += 1;
541 }
542 }
543 }
544
545 stats
546 }
547}
548
549#[derive(Debug)]
550pub struct PlanEntry {
551 pub content: Entity<Markdown>,
552 pub priority: acp::PlanEntryPriority,
553 pub status: acp::PlanEntryStatus,
554}
555
556impl PlanEntry {
557 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
558 Self {
559 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
560 priority: entry.priority,
561 status: entry.status,
562 }
563 }
564}
565
566pub struct AcpThread {
567 title: SharedString,
568 entries: Vec<AgentThreadEntry>,
569 plan: Plan,
570 project: Entity<Project>,
571 action_log: Entity<ActionLog>,
572 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
573 send_task: Option<Task<()>>,
574 connection: Rc<dyn AgentConnection>,
575 session_id: acp::SessionId,
576}
577
578pub enum AcpThreadEvent {
579 NewEntry,
580 EntryUpdated(usize),
581 ToolAuthorizationRequired,
582 Stopped,
583 Error,
584}
585
586impl EventEmitter<AcpThreadEvent> for AcpThread {}
587
588#[derive(PartialEq, Eq)]
589pub enum ThreadStatus {
590 Idle,
591 WaitingForToolConfirmation,
592 Generating,
593}
594
595#[derive(Debug, Clone)]
596pub enum LoadError {
597 Unsupported {
598 error_message: SharedString,
599 upgrade_message: SharedString,
600 upgrade_command: String,
601 },
602 Exited(i32),
603 Other(SharedString),
604}
605
606impl Display for LoadError {
607 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
608 match self {
609 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
610 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
611 LoadError::Other(msg) => write!(f, "{}", msg),
612 }
613 }
614}
615
616impl Error for LoadError {}
617
618impl AcpThread {
619 pub fn new(
620 title: impl Into<SharedString>,
621 connection: Rc<dyn AgentConnection>,
622 project: Entity<Project>,
623 session_id: acp::SessionId,
624 cx: &mut Context<Self>,
625 ) -> Self {
626 let action_log = cx.new(|_| ActionLog::new(project.clone()));
627
628 Self {
629 action_log,
630 shared_buffers: Default::default(),
631 entries: Default::default(),
632 plan: Default::default(),
633 title: title.into(),
634 project,
635 send_task: None,
636 connection,
637 session_id,
638 }
639 }
640
641 pub fn action_log(&self) -> &Entity<ActionLog> {
642 &self.action_log
643 }
644
645 pub fn project(&self) -> &Entity<Project> {
646 &self.project
647 }
648
649 pub fn title(&self) -> SharedString {
650 self.title.clone()
651 }
652
653 pub fn entries(&self) -> &[AgentThreadEntry] {
654 &self.entries
655 }
656
657 pub fn status(&self) -> ThreadStatus {
658 if self.send_task.is_some() {
659 if self.waiting_for_tool_confirmation() {
660 ThreadStatus::WaitingForToolConfirmation
661 } else {
662 ThreadStatus::Generating
663 }
664 } else {
665 ThreadStatus::Idle
666 }
667 }
668
669 pub fn has_pending_edit_tool_calls(&self) -> bool {
670 for entry in self.entries.iter().rev() {
671 match entry {
672 AgentThreadEntry::UserMessage(_) => return false,
673 AgentThreadEntry::ToolCall(
674 call @ ToolCall {
675 status:
676 ToolCallStatus::Allowed {
677 status:
678 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
679 },
680 ..
681 },
682 ) if call.diffs().next().is_some() => {
683 return true;
684 }
685 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
686 }
687 }
688
689 false
690 }
691
692 pub fn used_tools_since_last_user_message(&self) -> bool {
693 for entry in self.entries.iter().rev() {
694 match entry {
695 AgentThreadEntry::UserMessage(..) => return false,
696 AgentThreadEntry::AssistantMessage(..) => continue,
697 AgentThreadEntry::ToolCall(..) => return true,
698 }
699 }
700
701 false
702 }
703
704 pub fn handle_session_update(
705 &mut self,
706 update: acp::SessionUpdate,
707 cx: &mut Context<Self>,
708 ) -> Result<()> {
709 match update {
710 acp::SessionUpdate::UserMessageChunk { content } => {
711 self.push_user_content_block(content, cx);
712 }
713 acp::SessionUpdate::AgentMessageChunk { content } => {
714 self.push_assistant_content_block(content, false, cx);
715 }
716 acp::SessionUpdate::AgentThoughtChunk { content } => {
717 self.push_assistant_content_block(content, true, cx);
718 }
719 acp::SessionUpdate::ToolCall(tool_call) => {
720 self.upsert_tool_call(tool_call, cx);
721 }
722 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
723 self.update_tool_call(tool_call_update, cx)?;
724 }
725 acp::SessionUpdate::Plan(plan) => {
726 self.update_plan(plan, cx);
727 }
728 }
729 Ok(())
730 }
731
732 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
733 let language_registry = self.project.read(cx).languages().clone();
734 let entries_len = self.entries.len();
735
736 if let Some(last_entry) = self.entries.last_mut()
737 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
738 {
739 content.append(chunk, &language_registry, cx);
740 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
741 } else {
742 let content = ContentBlock::new(chunk, &language_registry, cx);
743 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
744 }
745 }
746
747 pub fn push_assistant_content_block(
748 &mut self,
749 chunk: acp::ContentBlock,
750 is_thought: bool,
751 cx: &mut Context<Self>,
752 ) {
753 let language_registry = self.project.read(cx).languages().clone();
754 let entries_len = self.entries.len();
755 if let Some(last_entry) = self.entries.last_mut()
756 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
757 {
758 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
759 match (chunks.last_mut(), is_thought) {
760 (Some(AssistantMessageChunk::Message { block }), false)
761 | (Some(AssistantMessageChunk::Thought { block }), true) => {
762 block.append(chunk, &language_registry, cx)
763 }
764 _ => {
765 let block = ContentBlock::new(chunk, &language_registry, cx);
766 if is_thought {
767 chunks.push(AssistantMessageChunk::Thought { block })
768 } else {
769 chunks.push(AssistantMessageChunk::Message { block })
770 }
771 }
772 }
773 } else {
774 let block = ContentBlock::new(chunk, &language_registry, cx);
775 let chunk = if is_thought {
776 AssistantMessageChunk::Thought { block }
777 } else {
778 AssistantMessageChunk::Message { block }
779 };
780
781 self.push_entry(
782 AgentThreadEntry::AssistantMessage(AssistantMessage {
783 chunks: vec![chunk],
784 }),
785 cx,
786 );
787 }
788 }
789
790 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
791 self.entries.push(entry);
792 cx.emit(AcpThreadEvent::NewEntry);
793 }
794
795 pub fn update_tool_call(
796 &mut self,
797 update: acp::ToolCallUpdate,
798 cx: &mut Context<Self>,
799 ) -> Result<()> {
800 let languages = self.project.read(cx).languages().clone();
801
802 let (ix, current_call) = self
803 .tool_call_mut(&update.id)
804 .context("Tool call not found")?;
805 current_call.update(update.fields, languages, cx);
806
807 cx.emit(AcpThreadEvent::EntryUpdated(ix));
808
809 Ok(())
810 }
811
812 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
813 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
814 let status = ToolCallStatus::Allowed {
815 status: tool_call.status,
816 };
817 self.upsert_tool_call_inner(tool_call, status, cx)
818 }
819
820 pub fn upsert_tool_call_inner(
821 &mut self,
822 tool_call: acp::ToolCall,
823 status: ToolCallStatus,
824 cx: &mut Context<Self>,
825 ) {
826 let language_registry = self.project.read(cx).languages().clone();
827 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
828
829 let location = call.locations.last().cloned();
830
831 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
832 *current_call = call;
833
834 cx.emit(AcpThreadEvent::EntryUpdated(ix));
835 } else {
836 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
837 }
838
839 if let Some(location) = location {
840 self.set_project_location(location, cx)
841 }
842 }
843
844 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
845 // The tool call we are looking for is typically the last one, or very close to the end.
846 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
847 self.entries
848 .iter_mut()
849 .enumerate()
850 .rev()
851 .find_map(|(index, tool_call)| {
852 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
853 && &tool_call.id == id
854 {
855 Some((index, tool_call))
856 } else {
857 None
858 }
859 })
860 }
861
862 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
863 self.project.update(cx, |project, cx| {
864 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
865 return;
866 };
867 let buffer = project.open_buffer(path, cx);
868 cx.spawn(async move |project, cx| {
869 let buffer = buffer.await?;
870
871 project.update(cx, |project, cx| {
872 let position = if let Some(line) = location.line {
873 let snapshot = buffer.read(cx).snapshot();
874 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
875 snapshot.anchor_before(point)
876 } else {
877 Anchor::MIN
878 };
879
880 project.set_agent_location(
881 Some(AgentLocation {
882 buffer: buffer.downgrade(),
883 position,
884 }),
885 cx,
886 );
887 })
888 })
889 .detach_and_log_err(cx);
890 });
891 }
892
893 pub fn request_tool_call_permission(
894 &mut self,
895 tool_call: acp::ToolCall,
896 options: Vec<acp::PermissionOption>,
897 cx: &mut Context<Self>,
898 ) -> oneshot::Receiver<acp::PermissionOptionId> {
899 let (tx, rx) = oneshot::channel();
900
901 let status = ToolCallStatus::WaitingForConfirmation {
902 options,
903 respond_tx: tx,
904 };
905
906 self.upsert_tool_call_inner(tool_call, status, cx);
907 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
908 rx
909 }
910
911 pub fn authorize_tool_call(
912 &mut self,
913 id: acp::ToolCallId,
914 option_id: acp::PermissionOptionId,
915 option_kind: acp::PermissionOptionKind,
916 cx: &mut Context<Self>,
917 ) {
918 let Some((ix, call)) = self.tool_call_mut(&id) else {
919 return;
920 };
921
922 let new_status = match option_kind {
923 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
924 ToolCallStatus::Rejected
925 }
926 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
927 ToolCallStatus::Allowed {
928 status: acp::ToolCallStatus::InProgress,
929 }
930 }
931 };
932
933 let curr_status = mem::replace(&mut call.status, new_status);
934
935 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
936 respond_tx.send(option_id).log_err();
937 } else if cfg!(debug_assertions) {
938 panic!("tried to authorize an already authorized tool call");
939 }
940
941 cx.emit(AcpThreadEvent::EntryUpdated(ix));
942 }
943
944 /// Returns true if the last turn is awaiting tool authorization
945 pub fn waiting_for_tool_confirmation(&self) -> bool {
946 for entry in self.entries.iter().rev() {
947 match &entry {
948 AgentThreadEntry::ToolCall(call) => match call.status {
949 ToolCallStatus::WaitingForConfirmation { .. } => return true,
950 ToolCallStatus::Allowed { .. }
951 | ToolCallStatus::Rejected
952 | ToolCallStatus::Canceled => continue,
953 },
954 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
955 // Reached the beginning of the turn
956 return false;
957 }
958 }
959 }
960 false
961 }
962
963 pub fn plan(&self) -> &Plan {
964 &self.plan
965 }
966
967 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
968 self.plan = Plan {
969 entries: request
970 .entries
971 .into_iter()
972 .map(|entry| PlanEntry::from_acp(entry, cx))
973 .collect(),
974 };
975
976 cx.notify();
977 }
978
979 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
980 self.plan
981 .entries
982 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
983 cx.notify();
984 }
985
986 #[cfg(any(test, feature = "test-support"))]
987 pub fn send_raw(
988 &mut self,
989 message: &str,
990 cx: &mut Context<Self>,
991 ) -> BoxFuture<'static, Result<()>> {
992 self.send(
993 vec![acp::ContentBlock::Text(acp::TextContent {
994 text: message.to_string(),
995 annotations: None,
996 })],
997 cx,
998 )
999 }
1000
1001 pub fn send(
1002 &mut self,
1003 message: Vec<acp::ContentBlock>,
1004 cx: &mut Context<Self>,
1005 ) -> BoxFuture<'static, Result<()>> {
1006 let block = ContentBlock::new_combined(
1007 message.clone(),
1008 self.project.read(cx).languages().clone(),
1009 cx,
1010 );
1011 self.push_entry(
1012 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1013 cx,
1014 );
1015 self.clear_completed_plan_entries(cx);
1016
1017 let (tx, rx) = oneshot::channel();
1018 let cancel_task = self.cancel(cx);
1019
1020 self.send_task = Some(cx.spawn(async move |this, cx| {
1021 async {
1022 cancel_task.await;
1023
1024 let result = this
1025 .update(cx, |this, cx| {
1026 this.connection.prompt(
1027 acp::PromptRequest {
1028 prompt: message,
1029 session_id: this.session_id.clone(),
1030 },
1031 cx,
1032 )
1033 })?
1034 .await;
1035 tx.send(result).log_err();
1036 this.update(cx, |this, _cx| this.send_task.take())?;
1037 anyhow::Ok(())
1038 }
1039 .await
1040 .log_err();
1041 }));
1042
1043 cx.spawn(async move |this, cx| match rx.await {
1044 Ok(Err(e)) => {
1045 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1046 .log_err();
1047 Err(e)?
1048 }
1049 _ => {
1050 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1051 .log_err();
1052 Ok(())
1053 }
1054 })
1055 .boxed()
1056 }
1057
1058 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1059 let Some(send_task) = self.send_task.take() else {
1060 return Task::ready(());
1061 };
1062
1063 for entry in self.entries.iter_mut() {
1064 if let AgentThreadEntry::ToolCall(call) = entry {
1065 let cancel = matches!(
1066 call.status,
1067 ToolCallStatus::WaitingForConfirmation { .. }
1068 | ToolCallStatus::Allowed {
1069 status: acp::ToolCallStatus::InProgress
1070 }
1071 );
1072
1073 if cancel {
1074 call.status = ToolCallStatus::Canceled;
1075 }
1076 }
1077 }
1078
1079 self.connection.cancel(&self.session_id, cx);
1080
1081 // Wait for the send task to complete
1082 cx.foreground_executor().spawn(send_task)
1083 }
1084
1085 pub fn read_text_file(
1086 &self,
1087 path: PathBuf,
1088 line: Option<u32>,
1089 limit: Option<u32>,
1090 reuse_shared_snapshot: bool,
1091 cx: &mut Context<Self>,
1092 ) -> Task<Result<String>> {
1093 let project = self.project.clone();
1094 let action_log = self.action_log.clone();
1095 cx.spawn(async move |this, cx| {
1096 let load = project.update(cx, |project, cx| {
1097 let path = project
1098 .project_path_for_absolute_path(&path, cx)
1099 .context("invalid path")?;
1100 anyhow::Ok(project.open_buffer(path, cx))
1101 });
1102 let buffer = load??.await?;
1103
1104 let snapshot = if reuse_shared_snapshot {
1105 this.read_with(cx, |this, _| {
1106 this.shared_buffers.get(&buffer.clone()).cloned()
1107 })
1108 .log_err()
1109 .flatten()
1110 } else {
1111 None
1112 };
1113
1114 let snapshot = if let Some(snapshot) = snapshot {
1115 snapshot
1116 } else {
1117 action_log.update(cx, |action_log, cx| {
1118 action_log.buffer_read(buffer.clone(), cx);
1119 })?;
1120 project.update(cx, |project, cx| {
1121 let position = buffer
1122 .read(cx)
1123 .snapshot()
1124 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1125 project.set_agent_location(
1126 Some(AgentLocation {
1127 buffer: buffer.downgrade(),
1128 position,
1129 }),
1130 cx,
1131 );
1132 })?;
1133
1134 buffer.update(cx, |buffer, _| buffer.snapshot())?
1135 };
1136
1137 this.update(cx, |this, _| {
1138 let text = snapshot.text();
1139 this.shared_buffers.insert(buffer.clone(), snapshot);
1140 if line.is_none() && limit.is_none() {
1141 return Ok(text);
1142 }
1143 let limit = limit.unwrap_or(u32::MAX) as usize;
1144 let Some(line) = line else {
1145 return Ok(text.lines().take(limit).collect::<String>());
1146 };
1147
1148 let count = text.lines().count();
1149 if count < line as usize {
1150 anyhow::bail!("There are only {} lines", count);
1151 }
1152 Ok(text
1153 .lines()
1154 .skip(line as usize + 1)
1155 .take(limit)
1156 .collect::<String>())
1157 })?
1158 })
1159 }
1160
1161 pub fn write_text_file(
1162 &self,
1163 path: PathBuf,
1164 content: String,
1165 cx: &mut Context<Self>,
1166 ) -> Task<Result<()>> {
1167 let project = self.project.clone();
1168 let action_log = self.action_log.clone();
1169 cx.spawn(async move |this, cx| {
1170 let load = project.update(cx, |project, cx| {
1171 let path = project
1172 .project_path_for_absolute_path(&path, cx)
1173 .context("invalid path")?;
1174 anyhow::Ok(project.open_buffer(path, cx))
1175 });
1176 let buffer = load??.await?;
1177 let snapshot = this.update(cx, |this, cx| {
1178 this.shared_buffers
1179 .get(&buffer)
1180 .cloned()
1181 .unwrap_or_else(|| buffer.read(cx).snapshot())
1182 })?;
1183 let edits = cx
1184 .background_executor()
1185 .spawn(async move {
1186 let old_text = snapshot.text();
1187 text_diff(old_text.as_str(), &content)
1188 .into_iter()
1189 .map(|(range, replacement)| {
1190 (
1191 snapshot.anchor_after(range.start)
1192 ..snapshot.anchor_before(range.end),
1193 replacement,
1194 )
1195 })
1196 .collect::<Vec<_>>()
1197 })
1198 .await;
1199 cx.update(|cx| {
1200 project.update(cx, |project, cx| {
1201 project.set_agent_location(
1202 Some(AgentLocation {
1203 buffer: buffer.downgrade(),
1204 position: edits
1205 .last()
1206 .map(|(range, _)| range.end)
1207 .unwrap_or(Anchor::MIN),
1208 }),
1209 cx,
1210 );
1211 });
1212
1213 action_log.update(cx, |action_log, cx| {
1214 action_log.buffer_read(buffer.clone(), cx);
1215 });
1216 buffer.update(cx, |buffer, cx| {
1217 buffer.edit(edits, None, cx);
1218 });
1219 action_log.update(cx, |action_log, cx| {
1220 action_log.buffer_edited(buffer.clone(), cx);
1221 });
1222 })?;
1223 project
1224 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1225 .await
1226 })
1227 }
1228
1229 pub fn to_markdown(&self, cx: &App) -> String {
1230 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1231 }
1232}
1233
1234#[cfg(test)]
1235mod tests {
1236 use super::*;
1237 use anyhow::anyhow;
1238 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1239 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1240 use indoc::indoc;
1241 use project::FakeFs;
1242 use rand::Rng as _;
1243 use serde_json::json;
1244 use settings::SettingsStore;
1245 use smol::stream::StreamExt as _;
1246 use std::{cell::RefCell, rc::Rc, time::Duration};
1247
1248 use util::path;
1249
1250 fn init_test(cx: &mut TestAppContext) {
1251 env_logger::try_init().ok();
1252 cx.update(|cx| {
1253 let settings_store = SettingsStore::test(cx);
1254 cx.set_global(settings_store);
1255 Project::init_settings(cx);
1256 language::init(cx);
1257 });
1258 }
1259
1260 #[gpui::test]
1261 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1262 init_test(cx);
1263
1264 let fs = FakeFs::new(cx.executor());
1265 let project = Project::test(fs, [], cx).await;
1266 let connection = Rc::new(FakeAgentConnection::new());
1267 let thread = cx
1268 .spawn(async move |mut cx| {
1269 connection
1270 .new_thread(project, Path::new(path!("/test")), &mut cx)
1271 .await
1272 })
1273 .await
1274 .unwrap();
1275
1276 // Test creating a new user message
1277 thread.update(cx, |thread, cx| {
1278 thread.push_user_content_block(
1279 acp::ContentBlock::Text(acp::TextContent {
1280 annotations: None,
1281 text: "Hello, ".to_string(),
1282 }),
1283 cx,
1284 );
1285 });
1286
1287 thread.update(cx, |thread, cx| {
1288 assert_eq!(thread.entries.len(), 1);
1289 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1290 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1291 } else {
1292 panic!("Expected UserMessage");
1293 }
1294 });
1295
1296 // Test appending to existing user message
1297 thread.update(cx, |thread, cx| {
1298 thread.push_user_content_block(
1299 acp::ContentBlock::Text(acp::TextContent {
1300 annotations: None,
1301 text: "world!".to_string(),
1302 }),
1303 cx,
1304 );
1305 });
1306
1307 thread.update(cx, |thread, cx| {
1308 assert_eq!(thread.entries.len(), 1);
1309 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1310 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1311 } else {
1312 panic!("Expected UserMessage");
1313 }
1314 });
1315
1316 // Test creating new user message after assistant message
1317 thread.update(cx, |thread, cx| {
1318 thread.push_assistant_content_block(
1319 acp::ContentBlock::Text(acp::TextContent {
1320 annotations: None,
1321 text: "Assistant response".to_string(),
1322 }),
1323 false,
1324 cx,
1325 );
1326 });
1327
1328 thread.update(cx, |thread, cx| {
1329 thread.push_user_content_block(
1330 acp::ContentBlock::Text(acp::TextContent {
1331 annotations: None,
1332 text: "New user message".to_string(),
1333 }),
1334 cx,
1335 );
1336 });
1337
1338 thread.update(cx, |thread, cx| {
1339 assert_eq!(thread.entries.len(), 3);
1340 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1341 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1342 } else {
1343 panic!("Expected UserMessage at index 2");
1344 }
1345 });
1346 }
1347
1348 #[gpui::test]
1349 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1350 init_test(cx);
1351
1352 let fs = FakeFs::new(cx.executor());
1353 let project = Project::test(fs, [], cx).await;
1354 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1355 |_, thread, mut cx| {
1356 async move {
1357 thread.update(&mut cx, |thread, cx| {
1358 thread
1359 .handle_session_update(
1360 acp::SessionUpdate::AgentThoughtChunk {
1361 content: "Thinking ".into(),
1362 },
1363 cx,
1364 )
1365 .unwrap();
1366 thread
1367 .handle_session_update(
1368 acp::SessionUpdate::AgentThoughtChunk {
1369 content: "hard!".into(),
1370 },
1371 cx,
1372 )
1373 .unwrap();
1374 })
1375 }
1376 .boxed_local()
1377 },
1378 ));
1379
1380 let thread = cx
1381 .spawn(async move |mut cx| {
1382 connection
1383 .new_thread(project, Path::new(path!("/test")), &mut cx)
1384 .await
1385 })
1386 .await
1387 .unwrap();
1388
1389 thread
1390 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1391 .await
1392 .unwrap();
1393
1394 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1395 assert_eq!(
1396 output,
1397 indoc! {r#"
1398 ## User
1399
1400 Hello from Zed!
1401
1402 ## Assistant
1403
1404 <thinking>
1405 Thinking hard!
1406 </thinking>
1407
1408 "#}
1409 );
1410 }
1411
1412 #[gpui::test]
1413 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1414 init_test(cx);
1415
1416 let fs = FakeFs::new(cx.executor());
1417 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1418 .await;
1419 let project = Project::test(fs.clone(), [], cx).await;
1420 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1421 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1422 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1423 move |_, thread, mut cx| {
1424 let read_file_tx = read_file_tx.clone();
1425 async move {
1426 let content = thread
1427 .update(&mut cx, |thread, cx| {
1428 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1429 })
1430 .unwrap()
1431 .await
1432 .unwrap();
1433 assert_eq!(content, "one\ntwo\nthree\n");
1434 read_file_tx.take().unwrap().send(()).unwrap();
1435 thread
1436 .update(&mut cx, |thread, cx| {
1437 thread.write_text_file(
1438 path!("/tmp/foo").into(),
1439 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1440 cx,
1441 )
1442 })
1443 .unwrap()
1444 .await
1445 .unwrap();
1446 Ok(())
1447 }
1448 .boxed_local()
1449 },
1450 ));
1451
1452 let (worktree, pathbuf) = project
1453 .update(cx, |project, cx| {
1454 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1455 })
1456 .await
1457 .unwrap();
1458 let buffer = project
1459 .update(cx, |project, cx| {
1460 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1461 })
1462 .await
1463 .unwrap();
1464
1465 let thread = cx
1466 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1467 .await
1468 .unwrap();
1469
1470 let request = thread.update(cx, |thread, cx| {
1471 thread.send_raw("Extend the count in /tmp/foo", cx)
1472 });
1473 read_file_rx.await.ok();
1474 buffer.update(cx, |buffer, cx| {
1475 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1476 });
1477 cx.run_until_parked();
1478 assert_eq!(
1479 buffer.read_with(cx, |buffer, _| buffer.text()),
1480 "zero\none\ntwo\nthree\nfour\nfive\n"
1481 );
1482 assert_eq!(
1483 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1484 "zero\none\ntwo\nthree\nfour\nfive\n"
1485 );
1486 request.await.unwrap();
1487 }
1488
1489 #[gpui::test]
1490 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1491 init_test(cx);
1492
1493 let fs = FakeFs::new(cx.executor());
1494 let project = Project::test(fs, [], cx).await;
1495 let id = acp::ToolCallId("test".into());
1496
1497 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1498 let id = id.clone();
1499 move |_, thread, mut cx| {
1500 let id = id.clone();
1501 async move {
1502 thread
1503 .update(&mut cx, |thread, cx| {
1504 thread.handle_session_update(
1505 acp::SessionUpdate::ToolCall(acp::ToolCall {
1506 id: id.clone(),
1507 title: "Label".into(),
1508 kind: acp::ToolKind::Fetch,
1509 status: acp::ToolCallStatus::InProgress,
1510 content: vec![],
1511 locations: vec![],
1512 raw_input: None,
1513 }),
1514 cx,
1515 )
1516 })
1517 .unwrap()
1518 .unwrap();
1519 Ok(())
1520 }
1521 .boxed_local()
1522 }
1523 }));
1524
1525 let thread = cx
1526 .spawn(async move |mut cx| {
1527 connection
1528 .new_thread(project, Path::new(path!("/test")), &mut cx)
1529 .await
1530 })
1531 .await
1532 .unwrap();
1533
1534 let request = thread.update(cx, |thread, cx| {
1535 thread.send_raw("Fetch https://example.com", cx)
1536 });
1537
1538 run_until_first_tool_call(&thread, cx).await;
1539
1540 thread.read_with(cx, |thread, _| {
1541 assert!(matches!(
1542 thread.entries[1],
1543 AgentThreadEntry::ToolCall(ToolCall {
1544 status: ToolCallStatus::Allowed {
1545 status: acp::ToolCallStatus::InProgress,
1546 ..
1547 },
1548 ..
1549 })
1550 ));
1551 });
1552
1553 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1554
1555 thread.read_with(cx, |thread, _| {
1556 assert!(matches!(
1557 &thread.entries[1],
1558 AgentThreadEntry::ToolCall(ToolCall {
1559 status: ToolCallStatus::Canceled,
1560 ..
1561 })
1562 ));
1563 });
1564
1565 thread
1566 .update(cx, |thread, cx| {
1567 thread.handle_session_update(
1568 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1569 id,
1570 fields: acp::ToolCallUpdateFields {
1571 status: Some(acp::ToolCallStatus::Completed),
1572 ..Default::default()
1573 },
1574 }),
1575 cx,
1576 )
1577 })
1578 .unwrap();
1579
1580 request.await.unwrap();
1581
1582 thread.read_with(cx, |thread, _| {
1583 assert!(matches!(
1584 thread.entries[1],
1585 AgentThreadEntry::ToolCall(ToolCall {
1586 status: ToolCallStatus::Allowed {
1587 status: acp::ToolCallStatus::Completed,
1588 ..
1589 },
1590 ..
1591 })
1592 ));
1593 });
1594 }
1595
1596 #[gpui::test]
1597 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1598 init_test(cx);
1599 let fs = FakeFs::new(cx.background_executor.clone());
1600 fs.insert_tree(path!("/test"), json!({})).await;
1601 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1602
1603 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1604 move |_, thread, mut cx| {
1605 async move {
1606 thread
1607 .update(&mut cx, |thread, cx| {
1608 thread.handle_session_update(
1609 acp::SessionUpdate::ToolCall(acp::ToolCall {
1610 id: acp::ToolCallId("test".into()),
1611 title: "Label".into(),
1612 kind: acp::ToolKind::Edit,
1613 status: acp::ToolCallStatus::Completed,
1614 content: vec![acp::ToolCallContent::Diff {
1615 diff: acp::Diff {
1616 path: "/test/test.txt".into(),
1617 old_text: None,
1618 new_text: "foo".into(),
1619 },
1620 }],
1621 locations: vec![],
1622 raw_input: None,
1623 }),
1624 cx,
1625 )
1626 })
1627 .unwrap()
1628 .unwrap();
1629 Ok(())
1630 }
1631 .boxed_local()
1632 }
1633 }));
1634
1635 let thread = connection
1636 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1637 .await
1638 .unwrap();
1639 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1640 .await
1641 .unwrap();
1642
1643 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1644 }
1645
1646 async fn run_until_first_tool_call(
1647 thread: &Entity<AcpThread>,
1648 cx: &mut TestAppContext,
1649 ) -> usize {
1650 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1651
1652 let subscription = cx.update(|cx| {
1653 cx.subscribe(thread, move |thread, _, cx| {
1654 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1655 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1656 return tx.try_send(ix).unwrap();
1657 }
1658 }
1659 })
1660 });
1661
1662 select! {
1663 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1664 panic!("Timeout waiting for tool call")
1665 }
1666 ix = rx.next().fuse() => {
1667 drop(subscription);
1668 ix.unwrap()
1669 }
1670 }
1671 }
1672
1673 #[derive(Clone, Default)]
1674 struct FakeAgentConnection {
1675 auth_methods: Vec<acp::AuthMethod>,
1676 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1677 on_user_message: Option<
1678 Rc<
1679 dyn Fn(
1680 acp::PromptRequest,
1681 WeakEntity<AcpThread>,
1682 AsyncApp,
1683 ) -> LocalBoxFuture<'static, Result<()>>
1684 + 'static,
1685 >,
1686 >,
1687 }
1688
1689 impl FakeAgentConnection {
1690 fn new() -> Self {
1691 Self {
1692 auth_methods: Vec::new(),
1693 on_user_message: None,
1694 sessions: Arc::default(),
1695 }
1696 }
1697
1698 #[expect(unused)]
1699 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1700 self.auth_methods = auth_methods;
1701 self
1702 }
1703
1704 fn on_user_message(
1705 mut self,
1706 handler: impl Fn(
1707 acp::PromptRequest,
1708 WeakEntity<AcpThread>,
1709 AsyncApp,
1710 ) -> LocalBoxFuture<'static, Result<()>>
1711 + 'static,
1712 ) -> Self {
1713 self.on_user_message.replace(Rc::new(handler));
1714 self
1715 }
1716 }
1717
1718 impl AgentConnection for FakeAgentConnection {
1719 fn auth_methods(&self) -> &[acp::AuthMethod] {
1720 &self.auth_methods
1721 }
1722
1723 fn new_thread(
1724 self: Rc<Self>,
1725 project: Entity<Project>,
1726 _cwd: &Path,
1727 cx: &mut gpui::AsyncApp,
1728 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1729 let session_id = acp::SessionId(
1730 rand::thread_rng()
1731 .sample_iter(&rand::distributions::Alphanumeric)
1732 .take(7)
1733 .map(char::from)
1734 .collect::<String>()
1735 .into(),
1736 );
1737 let thread = cx
1738 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1739 .unwrap();
1740 self.sessions.lock().insert(session_id, thread.downgrade());
1741 Task::ready(Ok(thread))
1742 }
1743
1744 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1745 if self.auth_methods().iter().any(|m| m.id == method) {
1746 Task::ready(Ok(()))
1747 } else {
1748 Task::ready(Err(anyhow!("Invalid Auth Method")))
1749 }
1750 }
1751
1752 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
1753 let sessions = self.sessions.lock();
1754 let thread = sessions.get(¶ms.session_id).unwrap();
1755 if let Some(handler) = &self.on_user_message {
1756 let handler = handler.clone();
1757 let thread = thread.clone();
1758 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1759 } else {
1760 Task::ready(Ok(()))
1761 }
1762 }
1763
1764 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1765 let sessions = self.sessions.lock();
1766 let thread = sessions.get(&session_id).unwrap().clone();
1767
1768 cx.spawn(async move |cx| {
1769 thread
1770 .update(cx, |thread, cx| thread.cancel(cx))
1771 .unwrap()
1772 .await
1773 })
1774 .detach();
1775 }
1776 }
1777}