1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::process::ExitStatus;
22use std::rc::Rc;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(chunk.into(), language_registry, cx),
114 }
115 }
116
117 fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::Message { block } => block.to_markdown(cx).to_string(),
120 Self::Thought { block } => {
121 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
128pub enum AgentThreadEntry {
129 UserMessage(UserMessage),
130 AssistantMessage(AssistantMessage),
131 ToolCall(ToolCall),
132}
133
134impl AgentThreadEntry {
135 fn to_markdown(&self, cx: &App) -> String {
136 match self {
137 Self::UserMessage(message) => message.to_markdown(cx),
138 Self::AssistantMessage(message) => message.to_markdown(cx),
139 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
140 }
141 }
142
143 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.diffs())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
152 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
153 Some(locations)
154 } else {
155 None
156 }
157 }
158}
159
160#[derive(Debug)]
161pub struct ToolCall {
162 pub id: acp::ToolCallId,
163 pub label: Entity<Markdown>,
164 pub kind: acp::ToolKind,
165 pub content: Vec<ToolCallContent>,
166 pub status: ToolCallStatus,
167 pub locations: Vec<acp::ToolCallLocation>,
168 pub raw_input: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 }
198 }
199
200 fn update(
201 &mut self,
202 fields: acp::ToolCallUpdateFields,
203 language_registry: Arc<LanguageRegistry>,
204 cx: &mut App,
205 ) {
206 let acp::ToolCallUpdateFields {
207 kind,
208 status,
209 title,
210 content,
211 locations,
212 raw_input,
213 } = fields;
214
215 if let Some(kind) = kind {
216 self.kind = kind;
217 }
218
219 if let Some(status) = status {
220 self.status = ToolCallStatus::Allowed { status };
221 }
222
223 if let Some(title) = title {
224 self.label.update(cx, |label, cx| {
225 label.replace(title, cx);
226 });
227 }
228
229 if let Some(content) = content {
230 self.content = content
231 .into_iter()
232 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
233 .collect();
234 }
235
236 if let Some(locations) = locations {
237 self.locations = locations;
238 }
239
240 if let Some(raw_input) = raw_input {
241 self.raw_input = Some(raw_input);
242 }
243 }
244
245 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
246 self.content.iter().filter_map(|content| match content {
247 ToolCallContent::ContentBlock { .. } => None,
248 ToolCallContent::Diff { diff } => Some(diff),
249 })
250 }
251
252 fn to_markdown(&self, cx: &App) -> String {
253 let mut markdown = format!(
254 "**Tool Call: {}**\nStatus: {}\n\n",
255 self.label.read(cx).source(),
256 self.status
257 );
258 for content in &self.content {
259 markdown.push_str(content.to_markdown(cx).as_str());
260 markdown.push_str("\n\n");
261 }
262 markdown
263 }
264}
265
266#[derive(Debug)]
267pub enum ToolCallStatus {
268 WaitingForConfirmation {
269 options: Vec<acp::PermissionOption>,
270 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
271 },
272 Allowed {
273 status: acp::ToolCallStatus,
274 },
275 Rejected,
276 Canceled,
277}
278
279impl Display for ToolCallStatus {
280 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281 write!(
282 f,
283 "{}",
284 match self {
285 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
286 ToolCallStatus::Allowed { status } => match status {
287 acp::ToolCallStatus::Pending => "Pending",
288 acp::ToolCallStatus::InProgress => "In Progress",
289 acp::ToolCallStatus::Completed => "Completed",
290 acp::ToolCallStatus::Failed => "Failed",
291 },
292 ToolCallStatus::Rejected => "Rejected",
293 ToolCallStatus::Canceled => "Canceled",
294 }
295 )
296 }
297}
298
299#[derive(Debug, PartialEq, Clone)]
300pub enum ContentBlock {
301 Empty,
302 Markdown { markdown: Entity<Markdown> },
303}
304
305impl ContentBlock {
306 pub fn new(
307 block: acp::ContentBlock,
308 language_registry: &Arc<LanguageRegistry>,
309 cx: &mut App,
310 ) -> Self {
311 let mut this = Self::Empty;
312 this.append(block, language_registry, cx);
313 this
314 }
315
316 pub fn new_combined(
317 blocks: impl IntoIterator<Item = acp::ContentBlock>,
318 language_registry: Arc<LanguageRegistry>,
319 cx: &mut App,
320 ) -> Self {
321 let mut this = Self::Empty;
322 for block in blocks {
323 this.append(block, &language_registry, cx);
324 }
325 this
326 }
327
328 pub fn append(
329 &mut self,
330 block: acp::ContentBlock,
331 language_registry: &Arc<LanguageRegistry>,
332 cx: &mut App,
333 ) {
334 let new_content = match block {
335 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
336 acp::ContentBlock::ResourceLink(resource_link) => {
337 if let Some(path) = resource_link.uri.strip_prefix("file://") {
338 format!("{}", MentionPath(path.as_ref()))
339 } else {
340 resource_link.uri.clone()
341 }
342 }
343 acp::ContentBlock::Image(_)
344 | acp::ContentBlock::Audio(_)
345 | acp::ContentBlock::Resource(_) => String::new(),
346 };
347
348 match self {
349 ContentBlock::Empty => {
350 *self = ContentBlock::Markdown {
351 markdown: cx.new(|cx| {
352 Markdown::new(
353 new_content.into(),
354 Some(language_registry.clone()),
355 None,
356 cx,
357 )
358 }),
359 };
360 }
361 ContentBlock::Markdown { markdown } => {
362 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
363 }
364 }
365 }
366
367 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
368 match self {
369 ContentBlock::Empty => "",
370 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
371 }
372 }
373
374 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
375 match self {
376 ContentBlock::Empty => None,
377 ContentBlock::Markdown { markdown } => Some(markdown),
378 }
379 }
380}
381
382#[derive(Debug)]
383pub enum ToolCallContent {
384 ContentBlock { content: ContentBlock },
385 Diff { diff: Diff },
386}
387
388impl ToolCallContent {
389 pub fn from_acp(
390 content: acp::ToolCallContent,
391 language_registry: Arc<LanguageRegistry>,
392 cx: &mut App,
393 ) -> Self {
394 match content {
395 acp::ToolCallContent::Content { content } => Self::ContentBlock {
396 content: ContentBlock::new(content, &language_registry, cx),
397 },
398 acp::ToolCallContent::Diff { diff } => Self::Diff {
399 diff: Diff::from_acp(diff, language_registry, cx),
400 },
401 }
402 }
403
404 pub fn to_markdown(&self, cx: &App) -> String {
405 match self {
406 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
407 Self::Diff { diff } => diff.to_markdown(cx),
408 }
409 }
410}
411
412#[derive(Debug)]
413pub struct Diff {
414 pub multibuffer: Entity<MultiBuffer>,
415 pub path: PathBuf,
416 _task: Task<Result<()>>,
417}
418
419impl Diff {
420 pub fn from_acp(
421 diff: acp::Diff,
422 language_registry: Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let acp::Diff {
426 path,
427 old_text,
428 new_text,
429 } = diff;
430
431 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
432
433 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
434 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
435 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
436 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
437
438 let task = cx.spawn({
439 let multibuffer = multibuffer.clone();
440 let path = path.clone();
441 async move |cx| {
442 let language = language_registry
443 .language_for_file_path(&path)
444 .await
445 .log_err();
446
447 new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
448
449 let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
450 buffer.set_language(language, cx);
451 buffer.snapshot()
452 })?;
453
454 buffer_diff
455 .update(cx, |diff, cx| {
456 diff.set_base_text(
457 old_buffer_snapshot,
458 Some(language_registry),
459 new_buffer_snapshot,
460 cx,
461 )
462 })?
463 .await?;
464
465 multibuffer
466 .update(cx, |multibuffer, cx| {
467 let hunk_ranges = {
468 let buffer = new_buffer.read(cx);
469 let diff = buffer_diff.read(cx);
470 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
471 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
472 .collect::<Vec<_>>()
473 };
474
475 multibuffer.set_excerpts_for_path(
476 PathKey::for_buffer(&new_buffer, cx),
477 new_buffer.clone(),
478 hunk_ranges,
479 editor::DEFAULT_MULTIBUFFER_CONTEXT,
480 cx,
481 );
482 multibuffer.add_diff(buffer_diff, cx);
483 })
484 .log_err();
485
486 anyhow::Ok(())
487 }
488 });
489
490 Self {
491 multibuffer,
492 path,
493 _task: task,
494 }
495 }
496
497 fn to_markdown(&self, cx: &App) -> String {
498 let buffer_text = self
499 .multibuffer
500 .read(cx)
501 .all_buffers()
502 .iter()
503 .map(|buffer| buffer.read(cx).text())
504 .join("\n");
505 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
506 }
507}
508
509#[derive(Debug, Default)]
510pub struct Plan {
511 pub entries: Vec<PlanEntry>,
512}
513
514#[derive(Debug)]
515pub struct PlanStats<'a> {
516 pub in_progress_entry: Option<&'a PlanEntry>,
517 pub pending: u32,
518 pub completed: u32,
519}
520
521impl Plan {
522 pub fn is_empty(&self) -> bool {
523 self.entries.is_empty()
524 }
525
526 pub fn stats(&self) -> PlanStats<'_> {
527 let mut stats = PlanStats {
528 in_progress_entry: None,
529 pending: 0,
530 completed: 0,
531 };
532
533 for entry in &self.entries {
534 match &entry.status {
535 acp::PlanEntryStatus::Pending => {
536 stats.pending += 1;
537 }
538 acp::PlanEntryStatus::InProgress => {
539 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
540 }
541 acp::PlanEntryStatus::Completed => {
542 stats.completed += 1;
543 }
544 }
545 }
546
547 stats
548 }
549}
550
551#[derive(Debug)]
552pub struct PlanEntry {
553 pub content: Entity<Markdown>,
554 pub priority: acp::PlanEntryPriority,
555 pub status: acp::PlanEntryStatus,
556}
557
558impl PlanEntry {
559 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
560 Self {
561 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
562 priority: entry.priority,
563 status: entry.status,
564 }
565 }
566}
567
568pub struct AcpThread {
569 title: SharedString,
570 entries: Vec<AgentThreadEntry>,
571 plan: Plan,
572 project: Entity<Project>,
573 action_log: Entity<ActionLog>,
574 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
575 send_task: Option<Task<()>>,
576 connection: Rc<dyn AgentConnection>,
577 session_id: acp::SessionId,
578}
579
580pub enum AcpThreadEvent {
581 NewEntry,
582 EntryUpdated(usize),
583 ToolAuthorizationRequired,
584 Stopped,
585 Error,
586 ServerExited(ExitStatus),
587}
588
589impl EventEmitter<AcpThreadEvent> for AcpThread {}
590
591#[derive(PartialEq, Eq)]
592pub enum ThreadStatus {
593 Idle,
594 WaitingForToolConfirmation,
595 Generating,
596}
597
598#[derive(Debug, Clone)]
599pub enum LoadError {
600 Unsupported {
601 error_message: SharedString,
602 upgrade_message: SharedString,
603 upgrade_command: String,
604 },
605 Exited(i32),
606 Other(SharedString),
607}
608
609impl Display for LoadError {
610 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
611 match self {
612 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
613 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
614 LoadError::Other(msg) => write!(f, "{}", msg),
615 }
616 }
617}
618
619impl Error for LoadError {}
620
621impl AcpThread {
622 pub fn new(
623 title: impl Into<SharedString>,
624 connection: Rc<dyn AgentConnection>,
625 project: Entity<Project>,
626 session_id: acp::SessionId,
627 cx: &mut Context<Self>,
628 ) -> Self {
629 let action_log = cx.new(|_| ActionLog::new(project.clone()));
630
631 Self {
632 action_log,
633 shared_buffers: Default::default(),
634 entries: Default::default(),
635 plan: Default::default(),
636 title: title.into(),
637 project,
638 send_task: None,
639 connection,
640 session_id,
641 }
642 }
643
644 pub fn action_log(&self) -> &Entity<ActionLog> {
645 &self.action_log
646 }
647
648 pub fn project(&self) -> &Entity<Project> {
649 &self.project
650 }
651
652 pub fn title(&self) -> SharedString {
653 self.title.clone()
654 }
655
656 pub fn entries(&self) -> &[AgentThreadEntry] {
657 &self.entries
658 }
659
660 pub fn session_id(&self) -> &acp::SessionId {
661 &self.session_id
662 }
663
664 pub fn status(&self) -> ThreadStatus {
665 if self.send_task.is_some() {
666 if self.waiting_for_tool_confirmation() {
667 ThreadStatus::WaitingForToolConfirmation
668 } else {
669 ThreadStatus::Generating
670 }
671 } else {
672 ThreadStatus::Idle
673 }
674 }
675
676 pub fn has_pending_edit_tool_calls(&self) -> bool {
677 for entry in self.entries.iter().rev() {
678 match entry {
679 AgentThreadEntry::UserMessage(_) => return false,
680 AgentThreadEntry::ToolCall(
681 call @ ToolCall {
682 status:
683 ToolCallStatus::Allowed {
684 status:
685 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
686 },
687 ..
688 },
689 ) if call.diffs().next().is_some() => {
690 return true;
691 }
692 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
693 }
694 }
695
696 false
697 }
698
699 pub fn used_tools_since_last_user_message(&self) -> bool {
700 for entry in self.entries.iter().rev() {
701 match entry {
702 AgentThreadEntry::UserMessage(..) => return false,
703 AgentThreadEntry::AssistantMessage(..) => continue,
704 AgentThreadEntry::ToolCall(..) => return true,
705 }
706 }
707
708 false
709 }
710
711 pub fn handle_session_update(
712 &mut self,
713 update: acp::SessionUpdate,
714 cx: &mut Context<Self>,
715 ) -> Result<()> {
716 match update {
717 acp::SessionUpdate::UserMessageChunk { content } => {
718 self.push_user_content_block(content, cx);
719 }
720 acp::SessionUpdate::AgentMessageChunk { content } => {
721 self.push_assistant_content_block(content, false, cx);
722 }
723 acp::SessionUpdate::AgentThoughtChunk { content } => {
724 self.push_assistant_content_block(content, true, cx);
725 }
726 acp::SessionUpdate::ToolCall(tool_call) => {
727 self.upsert_tool_call(tool_call, cx);
728 }
729 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
730 self.update_tool_call(tool_call_update, cx)?;
731 }
732 acp::SessionUpdate::Plan(plan) => {
733 self.update_plan(plan, cx);
734 }
735 }
736 Ok(())
737 }
738
739 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
740 let language_registry = self.project.read(cx).languages().clone();
741 let entries_len = self.entries.len();
742
743 if let Some(last_entry) = self.entries.last_mut()
744 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
745 {
746 content.append(chunk, &language_registry, cx);
747 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
748 } else {
749 let content = ContentBlock::new(chunk, &language_registry, cx);
750 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
751 }
752 }
753
754 pub fn push_assistant_content_block(
755 &mut self,
756 chunk: acp::ContentBlock,
757 is_thought: bool,
758 cx: &mut Context<Self>,
759 ) {
760 let language_registry = self.project.read(cx).languages().clone();
761 let entries_len = self.entries.len();
762 if let Some(last_entry) = self.entries.last_mut()
763 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
764 {
765 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
766 match (chunks.last_mut(), is_thought) {
767 (Some(AssistantMessageChunk::Message { block }), false)
768 | (Some(AssistantMessageChunk::Thought { block }), true) => {
769 block.append(chunk, &language_registry, cx)
770 }
771 _ => {
772 let block = ContentBlock::new(chunk, &language_registry, cx);
773 if is_thought {
774 chunks.push(AssistantMessageChunk::Thought { block })
775 } else {
776 chunks.push(AssistantMessageChunk::Message { block })
777 }
778 }
779 }
780 } else {
781 let block = ContentBlock::new(chunk, &language_registry, cx);
782 let chunk = if is_thought {
783 AssistantMessageChunk::Thought { block }
784 } else {
785 AssistantMessageChunk::Message { block }
786 };
787
788 self.push_entry(
789 AgentThreadEntry::AssistantMessage(AssistantMessage {
790 chunks: vec![chunk],
791 }),
792 cx,
793 );
794 }
795 }
796
797 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
798 self.entries.push(entry);
799 cx.emit(AcpThreadEvent::NewEntry);
800 }
801
802 pub fn update_tool_call(
803 &mut self,
804 update: acp::ToolCallUpdate,
805 cx: &mut Context<Self>,
806 ) -> Result<()> {
807 let languages = self.project.read(cx).languages().clone();
808
809 let (ix, current_call) = self
810 .tool_call_mut(&update.id)
811 .context("Tool call not found")?;
812 current_call.update(update.fields, languages, cx);
813
814 cx.emit(AcpThreadEvent::EntryUpdated(ix));
815
816 Ok(())
817 }
818
819 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
820 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
821 let status = ToolCallStatus::Allowed {
822 status: tool_call.status,
823 };
824 self.upsert_tool_call_inner(tool_call, status, cx)
825 }
826
827 pub fn upsert_tool_call_inner(
828 &mut self,
829 tool_call: acp::ToolCall,
830 status: ToolCallStatus,
831 cx: &mut Context<Self>,
832 ) {
833 let language_registry = self.project.read(cx).languages().clone();
834 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
835
836 let location = call.locations.last().cloned();
837
838 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
839 *current_call = call;
840
841 cx.emit(AcpThreadEvent::EntryUpdated(ix));
842 } else {
843 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
844 }
845
846 if let Some(location) = location {
847 self.set_project_location(location, cx)
848 }
849 }
850
851 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
852 // The tool call we are looking for is typically the last one, or very close to the end.
853 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
854 self.entries
855 .iter_mut()
856 .enumerate()
857 .rev()
858 .find_map(|(index, tool_call)| {
859 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
860 && &tool_call.id == id
861 {
862 Some((index, tool_call))
863 } else {
864 None
865 }
866 })
867 }
868
869 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
870 self.project.update(cx, |project, cx| {
871 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
872 return;
873 };
874 let buffer = project.open_buffer(path, cx);
875 cx.spawn(async move |project, cx| {
876 let buffer = buffer.await?;
877
878 project.update(cx, |project, cx| {
879 let position = if let Some(line) = location.line {
880 let snapshot = buffer.read(cx).snapshot();
881 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
882 snapshot.anchor_before(point)
883 } else {
884 Anchor::MIN
885 };
886
887 project.set_agent_location(
888 Some(AgentLocation {
889 buffer: buffer.downgrade(),
890 position,
891 }),
892 cx,
893 );
894 })
895 })
896 .detach_and_log_err(cx);
897 });
898 }
899
900 pub fn request_tool_call_permission(
901 &mut self,
902 tool_call: acp::ToolCall,
903 options: Vec<acp::PermissionOption>,
904 cx: &mut Context<Self>,
905 ) -> oneshot::Receiver<acp::PermissionOptionId> {
906 let (tx, rx) = oneshot::channel();
907
908 let status = ToolCallStatus::WaitingForConfirmation {
909 options,
910 respond_tx: tx,
911 };
912
913 self.upsert_tool_call_inner(tool_call, status, cx);
914 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
915 rx
916 }
917
918 pub fn authorize_tool_call(
919 &mut self,
920 id: acp::ToolCallId,
921 option_id: acp::PermissionOptionId,
922 option_kind: acp::PermissionOptionKind,
923 cx: &mut Context<Self>,
924 ) {
925 let Some((ix, call)) = self.tool_call_mut(&id) else {
926 return;
927 };
928
929 let new_status = match option_kind {
930 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
931 ToolCallStatus::Rejected
932 }
933 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
934 ToolCallStatus::Allowed {
935 status: acp::ToolCallStatus::InProgress,
936 }
937 }
938 };
939
940 let curr_status = mem::replace(&mut call.status, new_status);
941
942 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
943 respond_tx.send(option_id).log_err();
944 } else if cfg!(debug_assertions) {
945 panic!("tried to authorize an already authorized tool call");
946 }
947
948 cx.emit(AcpThreadEvent::EntryUpdated(ix));
949 }
950
951 /// Returns true if the last turn is awaiting tool authorization
952 pub fn waiting_for_tool_confirmation(&self) -> bool {
953 for entry in self.entries.iter().rev() {
954 match &entry {
955 AgentThreadEntry::ToolCall(call) => match call.status {
956 ToolCallStatus::WaitingForConfirmation { .. } => return true,
957 ToolCallStatus::Allowed { .. }
958 | ToolCallStatus::Rejected
959 | ToolCallStatus::Canceled => continue,
960 },
961 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
962 // Reached the beginning of the turn
963 return false;
964 }
965 }
966 }
967 false
968 }
969
970 pub fn plan(&self) -> &Plan {
971 &self.plan
972 }
973
974 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
975 let new_entries_len = request.entries.len();
976 let mut new_entries = request.entries.into_iter();
977
978 // Reuse existing markdown to prevent flickering
979 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
980 let PlanEntry {
981 content,
982 priority,
983 status,
984 } = old;
985 content.update(cx, |old, cx| {
986 old.replace(new.content, cx);
987 });
988 *priority = new.priority;
989 *status = new.status;
990 }
991 for new in new_entries {
992 self.plan.entries.push(PlanEntry::from_acp(new, cx))
993 }
994 self.plan.entries.truncate(new_entries_len);
995
996 cx.notify();
997 }
998
999 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1000 self.plan
1001 .entries
1002 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1003 cx.notify();
1004 }
1005
1006 #[cfg(any(test, feature = "test-support"))]
1007 pub fn send_raw(
1008 &mut self,
1009 message: &str,
1010 cx: &mut Context<Self>,
1011 ) -> BoxFuture<'static, Result<()>> {
1012 self.send(
1013 vec![acp::ContentBlock::Text(acp::TextContent {
1014 text: message.to_string(),
1015 annotations: None,
1016 })],
1017 cx,
1018 )
1019 }
1020
1021 pub fn send(
1022 &mut self,
1023 message: Vec<acp::ContentBlock>,
1024 cx: &mut Context<Self>,
1025 ) -> BoxFuture<'static, Result<()>> {
1026 let block = ContentBlock::new_combined(
1027 message.clone(),
1028 self.project.read(cx).languages().clone(),
1029 cx,
1030 );
1031 self.push_entry(
1032 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1033 cx,
1034 );
1035 self.clear_completed_plan_entries(cx);
1036
1037 let (tx, rx) = oneshot::channel();
1038 let cancel_task = self.cancel(cx);
1039
1040 self.send_task = Some(cx.spawn(async move |this, cx| {
1041 async {
1042 cancel_task.await;
1043
1044 let result = this
1045 .update(cx, |this, cx| {
1046 this.connection.prompt(
1047 acp::PromptRequest {
1048 prompt: message,
1049 session_id: this.session_id.clone(),
1050 },
1051 cx,
1052 )
1053 })?
1054 .await;
1055 tx.send(result).log_err();
1056 this.update(cx, |this, _cx| this.send_task.take())?;
1057 anyhow::Ok(())
1058 }
1059 .await
1060 .log_err();
1061 }));
1062
1063 cx.spawn(async move |this, cx| match rx.await {
1064 Ok(Err(e)) => {
1065 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1066 .log_err();
1067 Err(e)?
1068 }
1069 _ => {
1070 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1071 .log_err();
1072 Ok(())
1073 }
1074 })
1075 .boxed()
1076 }
1077
1078 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1079 let Some(send_task) = self.send_task.take() else {
1080 return Task::ready(());
1081 };
1082
1083 for entry in self.entries.iter_mut() {
1084 if let AgentThreadEntry::ToolCall(call) = entry {
1085 let cancel = matches!(
1086 call.status,
1087 ToolCallStatus::WaitingForConfirmation { .. }
1088 | ToolCallStatus::Allowed {
1089 status: acp::ToolCallStatus::InProgress
1090 }
1091 );
1092
1093 if cancel {
1094 call.status = ToolCallStatus::Canceled;
1095 }
1096 }
1097 }
1098
1099 self.connection.cancel(&self.session_id, cx);
1100
1101 // Wait for the send task to complete
1102 cx.foreground_executor().spawn(send_task)
1103 }
1104
1105 pub fn read_text_file(
1106 &self,
1107 path: PathBuf,
1108 line: Option<u32>,
1109 limit: Option<u32>,
1110 reuse_shared_snapshot: bool,
1111 cx: &mut Context<Self>,
1112 ) -> Task<Result<String>> {
1113 let project = self.project.clone();
1114 let action_log = self.action_log.clone();
1115 cx.spawn(async move |this, cx| {
1116 let load = project.update(cx, |project, cx| {
1117 let path = project
1118 .project_path_for_absolute_path(&path, cx)
1119 .context("invalid path")?;
1120 anyhow::Ok(project.open_buffer(path, cx))
1121 });
1122 let buffer = load??.await?;
1123
1124 let snapshot = if reuse_shared_snapshot {
1125 this.read_with(cx, |this, _| {
1126 this.shared_buffers.get(&buffer.clone()).cloned()
1127 })
1128 .log_err()
1129 .flatten()
1130 } else {
1131 None
1132 };
1133
1134 let snapshot = if let Some(snapshot) = snapshot {
1135 snapshot
1136 } else {
1137 action_log.update(cx, |action_log, cx| {
1138 action_log.buffer_read(buffer.clone(), cx);
1139 })?;
1140 project.update(cx, |project, cx| {
1141 let position = buffer
1142 .read(cx)
1143 .snapshot()
1144 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1145 project.set_agent_location(
1146 Some(AgentLocation {
1147 buffer: buffer.downgrade(),
1148 position,
1149 }),
1150 cx,
1151 );
1152 })?;
1153
1154 buffer.update(cx, |buffer, _| buffer.snapshot())?
1155 };
1156
1157 this.update(cx, |this, _| {
1158 let text = snapshot.text();
1159 this.shared_buffers.insert(buffer.clone(), snapshot);
1160 if line.is_none() && limit.is_none() {
1161 return Ok(text);
1162 }
1163 let limit = limit.unwrap_or(u32::MAX) as usize;
1164 let Some(line) = line else {
1165 return Ok(text.lines().take(limit).collect::<String>());
1166 };
1167
1168 let count = text.lines().count();
1169 if count < line as usize {
1170 anyhow::bail!("There are only {} lines", count);
1171 }
1172 Ok(text
1173 .lines()
1174 .skip(line as usize + 1)
1175 .take(limit)
1176 .collect::<String>())
1177 })?
1178 })
1179 }
1180
1181 pub fn write_text_file(
1182 &self,
1183 path: PathBuf,
1184 content: String,
1185 cx: &mut Context<Self>,
1186 ) -> Task<Result<()>> {
1187 let project = self.project.clone();
1188 let action_log = self.action_log.clone();
1189 cx.spawn(async move |this, cx| {
1190 let load = project.update(cx, |project, cx| {
1191 let path = project
1192 .project_path_for_absolute_path(&path, cx)
1193 .context("invalid path")?;
1194 anyhow::Ok(project.open_buffer(path, cx))
1195 });
1196 let buffer = load??.await?;
1197 let snapshot = this.update(cx, |this, cx| {
1198 this.shared_buffers
1199 .get(&buffer)
1200 .cloned()
1201 .unwrap_or_else(|| buffer.read(cx).snapshot())
1202 })?;
1203 let edits = cx
1204 .background_executor()
1205 .spawn(async move {
1206 let old_text = snapshot.text();
1207 text_diff(old_text.as_str(), &content)
1208 .into_iter()
1209 .map(|(range, replacement)| {
1210 (
1211 snapshot.anchor_after(range.start)
1212 ..snapshot.anchor_before(range.end),
1213 replacement,
1214 )
1215 })
1216 .collect::<Vec<_>>()
1217 })
1218 .await;
1219 cx.update(|cx| {
1220 project.update(cx, |project, cx| {
1221 project.set_agent_location(
1222 Some(AgentLocation {
1223 buffer: buffer.downgrade(),
1224 position: edits
1225 .last()
1226 .map(|(range, _)| range.end)
1227 .unwrap_or(Anchor::MIN),
1228 }),
1229 cx,
1230 );
1231 });
1232
1233 action_log.update(cx, |action_log, cx| {
1234 action_log.buffer_read(buffer.clone(), cx);
1235 });
1236 buffer.update(cx, |buffer, cx| {
1237 buffer.edit(edits, None, cx);
1238 });
1239 action_log.update(cx, |action_log, cx| {
1240 action_log.buffer_edited(buffer.clone(), cx);
1241 });
1242 })?;
1243 project
1244 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1245 .await
1246 })
1247 }
1248
1249 pub fn to_markdown(&self, cx: &App) -> String {
1250 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1251 }
1252
1253 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1254 cx.emit(AcpThreadEvent::ServerExited(status));
1255 }
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260 use super::*;
1261 use anyhow::anyhow;
1262 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1263 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1264 use indoc::indoc;
1265 use project::FakeFs;
1266 use rand::Rng as _;
1267 use serde_json::json;
1268 use settings::SettingsStore;
1269 use smol::stream::StreamExt as _;
1270 use std::{cell::RefCell, rc::Rc, time::Duration};
1271
1272 use util::path;
1273
1274 fn init_test(cx: &mut TestAppContext) {
1275 env_logger::try_init().ok();
1276 cx.update(|cx| {
1277 let settings_store = SettingsStore::test(cx);
1278 cx.set_global(settings_store);
1279 Project::init_settings(cx);
1280 language::init(cx);
1281 });
1282 }
1283
1284 #[gpui::test]
1285 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1286 init_test(cx);
1287
1288 let fs = FakeFs::new(cx.executor());
1289 let project = Project::test(fs, [], cx).await;
1290 let connection = Rc::new(FakeAgentConnection::new());
1291 let thread = cx
1292 .spawn(async move |mut cx| {
1293 connection
1294 .new_thread(project, Path::new(path!("/test")), &mut cx)
1295 .await
1296 })
1297 .await
1298 .unwrap();
1299
1300 // Test creating a new user message
1301 thread.update(cx, |thread, cx| {
1302 thread.push_user_content_block(
1303 acp::ContentBlock::Text(acp::TextContent {
1304 annotations: None,
1305 text: "Hello, ".to_string(),
1306 }),
1307 cx,
1308 );
1309 });
1310
1311 thread.update(cx, |thread, cx| {
1312 assert_eq!(thread.entries.len(), 1);
1313 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1314 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1315 } else {
1316 panic!("Expected UserMessage");
1317 }
1318 });
1319
1320 // Test appending to existing user message
1321 thread.update(cx, |thread, cx| {
1322 thread.push_user_content_block(
1323 acp::ContentBlock::Text(acp::TextContent {
1324 annotations: None,
1325 text: "world!".to_string(),
1326 }),
1327 cx,
1328 );
1329 });
1330
1331 thread.update(cx, |thread, cx| {
1332 assert_eq!(thread.entries.len(), 1);
1333 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1334 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1335 } else {
1336 panic!("Expected UserMessage");
1337 }
1338 });
1339
1340 // Test creating new user message after assistant message
1341 thread.update(cx, |thread, cx| {
1342 thread.push_assistant_content_block(
1343 acp::ContentBlock::Text(acp::TextContent {
1344 annotations: None,
1345 text: "Assistant response".to_string(),
1346 }),
1347 false,
1348 cx,
1349 );
1350 });
1351
1352 thread.update(cx, |thread, cx| {
1353 thread.push_user_content_block(
1354 acp::ContentBlock::Text(acp::TextContent {
1355 annotations: None,
1356 text: "New user message".to_string(),
1357 }),
1358 cx,
1359 );
1360 });
1361
1362 thread.update(cx, |thread, cx| {
1363 assert_eq!(thread.entries.len(), 3);
1364 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1365 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1366 } else {
1367 panic!("Expected UserMessage at index 2");
1368 }
1369 });
1370 }
1371
1372 #[gpui::test]
1373 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1374 init_test(cx);
1375
1376 let fs = FakeFs::new(cx.executor());
1377 let project = Project::test(fs, [], cx).await;
1378 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1379 |_, thread, mut cx| {
1380 async move {
1381 thread.update(&mut cx, |thread, cx| {
1382 thread
1383 .handle_session_update(
1384 acp::SessionUpdate::AgentThoughtChunk {
1385 content: "Thinking ".into(),
1386 },
1387 cx,
1388 )
1389 .unwrap();
1390 thread
1391 .handle_session_update(
1392 acp::SessionUpdate::AgentThoughtChunk {
1393 content: "hard!".into(),
1394 },
1395 cx,
1396 )
1397 .unwrap();
1398 })?;
1399 Ok(acp::PromptResponse {
1400 stop_reason: acp::StopReason::EndTurn,
1401 })
1402 }
1403 .boxed_local()
1404 },
1405 ));
1406
1407 let thread = cx
1408 .spawn(async move |mut cx| {
1409 connection
1410 .new_thread(project, Path::new(path!("/test")), &mut cx)
1411 .await
1412 })
1413 .await
1414 .unwrap();
1415
1416 thread
1417 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1418 .await
1419 .unwrap();
1420
1421 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1422 assert_eq!(
1423 output,
1424 indoc! {r#"
1425 ## User
1426
1427 Hello from Zed!
1428
1429 ## Assistant
1430
1431 <thinking>
1432 Thinking hard!
1433 </thinking>
1434
1435 "#}
1436 );
1437 }
1438
1439 #[gpui::test]
1440 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1441 init_test(cx);
1442
1443 let fs = FakeFs::new(cx.executor());
1444 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1445 .await;
1446 let project = Project::test(fs.clone(), [], cx).await;
1447 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1448 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1449 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1450 move |_, thread, mut cx| {
1451 let read_file_tx = read_file_tx.clone();
1452 async move {
1453 let content = thread
1454 .update(&mut cx, |thread, cx| {
1455 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1456 })
1457 .unwrap()
1458 .await
1459 .unwrap();
1460 assert_eq!(content, "one\ntwo\nthree\n");
1461 read_file_tx.take().unwrap().send(()).unwrap();
1462 thread
1463 .update(&mut cx, |thread, cx| {
1464 thread.write_text_file(
1465 path!("/tmp/foo").into(),
1466 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1467 cx,
1468 )
1469 })
1470 .unwrap()
1471 .await
1472 .unwrap();
1473 Ok(acp::PromptResponse {
1474 stop_reason: acp::StopReason::EndTurn,
1475 })
1476 }
1477 .boxed_local()
1478 },
1479 ));
1480
1481 let (worktree, pathbuf) = project
1482 .update(cx, |project, cx| {
1483 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1484 })
1485 .await
1486 .unwrap();
1487 let buffer = project
1488 .update(cx, |project, cx| {
1489 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1490 })
1491 .await
1492 .unwrap();
1493
1494 let thread = cx
1495 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1496 .await
1497 .unwrap();
1498
1499 let request = thread.update(cx, |thread, cx| {
1500 thread.send_raw("Extend the count in /tmp/foo", cx)
1501 });
1502 read_file_rx.await.ok();
1503 buffer.update(cx, |buffer, cx| {
1504 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1505 });
1506 cx.run_until_parked();
1507 assert_eq!(
1508 buffer.read_with(cx, |buffer, _| buffer.text()),
1509 "zero\none\ntwo\nthree\nfour\nfive\n"
1510 );
1511 assert_eq!(
1512 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1513 "zero\none\ntwo\nthree\nfour\nfive\n"
1514 );
1515 request.await.unwrap();
1516 }
1517
1518 #[gpui::test]
1519 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1520 init_test(cx);
1521
1522 let fs = FakeFs::new(cx.executor());
1523 let project = Project::test(fs, [], cx).await;
1524 let id = acp::ToolCallId("test".into());
1525
1526 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1527 let id = id.clone();
1528 move |_, thread, mut cx| {
1529 let id = id.clone();
1530 async move {
1531 thread
1532 .update(&mut cx, |thread, cx| {
1533 thread.handle_session_update(
1534 acp::SessionUpdate::ToolCall(acp::ToolCall {
1535 id: id.clone(),
1536 title: "Label".into(),
1537 kind: acp::ToolKind::Fetch,
1538 status: acp::ToolCallStatus::InProgress,
1539 content: vec![],
1540 locations: vec![],
1541 raw_input: None,
1542 }),
1543 cx,
1544 )
1545 })
1546 .unwrap()
1547 .unwrap();
1548 Ok(acp::PromptResponse {
1549 stop_reason: acp::StopReason::EndTurn,
1550 })
1551 }
1552 .boxed_local()
1553 }
1554 }));
1555
1556 let thread = cx
1557 .spawn(async move |mut cx| {
1558 connection
1559 .new_thread(project, Path::new(path!("/test")), &mut cx)
1560 .await
1561 })
1562 .await
1563 .unwrap();
1564
1565 let request = thread.update(cx, |thread, cx| {
1566 thread.send_raw("Fetch https://example.com", cx)
1567 });
1568
1569 run_until_first_tool_call(&thread, cx).await;
1570
1571 thread.read_with(cx, |thread, _| {
1572 assert!(matches!(
1573 thread.entries[1],
1574 AgentThreadEntry::ToolCall(ToolCall {
1575 status: ToolCallStatus::Allowed {
1576 status: acp::ToolCallStatus::InProgress,
1577 ..
1578 },
1579 ..
1580 })
1581 ));
1582 });
1583
1584 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1585
1586 thread.read_with(cx, |thread, _| {
1587 assert!(matches!(
1588 &thread.entries[1],
1589 AgentThreadEntry::ToolCall(ToolCall {
1590 status: ToolCallStatus::Canceled,
1591 ..
1592 })
1593 ));
1594 });
1595
1596 thread
1597 .update(cx, |thread, cx| {
1598 thread.handle_session_update(
1599 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1600 id,
1601 fields: acp::ToolCallUpdateFields {
1602 status: Some(acp::ToolCallStatus::Completed),
1603 ..Default::default()
1604 },
1605 }),
1606 cx,
1607 )
1608 })
1609 .unwrap();
1610
1611 request.await.unwrap();
1612
1613 thread.read_with(cx, |thread, _| {
1614 assert!(matches!(
1615 thread.entries[1],
1616 AgentThreadEntry::ToolCall(ToolCall {
1617 status: ToolCallStatus::Allowed {
1618 status: acp::ToolCallStatus::Completed,
1619 ..
1620 },
1621 ..
1622 })
1623 ));
1624 });
1625 }
1626
1627 #[gpui::test]
1628 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1629 init_test(cx);
1630 let fs = FakeFs::new(cx.background_executor.clone());
1631 fs.insert_tree(path!("/test"), json!({})).await;
1632 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1633
1634 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1635 move |_, thread, mut cx| {
1636 async move {
1637 thread
1638 .update(&mut cx, |thread, cx| {
1639 thread.handle_session_update(
1640 acp::SessionUpdate::ToolCall(acp::ToolCall {
1641 id: acp::ToolCallId("test".into()),
1642 title: "Label".into(),
1643 kind: acp::ToolKind::Edit,
1644 status: acp::ToolCallStatus::Completed,
1645 content: vec![acp::ToolCallContent::Diff {
1646 diff: acp::Diff {
1647 path: "/test/test.txt".into(),
1648 old_text: None,
1649 new_text: "foo".into(),
1650 },
1651 }],
1652 locations: vec![],
1653 raw_input: None,
1654 }),
1655 cx,
1656 )
1657 })
1658 .unwrap()
1659 .unwrap();
1660 Ok(acp::PromptResponse {
1661 stop_reason: acp::StopReason::EndTurn,
1662 })
1663 }
1664 .boxed_local()
1665 }
1666 }));
1667
1668 let thread = connection
1669 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1670 .await
1671 .unwrap();
1672 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1673 .await
1674 .unwrap();
1675
1676 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1677 }
1678
1679 async fn run_until_first_tool_call(
1680 thread: &Entity<AcpThread>,
1681 cx: &mut TestAppContext,
1682 ) -> usize {
1683 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1684
1685 let subscription = cx.update(|cx| {
1686 cx.subscribe(thread, move |thread, _, cx| {
1687 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1688 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1689 return tx.try_send(ix).unwrap();
1690 }
1691 }
1692 })
1693 });
1694
1695 select! {
1696 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1697 panic!("Timeout waiting for tool call")
1698 }
1699 ix = rx.next().fuse() => {
1700 drop(subscription);
1701 ix.unwrap()
1702 }
1703 }
1704 }
1705
1706 #[derive(Clone, Default)]
1707 struct FakeAgentConnection {
1708 auth_methods: Vec<acp::AuthMethod>,
1709 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1710 on_user_message: Option<
1711 Rc<
1712 dyn Fn(
1713 acp::PromptRequest,
1714 WeakEntity<AcpThread>,
1715 AsyncApp,
1716 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1717 + 'static,
1718 >,
1719 >,
1720 }
1721
1722 impl FakeAgentConnection {
1723 fn new() -> Self {
1724 Self {
1725 auth_methods: Vec::new(),
1726 on_user_message: None,
1727 sessions: Arc::default(),
1728 }
1729 }
1730
1731 #[expect(unused)]
1732 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1733 self.auth_methods = auth_methods;
1734 self
1735 }
1736
1737 fn on_user_message(
1738 mut self,
1739 handler: impl Fn(
1740 acp::PromptRequest,
1741 WeakEntity<AcpThread>,
1742 AsyncApp,
1743 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1744 + 'static,
1745 ) -> Self {
1746 self.on_user_message.replace(Rc::new(handler));
1747 self
1748 }
1749 }
1750
1751 impl AgentConnection for FakeAgentConnection {
1752 fn auth_methods(&self) -> &[acp::AuthMethod] {
1753 &self.auth_methods
1754 }
1755
1756 fn new_thread(
1757 self: Rc<Self>,
1758 project: Entity<Project>,
1759 _cwd: &Path,
1760 cx: &mut gpui::AsyncApp,
1761 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1762 let session_id = acp::SessionId(
1763 rand::thread_rng()
1764 .sample_iter(&rand::distributions::Alphanumeric)
1765 .take(7)
1766 .map(char::from)
1767 .collect::<String>()
1768 .into(),
1769 );
1770 let thread = cx
1771 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1772 .unwrap();
1773 self.sessions.lock().insert(session_id, thread.downgrade());
1774 Task::ready(Ok(thread))
1775 }
1776
1777 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1778 if self.auth_methods().iter().any(|m| m.id == method) {
1779 Task::ready(Ok(()))
1780 } else {
1781 Task::ready(Err(anyhow!("Invalid Auth Method")))
1782 }
1783 }
1784
1785 fn prompt(
1786 &self,
1787 params: acp::PromptRequest,
1788 cx: &mut App,
1789 ) -> Task<gpui::Result<acp::PromptResponse>> {
1790 let sessions = self.sessions.lock();
1791 let thread = sessions.get(¶ms.session_id).unwrap();
1792 if let Some(handler) = &self.on_user_message {
1793 let handler = handler.clone();
1794 let thread = thread.clone();
1795 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1796 } else {
1797 Task::ready(Ok(acp::PromptResponse {
1798 stop_reason: acp::StopReason::EndTurn,
1799 }))
1800 }
1801 }
1802
1803 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1804 let sessions = self.sessions.lock();
1805 let thread = sessions.get(&session_id).unwrap().clone();
1806
1807 cx.spawn(async move |cx| {
1808 thread
1809 .update(cx, |thread, cx| thread.cancel(cx))
1810 .unwrap()
1811 .await
1812 })
1813 .detach();
1814 }
1815 }
1816}