1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::future::{Fuse, FusedFuture};
10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
11use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
12use itertools::Itertools;
13use language::{
14 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
15 text_diff,
16};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::process::ExitStatus;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
145 if let AgentThreadEntry::ToolCall(call) = self {
146 itertools::Either::Left(call.diffs())
147 } else {
148 itertools::Either::Right(std::iter::empty())
149 }
150 }
151
152 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
153 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
154 Some(locations)
155 } else {
156 None
157 }
158 }
159}
160
161#[derive(Debug)]
162pub struct ToolCall {
163 pub id: acp::ToolCallId,
164 pub label: Entity<Markdown>,
165 pub kind: acp::ToolKind,
166 pub content: Vec<ToolCallContent>,
167 pub status: ToolCallStatus,
168 pub locations: Vec<acp::ToolCallLocation>,
169 pub raw_input: Option<serde_json::Value>,
170 pub raw_output: Option<serde_json::Value>,
171}
172
173impl ToolCall {
174 fn from_acp(
175 tool_call: acp::ToolCall,
176 status: ToolCallStatus,
177 language_registry: Arc<LanguageRegistry>,
178 cx: &mut App,
179 ) -> Self {
180 Self {
181 id: tool_call.id,
182 label: cx.new(|cx| {
183 Markdown::new(
184 tool_call.title.into(),
185 Some(language_registry.clone()),
186 None,
187 cx,
188 )
189 }),
190 kind: tool_call.kind,
191 content: tool_call
192 .content
193 .into_iter()
194 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
195 .collect(),
196 locations: tool_call.locations,
197 status,
198 raw_input: tool_call.raw_input,
199 raw_output: tool_call.raw_output,
200 }
201 }
202
203 fn update(
204 &mut self,
205 fields: acp::ToolCallUpdateFields,
206 language_registry: Arc<LanguageRegistry>,
207 cx: &mut App,
208 ) {
209 let acp::ToolCallUpdateFields {
210 kind,
211 status,
212 title,
213 content,
214 locations,
215 raw_input,
216 raw_output,
217 } = fields;
218
219 if let Some(kind) = kind {
220 self.kind = kind;
221 }
222
223 if let Some(status) = status {
224 self.status = ToolCallStatus::Allowed { status };
225 }
226
227 if let Some(title) = title {
228 self.label.update(cx, |label, cx| {
229 label.replace(title, cx);
230 });
231 }
232
233 if let Some(content) = content {
234 self.content = content
235 .into_iter()
236 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
237 .collect();
238 }
239
240 if let Some(locations) = locations {
241 self.locations = locations;
242 }
243
244 if let Some(raw_input) = raw_input {
245 self.raw_input = Some(raw_input);
246 }
247
248 if let Some(raw_output) = raw_output {
249 self.raw_output = Some(raw_output);
250 }
251 }
252
253 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
254 self.content.iter().filter_map(|content| match content {
255 ToolCallContent::ContentBlock { .. } => None,
256 ToolCallContent::Diff { diff } => Some(diff),
257 })
258 }
259
260 fn to_markdown(&self, cx: &App) -> String {
261 let mut markdown = format!(
262 "**Tool Call: {}**\nStatus: {}\n\n",
263 self.label.read(cx).source(),
264 self.status
265 );
266 for content in &self.content {
267 markdown.push_str(content.to_markdown(cx).as_str());
268 markdown.push_str("\n\n");
269 }
270 markdown
271 }
272}
273
274#[derive(Debug)]
275pub enum ToolCallStatus {
276 WaitingForConfirmation {
277 options: Vec<acp::PermissionOption>,
278 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
279 },
280 Allowed {
281 status: acp::ToolCallStatus,
282 },
283 Rejected,
284 Canceled,
285}
286
287impl Display for ToolCallStatus {
288 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
289 write!(
290 f,
291 "{}",
292 match self {
293 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
294 ToolCallStatus::Allowed { status } => match status {
295 acp::ToolCallStatus::Pending => "Pending",
296 acp::ToolCallStatus::InProgress => "In Progress",
297 acp::ToolCallStatus::Completed => "Completed",
298 acp::ToolCallStatus::Failed => "Failed",
299 },
300 ToolCallStatus::Rejected => "Rejected",
301 ToolCallStatus::Canceled => "Canceled",
302 }
303 )
304 }
305}
306
307#[derive(Debug, PartialEq, Clone)]
308pub enum ContentBlock {
309 Empty,
310 Markdown { markdown: Entity<Markdown> },
311}
312
313impl ContentBlock {
314 pub fn new(
315 block: acp::ContentBlock,
316 language_registry: &Arc<LanguageRegistry>,
317 cx: &mut App,
318 ) -> Self {
319 let mut this = Self::Empty;
320 this.append(block, language_registry, cx);
321 this
322 }
323
324 pub fn new_combined(
325 blocks: impl IntoIterator<Item = acp::ContentBlock>,
326 language_registry: Arc<LanguageRegistry>,
327 cx: &mut App,
328 ) -> Self {
329 let mut this = Self::Empty;
330 for block in blocks {
331 this.append(block, &language_registry, cx);
332 }
333 this
334 }
335
336 pub fn append(
337 &mut self,
338 block: acp::ContentBlock,
339 language_registry: &Arc<LanguageRegistry>,
340 cx: &mut App,
341 ) {
342 let new_content = match block {
343 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
344 acp::ContentBlock::ResourceLink(resource_link) => {
345 if let Some(path) = resource_link.uri.strip_prefix("file://") {
346 format!("{}", MentionPath(path.as_ref()))
347 } else {
348 resource_link.uri.clone()
349 }
350 }
351 acp::ContentBlock::Image(_)
352 | acp::ContentBlock::Audio(_)
353 | acp::ContentBlock::Resource(_) => String::new(),
354 };
355
356 match self {
357 ContentBlock::Empty => {
358 *self = ContentBlock::Markdown {
359 markdown: cx.new(|cx| {
360 Markdown::new(
361 new_content.into(),
362 Some(language_registry.clone()),
363 None,
364 cx,
365 )
366 }),
367 };
368 }
369 ContentBlock::Markdown { markdown } => {
370 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
371 }
372 }
373 }
374
375 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
376 match self {
377 ContentBlock::Empty => "",
378 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
379 }
380 }
381
382 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
383 match self {
384 ContentBlock::Empty => None,
385 ContentBlock::Markdown { markdown } => Some(markdown),
386 }
387 }
388}
389
390#[derive(Debug)]
391pub enum ToolCallContent {
392 ContentBlock { content: ContentBlock },
393 Diff { diff: Diff },
394}
395
396impl ToolCallContent {
397 pub fn from_acp(
398 content: acp::ToolCallContent,
399 language_registry: Arc<LanguageRegistry>,
400 cx: &mut App,
401 ) -> Self {
402 match content {
403 acp::ToolCallContent::Content { content } => Self::ContentBlock {
404 content: ContentBlock::new(content, &language_registry, cx),
405 },
406 acp::ToolCallContent::Diff { diff } => Self::Diff {
407 diff: Diff::from_acp(diff, language_registry, cx),
408 },
409 }
410 }
411
412 pub fn to_markdown(&self, cx: &App) -> String {
413 match self {
414 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
415 Self::Diff { diff } => diff.to_markdown(cx),
416 }
417 }
418}
419
420#[derive(Debug)]
421pub struct Diff {
422 pub multibuffer: Entity<MultiBuffer>,
423 pub path: PathBuf,
424 _task: Task<Result<()>>,
425}
426
427impl Diff {
428 pub fn from_acp(
429 diff: acp::Diff,
430 language_registry: Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) -> Self {
433 let acp::Diff {
434 path,
435 old_text,
436 new_text,
437 } = diff;
438
439 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
440
441 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
442 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
443 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
444 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
445
446 let task = cx.spawn({
447 let multibuffer = multibuffer.clone();
448 let path = path.clone();
449 async move |cx| {
450 let language = language_registry
451 .language_for_file_path(&path)
452 .await
453 .log_err();
454
455 new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
456
457 let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
458 buffer.set_language(language, cx);
459 buffer.snapshot()
460 })?;
461
462 buffer_diff
463 .update(cx, |diff, cx| {
464 diff.set_base_text(
465 old_buffer_snapshot,
466 Some(language_registry),
467 new_buffer_snapshot,
468 cx,
469 )
470 })?
471 .await?;
472
473 multibuffer
474 .update(cx, |multibuffer, cx| {
475 let hunk_ranges = {
476 let buffer = new_buffer.read(cx);
477 let diff = buffer_diff.read(cx);
478 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
479 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
480 .collect::<Vec<_>>()
481 };
482
483 multibuffer.set_excerpts_for_path(
484 PathKey::for_buffer(&new_buffer, cx),
485 new_buffer.clone(),
486 hunk_ranges,
487 editor::DEFAULT_MULTIBUFFER_CONTEXT,
488 cx,
489 );
490 multibuffer.add_diff(buffer_diff, cx);
491 })
492 .log_err();
493
494 anyhow::Ok(())
495 }
496 });
497
498 Self {
499 multibuffer,
500 path,
501 _task: task,
502 }
503 }
504
505 fn to_markdown(&self, cx: &App) -> String {
506 let buffer_text = self
507 .multibuffer
508 .read(cx)
509 .all_buffers()
510 .iter()
511 .map(|buffer| buffer.read(cx).text())
512 .join("\n");
513 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
514 }
515}
516
517#[derive(Debug, Default)]
518pub struct Plan {
519 pub entries: Vec<PlanEntry>,
520}
521
522#[derive(Debug)]
523pub struct PlanStats<'a> {
524 pub in_progress_entry: Option<&'a PlanEntry>,
525 pub pending: u32,
526 pub completed: u32,
527}
528
529impl Plan {
530 pub fn is_empty(&self) -> bool {
531 self.entries.is_empty()
532 }
533
534 pub fn stats(&self) -> PlanStats<'_> {
535 let mut stats = PlanStats {
536 in_progress_entry: None,
537 pending: 0,
538 completed: 0,
539 };
540
541 for entry in &self.entries {
542 match &entry.status {
543 acp::PlanEntryStatus::Pending => {
544 stats.pending += 1;
545 }
546 acp::PlanEntryStatus::InProgress => {
547 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
548 }
549 acp::PlanEntryStatus::Completed => {
550 stats.completed += 1;
551 }
552 }
553 }
554
555 stats
556 }
557}
558
559#[derive(Debug)]
560pub struct PlanEntry {
561 pub content: Entity<Markdown>,
562 pub priority: acp::PlanEntryPriority,
563 pub status: acp::PlanEntryStatus,
564}
565
566impl PlanEntry {
567 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
568 Self {
569 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
570 priority: entry.priority,
571 status: entry.status,
572 }
573 }
574}
575
576pub struct AcpThread {
577 title: SharedString,
578 entries: Vec<AgentThreadEntry>,
579 plan: Plan,
580 project: Entity<Project>,
581 action_log: Entity<ActionLog>,
582 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
583 send_task: Option<Fuse<Task<()>>>,
584 connection: Rc<dyn AgentConnection>,
585 session_id: acp::SessionId,
586}
587
588pub enum AcpThreadEvent {
589 NewEntry,
590 EntryUpdated(usize),
591 ToolAuthorizationRequired,
592 Stopped,
593 Error,
594 ServerExited(ExitStatus),
595}
596
597impl EventEmitter<AcpThreadEvent> for AcpThread {}
598
599#[derive(PartialEq, Eq)]
600pub enum ThreadStatus {
601 Idle,
602 WaitingForToolConfirmation,
603 Generating,
604}
605
606#[derive(Debug, Clone)]
607pub enum LoadError {
608 Unsupported {
609 error_message: SharedString,
610 upgrade_message: SharedString,
611 upgrade_command: String,
612 },
613 Exited(i32),
614 Other(SharedString),
615}
616
617impl Display for LoadError {
618 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
619 match self {
620 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
621 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
622 LoadError::Other(msg) => write!(f, "{}", msg),
623 }
624 }
625}
626
627impl Error for LoadError {}
628
629impl AcpThread {
630 pub fn new(
631 title: impl Into<SharedString>,
632 connection: Rc<dyn AgentConnection>,
633 project: Entity<Project>,
634 session_id: acp::SessionId,
635 cx: &mut Context<Self>,
636 ) -> Self {
637 let action_log = cx.new(|_| ActionLog::new(project.clone()));
638
639 Self {
640 action_log,
641 shared_buffers: Default::default(),
642 entries: Default::default(),
643 plan: Default::default(),
644 title: title.into(),
645 project,
646 send_task: None,
647 connection,
648 session_id,
649 }
650 }
651
652 pub fn action_log(&self) -> &Entity<ActionLog> {
653 &self.action_log
654 }
655
656 pub fn project(&self) -> &Entity<Project> {
657 &self.project
658 }
659
660 pub fn title(&self) -> SharedString {
661 self.title.clone()
662 }
663
664 pub fn entries(&self) -> &[AgentThreadEntry] {
665 &self.entries
666 }
667
668 pub fn session_id(&self) -> &acp::SessionId {
669 &self.session_id
670 }
671
672 pub fn status(&self) -> ThreadStatus {
673 if self
674 .send_task
675 .as_ref()
676 .map_or(false, |t| !t.is_terminated())
677 {
678 if self.waiting_for_tool_confirmation() {
679 ThreadStatus::WaitingForToolConfirmation
680 } else {
681 ThreadStatus::Generating
682 }
683 } else {
684 ThreadStatus::Idle
685 }
686 }
687
688 pub fn has_pending_edit_tool_calls(&self) -> bool {
689 for entry in self.entries.iter().rev() {
690 match entry {
691 AgentThreadEntry::UserMessage(_) => return false,
692 AgentThreadEntry::ToolCall(
693 call @ ToolCall {
694 status:
695 ToolCallStatus::Allowed {
696 status:
697 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
698 },
699 ..
700 },
701 ) if call.diffs().next().is_some() => {
702 return true;
703 }
704 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
705 }
706 }
707
708 false
709 }
710
711 pub fn used_tools_since_last_user_message(&self) -> bool {
712 for entry in self.entries.iter().rev() {
713 match entry {
714 AgentThreadEntry::UserMessage(..) => return false,
715 AgentThreadEntry::AssistantMessage(..) => continue,
716 AgentThreadEntry::ToolCall(..) => return true,
717 }
718 }
719
720 false
721 }
722
723 pub fn handle_session_update(
724 &mut self,
725 update: acp::SessionUpdate,
726 cx: &mut Context<Self>,
727 ) -> Result<()> {
728 match update {
729 acp::SessionUpdate::UserMessageChunk { content } => {
730 self.push_user_content_block(content, cx);
731 }
732 acp::SessionUpdate::AgentMessageChunk { content } => {
733 self.push_assistant_content_block(content, false, cx);
734 }
735 acp::SessionUpdate::AgentThoughtChunk { content } => {
736 self.push_assistant_content_block(content, true, cx);
737 }
738 acp::SessionUpdate::ToolCall(tool_call) => {
739 self.upsert_tool_call(tool_call, cx);
740 }
741 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
742 self.update_tool_call(tool_call_update, cx)?;
743 }
744 acp::SessionUpdate::Plan(plan) => {
745 self.update_plan(plan, cx);
746 }
747 }
748 Ok(())
749 }
750
751 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
752 let language_registry = self.project.read(cx).languages().clone();
753 let entries_len = self.entries.len();
754
755 if let Some(last_entry) = self.entries.last_mut()
756 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
757 {
758 content.append(chunk, &language_registry, cx);
759 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
760 } else {
761 let content = ContentBlock::new(chunk, &language_registry, cx);
762 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
763 }
764 }
765
766 pub fn push_assistant_content_block(
767 &mut self,
768 chunk: acp::ContentBlock,
769 is_thought: bool,
770 cx: &mut Context<Self>,
771 ) {
772 let language_registry = self.project.read(cx).languages().clone();
773 let entries_len = self.entries.len();
774 if let Some(last_entry) = self.entries.last_mut()
775 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
776 {
777 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
778 match (chunks.last_mut(), is_thought) {
779 (Some(AssistantMessageChunk::Message { block }), false)
780 | (Some(AssistantMessageChunk::Thought { block }), true) => {
781 block.append(chunk, &language_registry, cx)
782 }
783 _ => {
784 let block = ContentBlock::new(chunk, &language_registry, cx);
785 if is_thought {
786 chunks.push(AssistantMessageChunk::Thought { block })
787 } else {
788 chunks.push(AssistantMessageChunk::Message { block })
789 }
790 }
791 }
792 } else {
793 let block = ContentBlock::new(chunk, &language_registry, cx);
794 let chunk = if is_thought {
795 AssistantMessageChunk::Thought { block }
796 } else {
797 AssistantMessageChunk::Message { block }
798 };
799
800 self.push_entry(
801 AgentThreadEntry::AssistantMessage(AssistantMessage {
802 chunks: vec![chunk],
803 }),
804 cx,
805 );
806 }
807 }
808
809 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
810 self.entries.push(entry);
811 cx.emit(AcpThreadEvent::NewEntry);
812 }
813
814 pub fn update_tool_call(
815 &mut self,
816 update: acp::ToolCallUpdate,
817 cx: &mut Context<Self>,
818 ) -> Result<()> {
819 let languages = self.project.read(cx).languages().clone();
820
821 let (ix, current_call) = self
822 .tool_call_mut(&update.id)
823 .context("Tool call not found")?;
824 current_call.update(update.fields, languages, cx);
825
826 cx.emit(AcpThreadEvent::EntryUpdated(ix));
827
828 Ok(())
829 }
830
831 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
832 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
833 let status = ToolCallStatus::Allowed {
834 status: tool_call.status,
835 };
836 self.upsert_tool_call_inner(tool_call, status, cx)
837 }
838
839 pub fn upsert_tool_call_inner(
840 &mut self,
841 tool_call: acp::ToolCall,
842 status: ToolCallStatus,
843 cx: &mut Context<Self>,
844 ) {
845 let language_registry = self.project.read(cx).languages().clone();
846 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
847
848 let location = call.locations.last().cloned();
849
850 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
851 *current_call = call;
852
853 cx.emit(AcpThreadEvent::EntryUpdated(ix));
854 } else {
855 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
856 }
857
858 if let Some(location) = location {
859 self.set_project_location(location, cx)
860 }
861 }
862
863 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
864 // The tool call we are looking for is typically the last one, or very close to the end.
865 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
866 self.entries
867 .iter_mut()
868 .enumerate()
869 .rev()
870 .find_map(|(index, tool_call)| {
871 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
872 && &tool_call.id == id
873 {
874 Some((index, tool_call))
875 } else {
876 None
877 }
878 })
879 }
880
881 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
882 self.project.update(cx, |project, cx| {
883 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
884 return;
885 };
886 let buffer = project.open_buffer(path, cx);
887 cx.spawn(async move |project, cx| {
888 let buffer = buffer.await?;
889
890 project.update(cx, |project, cx| {
891 let position = if let Some(line) = location.line {
892 let snapshot = buffer.read(cx).snapshot();
893 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
894 snapshot.anchor_before(point)
895 } else {
896 Anchor::MIN
897 };
898
899 project.set_agent_location(
900 Some(AgentLocation {
901 buffer: buffer.downgrade(),
902 position,
903 }),
904 cx,
905 );
906 })
907 })
908 .detach_and_log_err(cx);
909 });
910 }
911
912 pub fn request_tool_call_authorization(
913 &mut self,
914 tool_call: acp::ToolCall,
915 options: Vec<acp::PermissionOption>,
916 cx: &mut Context<Self>,
917 ) -> oneshot::Receiver<acp::PermissionOptionId> {
918 let (tx, rx) = oneshot::channel();
919
920 let status = ToolCallStatus::WaitingForConfirmation {
921 options,
922 respond_tx: tx,
923 };
924
925 self.upsert_tool_call_inner(tool_call, status, cx);
926 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
927 rx
928 }
929
930 pub fn authorize_tool_call(
931 &mut self,
932 id: acp::ToolCallId,
933 option_id: acp::PermissionOptionId,
934 option_kind: acp::PermissionOptionKind,
935 cx: &mut Context<Self>,
936 ) {
937 let Some((ix, call)) = self.tool_call_mut(&id) else {
938 return;
939 };
940
941 let new_status = match option_kind {
942 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
943 ToolCallStatus::Rejected
944 }
945 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
946 ToolCallStatus::Allowed {
947 status: acp::ToolCallStatus::InProgress,
948 }
949 }
950 };
951
952 let curr_status = mem::replace(&mut call.status, new_status);
953
954 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
955 respond_tx.send(option_id).log_err();
956 } else if cfg!(debug_assertions) {
957 panic!("tried to authorize an already authorized tool call");
958 }
959
960 cx.emit(AcpThreadEvent::EntryUpdated(ix));
961 }
962
963 /// Returns true if the last turn is awaiting tool authorization
964 pub fn waiting_for_tool_confirmation(&self) -> bool {
965 for entry in self.entries.iter().rev() {
966 match &entry {
967 AgentThreadEntry::ToolCall(call) => match call.status {
968 ToolCallStatus::WaitingForConfirmation { .. } => return true,
969 ToolCallStatus::Allowed { .. }
970 | ToolCallStatus::Rejected
971 | ToolCallStatus::Canceled => continue,
972 },
973 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
974 // Reached the beginning of the turn
975 return false;
976 }
977 }
978 }
979 false
980 }
981
982 pub fn plan(&self) -> &Plan {
983 &self.plan
984 }
985
986 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
987 let new_entries_len = request.entries.len();
988 let mut new_entries = request.entries.into_iter();
989
990 // Reuse existing markdown to prevent flickering
991 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
992 let PlanEntry {
993 content,
994 priority,
995 status,
996 } = old;
997 content.update(cx, |old, cx| {
998 old.replace(new.content, cx);
999 });
1000 *priority = new.priority;
1001 *status = new.status;
1002 }
1003 for new in new_entries {
1004 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1005 }
1006 self.plan.entries.truncate(new_entries_len);
1007
1008 cx.notify();
1009 }
1010
1011 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1012 self.plan
1013 .entries
1014 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1015 cx.notify();
1016 }
1017
1018 #[cfg(any(test, feature = "test-support"))]
1019 pub fn send_raw(
1020 &mut self,
1021 message: &str,
1022 cx: &mut Context<Self>,
1023 ) -> BoxFuture<'static, Result<()>> {
1024 self.send(
1025 vec![acp::ContentBlock::Text(acp::TextContent {
1026 text: message.to_string(),
1027 annotations: None,
1028 })],
1029 cx,
1030 )
1031 }
1032
1033 pub fn send(
1034 &mut self,
1035 message: Vec<acp::ContentBlock>,
1036 cx: &mut Context<Self>,
1037 ) -> BoxFuture<'static, Result<()>> {
1038 let block = ContentBlock::new_combined(
1039 message.clone(),
1040 self.project.read(cx).languages().clone(),
1041 cx,
1042 );
1043 self.push_entry(
1044 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1045 cx,
1046 );
1047 self.clear_completed_plan_entries(cx);
1048
1049 let (tx, rx) = oneshot::channel();
1050 let cancel_task = self.cancel(cx);
1051
1052 self.send_task = Some(
1053 cx.spawn(async move |this, cx| {
1054 async {
1055 cancel_task.await;
1056
1057 let result = this
1058 .update(cx, |this, cx| {
1059 this.connection.prompt(
1060 acp::PromptRequest {
1061 prompt: message,
1062 session_id: this.session_id.clone(),
1063 },
1064 cx,
1065 )
1066 })?
1067 .await;
1068
1069 tx.send(result).log_err();
1070 anyhow::Ok(())
1071 }
1072 .await
1073 .log_err();
1074 })
1075 .fuse(),
1076 );
1077
1078 cx.spawn(async move |this, cx| match rx.await {
1079 Ok(Err(e)) => {
1080 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1081 .log_err();
1082 Err(e)?
1083 }
1084 _ => {
1085 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1086 .log_err();
1087 Ok(())
1088 }
1089 })
1090 .boxed()
1091 }
1092
1093 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1094 let Some(send_task) = self.send_task.take() else {
1095 return Task::ready(());
1096 };
1097
1098 for entry in self.entries.iter_mut() {
1099 if let AgentThreadEntry::ToolCall(call) = entry {
1100 let cancel = matches!(
1101 call.status,
1102 ToolCallStatus::WaitingForConfirmation { .. }
1103 | ToolCallStatus::Allowed {
1104 status: acp::ToolCallStatus::InProgress
1105 }
1106 );
1107
1108 if cancel {
1109 call.status = ToolCallStatus::Canceled;
1110 }
1111 }
1112 }
1113
1114 self.connection.cancel(&self.session_id, cx);
1115
1116 // Wait for the send task to complete
1117 cx.foreground_executor().spawn(send_task)
1118 }
1119
1120 pub fn read_text_file(
1121 &self,
1122 path: PathBuf,
1123 line: Option<u32>,
1124 limit: Option<u32>,
1125 reuse_shared_snapshot: bool,
1126 cx: &mut Context<Self>,
1127 ) -> Task<Result<String>> {
1128 let project = self.project.clone();
1129 let action_log = self.action_log.clone();
1130 cx.spawn(async move |this, cx| {
1131 let load = project.update(cx, |project, cx| {
1132 let path = project
1133 .project_path_for_absolute_path(&path, cx)
1134 .context("invalid path")?;
1135 anyhow::Ok(project.open_buffer(path, cx))
1136 });
1137 let buffer = load??.await?;
1138
1139 let snapshot = if reuse_shared_snapshot {
1140 this.read_with(cx, |this, _| {
1141 this.shared_buffers.get(&buffer.clone()).cloned()
1142 })
1143 .log_err()
1144 .flatten()
1145 } else {
1146 None
1147 };
1148
1149 let snapshot = if let Some(snapshot) = snapshot {
1150 snapshot
1151 } else {
1152 action_log.update(cx, |action_log, cx| {
1153 action_log.buffer_read(buffer.clone(), cx);
1154 })?;
1155 project.update(cx, |project, cx| {
1156 let position = buffer
1157 .read(cx)
1158 .snapshot()
1159 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1160 project.set_agent_location(
1161 Some(AgentLocation {
1162 buffer: buffer.downgrade(),
1163 position,
1164 }),
1165 cx,
1166 );
1167 })?;
1168
1169 buffer.update(cx, |buffer, _| buffer.snapshot())?
1170 };
1171
1172 this.update(cx, |this, _| {
1173 let text = snapshot.text();
1174 this.shared_buffers.insert(buffer.clone(), snapshot);
1175 if line.is_none() && limit.is_none() {
1176 return Ok(text);
1177 }
1178 let limit = limit.unwrap_or(u32::MAX) as usize;
1179 let Some(line) = line else {
1180 return Ok(text.lines().take(limit).collect::<String>());
1181 };
1182
1183 let count = text.lines().count();
1184 if count < line as usize {
1185 anyhow::bail!("There are only {} lines", count);
1186 }
1187 Ok(text
1188 .lines()
1189 .skip(line as usize + 1)
1190 .take(limit)
1191 .collect::<String>())
1192 })?
1193 })
1194 }
1195
1196 pub fn write_text_file(
1197 &self,
1198 path: PathBuf,
1199 content: String,
1200 cx: &mut Context<Self>,
1201 ) -> Task<Result<()>> {
1202 let project = self.project.clone();
1203 let action_log = self.action_log.clone();
1204 cx.spawn(async move |this, cx| {
1205 let load = project.update(cx, |project, cx| {
1206 let path = project
1207 .project_path_for_absolute_path(&path, cx)
1208 .context("invalid path")?;
1209 anyhow::Ok(project.open_buffer(path, cx))
1210 });
1211 let buffer = load??.await?;
1212 let snapshot = this.update(cx, |this, cx| {
1213 this.shared_buffers
1214 .get(&buffer)
1215 .cloned()
1216 .unwrap_or_else(|| buffer.read(cx).snapshot())
1217 })?;
1218 let edits = cx
1219 .background_executor()
1220 .spawn(async move {
1221 let old_text = snapshot.text();
1222 text_diff(old_text.as_str(), &content)
1223 .into_iter()
1224 .map(|(range, replacement)| {
1225 (
1226 snapshot.anchor_after(range.start)
1227 ..snapshot.anchor_before(range.end),
1228 replacement,
1229 )
1230 })
1231 .collect::<Vec<_>>()
1232 })
1233 .await;
1234 cx.update(|cx| {
1235 project.update(cx, |project, cx| {
1236 project.set_agent_location(
1237 Some(AgentLocation {
1238 buffer: buffer.downgrade(),
1239 position: edits
1240 .last()
1241 .map(|(range, _)| range.end)
1242 .unwrap_or(Anchor::MIN),
1243 }),
1244 cx,
1245 );
1246 });
1247
1248 action_log.update(cx, |action_log, cx| {
1249 action_log.buffer_read(buffer.clone(), cx);
1250 });
1251 buffer.update(cx, |buffer, cx| {
1252 buffer.edit(edits, None, cx);
1253 });
1254 action_log.update(cx, |action_log, cx| {
1255 action_log.buffer_edited(buffer.clone(), cx);
1256 });
1257 })?;
1258 project
1259 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1260 .await
1261 })
1262 }
1263
1264 pub fn to_markdown(&self, cx: &App) -> String {
1265 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1266 }
1267
1268 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1269 cx.emit(AcpThreadEvent::ServerExited(status));
1270 }
1271}
1272
1273#[cfg(test)]
1274mod tests {
1275 use super::*;
1276 use anyhow::anyhow;
1277 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1278 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1279 use indoc::indoc;
1280 use project::FakeFs;
1281 use rand::Rng as _;
1282 use serde_json::json;
1283 use settings::SettingsStore;
1284 use smol::stream::StreamExt as _;
1285 use std::{cell::RefCell, rc::Rc, time::Duration};
1286
1287 use util::path;
1288
1289 fn init_test(cx: &mut TestAppContext) {
1290 env_logger::try_init().ok();
1291 cx.update(|cx| {
1292 let settings_store = SettingsStore::test(cx);
1293 cx.set_global(settings_store);
1294 Project::init_settings(cx);
1295 language::init(cx);
1296 });
1297 }
1298
1299 #[gpui::test]
1300 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1301 init_test(cx);
1302
1303 let fs = FakeFs::new(cx.executor());
1304 let project = Project::test(fs, [], cx).await;
1305 let connection = Rc::new(FakeAgentConnection::new());
1306 let thread = cx
1307 .spawn(async move |mut cx| {
1308 connection
1309 .new_thread(project, Path::new(path!("/test")), &mut cx)
1310 .await
1311 })
1312 .await
1313 .unwrap();
1314
1315 // Test creating a new user message
1316 thread.update(cx, |thread, cx| {
1317 thread.push_user_content_block(
1318 acp::ContentBlock::Text(acp::TextContent {
1319 annotations: None,
1320 text: "Hello, ".to_string(),
1321 }),
1322 cx,
1323 );
1324 });
1325
1326 thread.update(cx, |thread, cx| {
1327 assert_eq!(thread.entries.len(), 1);
1328 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1329 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1330 } else {
1331 panic!("Expected UserMessage");
1332 }
1333 });
1334
1335 // Test appending to existing user message
1336 thread.update(cx, |thread, cx| {
1337 thread.push_user_content_block(
1338 acp::ContentBlock::Text(acp::TextContent {
1339 annotations: None,
1340 text: "world!".to_string(),
1341 }),
1342 cx,
1343 );
1344 });
1345
1346 thread.update(cx, |thread, cx| {
1347 assert_eq!(thread.entries.len(), 1);
1348 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1349 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1350 } else {
1351 panic!("Expected UserMessage");
1352 }
1353 });
1354
1355 // Test creating new user message after assistant message
1356 thread.update(cx, |thread, cx| {
1357 thread.push_assistant_content_block(
1358 acp::ContentBlock::Text(acp::TextContent {
1359 annotations: None,
1360 text: "Assistant response".to_string(),
1361 }),
1362 false,
1363 cx,
1364 );
1365 });
1366
1367 thread.update(cx, |thread, cx| {
1368 thread.push_user_content_block(
1369 acp::ContentBlock::Text(acp::TextContent {
1370 annotations: None,
1371 text: "New user message".to_string(),
1372 }),
1373 cx,
1374 );
1375 });
1376
1377 thread.update(cx, |thread, cx| {
1378 assert_eq!(thread.entries.len(), 3);
1379 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1380 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1381 } else {
1382 panic!("Expected UserMessage at index 2");
1383 }
1384 });
1385 }
1386
1387 #[gpui::test]
1388 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1389 init_test(cx);
1390
1391 let fs = FakeFs::new(cx.executor());
1392 let project = Project::test(fs, [], cx).await;
1393 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1394 |_, thread, mut cx| {
1395 async move {
1396 thread.update(&mut cx, |thread, cx| {
1397 thread
1398 .handle_session_update(
1399 acp::SessionUpdate::AgentThoughtChunk {
1400 content: "Thinking ".into(),
1401 },
1402 cx,
1403 )
1404 .unwrap();
1405 thread
1406 .handle_session_update(
1407 acp::SessionUpdate::AgentThoughtChunk {
1408 content: "hard!".into(),
1409 },
1410 cx,
1411 )
1412 .unwrap();
1413 })?;
1414 Ok(acp::PromptResponse {
1415 stop_reason: acp::StopReason::EndTurn,
1416 })
1417 }
1418 .boxed_local()
1419 },
1420 ));
1421
1422 let thread = cx
1423 .spawn(async move |mut cx| {
1424 connection
1425 .new_thread(project, Path::new(path!("/test")), &mut cx)
1426 .await
1427 })
1428 .await
1429 .unwrap();
1430
1431 thread
1432 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1433 .await
1434 .unwrap();
1435
1436 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1437 assert_eq!(
1438 output,
1439 indoc! {r#"
1440 ## User
1441
1442 Hello from Zed!
1443
1444 ## Assistant
1445
1446 <thinking>
1447 Thinking hard!
1448 </thinking>
1449
1450 "#}
1451 );
1452 }
1453
1454 #[gpui::test]
1455 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1456 init_test(cx);
1457
1458 let fs = FakeFs::new(cx.executor());
1459 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1460 .await;
1461 let project = Project::test(fs.clone(), [], cx).await;
1462 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1463 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1464 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1465 move |_, thread, mut cx| {
1466 let read_file_tx = read_file_tx.clone();
1467 async move {
1468 let content = thread
1469 .update(&mut cx, |thread, cx| {
1470 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1471 })
1472 .unwrap()
1473 .await
1474 .unwrap();
1475 assert_eq!(content, "one\ntwo\nthree\n");
1476 read_file_tx.take().unwrap().send(()).unwrap();
1477 thread
1478 .update(&mut cx, |thread, cx| {
1479 thread.write_text_file(
1480 path!("/tmp/foo").into(),
1481 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1482 cx,
1483 )
1484 })
1485 .unwrap()
1486 .await
1487 .unwrap();
1488 Ok(acp::PromptResponse {
1489 stop_reason: acp::StopReason::EndTurn,
1490 })
1491 }
1492 .boxed_local()
1493 },
1494 ));
1495
1496 let (worktree, pathbuf) = project
1497 .update(cx, |project, cx| {
1498 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1499 })
1500 .await
1501 .unwrap();
1502 let buffer = project
1503 .update(cx, |project, cx| {
1504 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1505 })
1506 .await
1507 .unwrap();
1508
1509 let thread = cx
1510 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1511 .await
1512 .unwrap();
1513
1514 let request = thread.update(cx, |thread, cx| {
1515 thread.send_raw("Extend the count in /tmp/foo", cx)
1516 });
1517 read_file_rx.await.ok();
1518 buffer.update(cx, |buffer, cx| {
1519 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1520 });
1521 cx.run_until_parked();
1522 assert_eq!(
1523 buffer.read_with(cx, |buffer, _| buffer.text()),
1524 "zero\none\ntwo\nthree\nfour\nfive\n"
1525 );
1526 assert_eq!(
1527 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1528 "zero\none\ntwo\nthree\nfour\nfive\n"
1529 );
1530 request.await.unwrap();
1531 }
1532
1533 #[gpui::test]
1534 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1535 init_test(cx);
1536
1537 let fs = FakeFs::new(cx.executor());
1538 let project = Project::test(fs, [], cx).await;
1539 let id = acp::ToolCallId("test".into());
1540
1541 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1542 let id = id.clone();
1543 move |_, thread, mut cx| {
1544 let id = id.clone();
1545 async move {
1546 thread
1547 .update(&mut cx, |thread, cx| {
1548 thread.handle_session_update(
1549 acp::SessionUpdate::ToolCall(acp::ToolCall {
1550 id: id.clone(),
1551 title: "Label".into(),
1552 kind: acp::ToolKind::Fetch,
1553 status: acp::ToolCallStatus::InProgress,
1554 content: vec![],
1555 locations: vec![],
1556 raw_input: None,
1557 raw_output: None,
1558 }),
1559 cx,
1560 )
1561 })
1562 .unwrap()
1563 .unwrap();
1564 Ok(acp::PromptResponse {
1565 stop_reason: acp::StopReason::EndTurn,
1566 })
1567 }
1568 .boxed_local()
1569 }
1570 }));
1571
1572 let thread = cx
1573 .spawn(async move |mut cx| {
1574 connection
1575 .new_thread(project, Path::new(path!("/test")), &mut cx)
1576 .await
1577 })
1578 .await
1579 .unwrap();
1580
1581 let request = thread.update(cx, |thread, cx| {
1582 thread.send_raw("Fetch https://example.com", cx)
1583 });
1584
1585 run_until_first_tool_call(&thread, cx).await;
1586
1587 thread.read_with(cx, |thread, _| {
1588 assert!(matches!(
1589 thread.entries[1],
1590 AgentThreadEntry::ToolCall(ToolCall {
1591 status: ToolCallStatus::Allowed {
1592 status: acp::ToolCallStatus::InProgress,
1593 ..
1594 },
1595 ..
1596 })
1597 ));
1598 });
1599
1600 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1601
1602 thread.read_with(cx, |thread, _| {
1603 assert!(matches!(
1604 &thread.entries[1],
1605 AgentThreadEntry::ToolCall(ToolCall {
1606 status: ToolCallStatus::Canceled,
1607 ..
1608 })
1609 ));
1610 });
1611
1612 thread
1613 .update(cx, |thread, cx| {
1614 thread.handle_session_update(
1615 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1616 id,
1617 fields: acp::ToolCallUpdateFields {
1618 status: Some(acp::ToolCallStatus::Completed),
1619 ..Default::default()
1620 },
1621 }),
1622 cx,
1623 )
1624 })
1625 .unwrap();
1626
1627 request.await.unwrap();
1628
1629 thread.read_with(cx, |thread, _| {
1630 assert!(matches!(
1631 thread.entries[1],
1632 AgentThreadEntry::ToolCall(ToolCall {
1633 status: ToolCallStatus::Allowed {
1634 status: acp::ToolCallStatus::Completed,
1635 ..
1636 },
1637 ..
1638 })
1639 ));
1640 });
1641 }
1642
1643 #[gpui::test]
1644 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1645 init_test(cx);
1646 let fs = FakeFs::new(cx.background_executor.clone());
1647 fs.insert_tree(path!("/test"), json!({})).await;
1648 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1649
1650 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1651 move |_, thread, mut cx| {
1652 async move {
1653 thread
1654 .update(&mut cx, |thread, cx| {
1655 thread.handle_session_update(
1656 acp::SessionUpdate::ToolCall(acp::ToolCall {
1657 id: acp::ToolCallId("test".into()),
1658 title: "Label".into(),
1659 kind: acp::ToolKind::Edit,
1660 status: acp::ToolCallStatus::Completed,
1661 content: vec![acp::ToolCallContent::Diff {
1662 diff: acp::Diff {
1663 path: "/test/test.txt".into(),
1664 old_text: None,
1665 new_text: "foo".into(),
1666 },
1667 }],
1668 locations: vec![],
1669 raw_input: None,
1670 raw_output: None,
1671 }),
1672 cx,
1673 )
1674 })
1675 .unwrap()
1676 .unwrap();
1677 Ok(acp::PromptResponse {
1678 stop_reason: acp::StopReason::EndTurn,
1679 })
1680 }
1681 .boxed_local()
1682 }
1683 }));
1684
1685 let thread = connection
1686 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1687 .await
1688 .unwrap();
1689 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1690 .await
1691 .unwrap();
1692
1693 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1694 }
1695
1696 async fn run_until_first_tool_call(
1697 thread: &Entity<AcpThread>,
1698 cx: &mut TestAppContext,
1699 ) -> usize {
1700 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1701
1702 let subscription = cx.update(|cx| {
1703 cx.subscribe(thread, move |thread, _, cx| {
1704 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1705 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1706 return tx.try_send(ix).unwrap();
1707 }
1708 }
1709 })
1710 });
1711
1712 select! {
1713 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1714 panic!("Timeout waiting for tool call")
1715 }
1716 ix = rx.next().fuse() => {
1717 drop(subscription);
1718 ix.unwrap()
1719 }
1720 }
1721 }
1722
1723 #[derive(Clone, Default)]
1724 struct FakeAgentConnection {
1725 auth_methods: Vec<acp::AuthMethod>,
1726 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1727 on_user_message: Option<
1728 Rc<
1729 dyn Fn(
1730 acp::PromptRequest,
1731 WeakEntity<AcpThread>,
1732 AsyncApp,
1733 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1734 + 'static,
1735 >,
1736 >,
1737 }
1738
1739 impl FakeAgentConnection {
1740 fn new() -> Self {
1741 Self {
1742 auth_methods: Vec::new(),
1743 on_user_message: None,
1744 sessions: Arc::default(),
1745 }
1746 }
1747
1748 #[expect(unused)]
1749 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1750 self.auth_methods = auth_methods;
1751 self
1752 }
1753
1754 fn on_user_message(
1755 mut self,
1756 handler: impl Fn(
1757 acp::PromptRequest,
1758 WeakEntity<AcpThread>,
1759 AsyncApp,
1760 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1761 + 'static,
1762 ) -> Self {
1763 self.on_user_message.replace(Rc::new(handler));
1764 self
1765 }
1766 }
1767
1768 impl AgentConnection for FakeAgentConnection {
1769 fn auth_methods(&self) -> &[acp::AuthMethod] {
1770 &self.auth_methods
1771 }
1772
1773 fn new_thread(
1774 self: Rc<Self>,
1775 project: Entity<Project>,
1776 _cwd: &Path,
1777 cx: &mut gpui::AsyncApp,
1778 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1779 let session_id = acp::SessionId(
1780 rand::thread_rng()
1781 .sample_iter(&rand::distributions::Alphanumeric)
1782 .take(7)
1783 .map(char::from)
1784 .collect::<String>()
1785 .into(),
1786 );
1787 let thread = cx
1788 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1789 .unwrap();
1790 self.sessions.lock().insert(session_id, thread.downgrade());
1791 Task::ready(Ok(thread))
1792 }
1793
1794 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1795 if self.auth_methods().iter().any(|m| m.id == method) {
1796 Task::ready(Ok(()))
1797 } else {
1798 Task::ready(Err(anyhow!("Invalid Auth Method")))
1799 }
1800 }
1801
1802 fn prompt(
1803 &self,
1804 params: acp::PromptRequest,
1805 cx: &mut App,
1806 ) -> Task<gpui::Result<acp::PromptResponse>> {
1807 let sessions = self.sessions.lock();
1808 let thread = sessions.get(¶ms.session_id).unwrap();
1809 if let Some(handler) = &self.on_user_message {
1810 let handler = handler.clone();
1811 let thread = thread.clone();
1812 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1813 } else {
1814 Task::ready(Ok(acp::PromptResponse {
1815 stop_reason: acp::StopReason::EndTurn,
1816 }))
1817 }
1818 }
1819
1820 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1821 let sessions = self.sessions.lock();
1822 let thread = sessions.get(&session_id).unwrap().clone();
1823
1824 cx.spawn(async move |cx| {
1825 thread
1826 .update(cx, |thread, cx| thread.cancel(cx))
1827 .unwrap()
1828 .await
1829 })
1830 .detach();
1831 }
1832 }
1833}