1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::process::ExitStatus;
22use std::rc::Rc;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(chunk.into(), language_registry, cx),
114 }
115 }
116
117 fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::Message { block } => block.to_markdown(cx).to_string(),
120 Self::Thought { block } => {
121 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
128pub enum AgentThreadEntry {
129 UserMessage(UserMessage),
130 AssistantMessage(AssistantMessage),
131 ToolCall(ToolCall),
132}
133
134impl AgentThreadEntry {
135 fn to_markdown(&self, cx: &App) -> String {
136 match self {
137 Self::UserMessage(message) => message.to_markdown(cx),
138 Self::AssistantMessage(message) => message.to_markdown(cx),
139 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
140 }
141 }
142
143 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.diffs())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
152 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
153 Some(locations)
154 } else {
155 None
156 }
157 }
158}
159
160#[derive(Debug)]
161pub struct ToolCall {
162 pub id: acp::ToolCallId,
163 pub label: Entity<Markdown>,
164 pub kind: acp::ToolKind,
165 pub content: Vec<ToolCallContent>,
166 pub status: ToolCallStatus,
167 pub locations: Vec<acp::ToolCallLocation>,
168 pub raw_input: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 }
198 }
199
200 fn update(
201 &mut self,
202 fields: acp::ToolCallUpdateFields,
203 language_registry: Arc<LanguageRegistry>,
204 cx: &mut App,
205 ) {
206 let acp::ToolCallUpdateFields {
207 kind,
208 status,
209 title,
210 content,
211 locations,
212 raw_input,
213 } = fields;
214
215 if let Some(kind) = kind {
216 self.kind = kind;
217 }
218
219 if let Some(status) = status {
220 self.status = ToolCallStatus::Allowed { status };
221 }
222
223 if let Some(title) = title {
224 self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
225 }
226
227 if let Some(content) = content {
228 self.content = content
229 .into_iter()
230 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
231 .collect();
232 }
233
234 if let Some(locations) = locations {
235 self.locations = locations;
236 }
237
238 if let Some(raw_input) = raw_input {
239 self.raw_input = Some(raw_input);
240 }
241 }
242
243 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::ContentBlock { .. } => None,
246 ToolCallContent::Diff { diff } => Some(diff),
247 })
248 }
249
250 fn to_markdown(&self, cx: &App) -> String {
251 let mut markdown = format!(
252 "**Tool Call: {}**\nStatus: {}\n\n",
253 self.label.read(cx).source(),
254 self.status
255 );
256 for content in &self.content {
257 markdown.push_str(content.to_markdown(cx).as_str());
258 markdown.push_str("\n\n");
259 }
260 markdown
261 }
262}
263
264#[derive(Debug)]
265pub enum ToolCallStatus {
266 WaitingForConfirmation {
267 options: Vec<acp::PermissionOption>,
268 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
269 },
270 Allowed {
271 status: acp::ToolCallStatus,
272 },
273 Rejected,
274 Canceled,
275}
276
277impl Display for ToolCallStatus {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 write!(
280 f,
281 "{}",
282 match self {
283 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
284 ToolCallStatus::Allowed { status } => match status {
285 acp::ToolCallStatus::Pending => "Pending",
286 acp::ToolCallStatus::InProgress => "In Progress",
287 acp::ToolCallStatus::Completed => "Completed",
288 acp::ToolCallStatus::Failed => "Failed",
289 },
290 ToolCallStatus::Rejected => "Rejected",
291 ToolCallStatus::Canceled => "Canceled",
292 }
293 )
294 }
295}
296
297#[derive(Debug, PartialEq, Clone)]
298pub enum ContentBlock {
299 Empty,
300 Markdown { markdown: Entity<Markdown> },
301}
302
303impl ContentBlock {
304 pub fn new(
305 block: acp::ContentBlock,
306 language_registry: &Arc<LanguageRegistry>,
307 cx: &mut App,
308 ) -> Self {
309 let mut this = Self::Empty;
310 this.append(block, language_registry, cx);
311 this
312 }
313
314 pub fn new_combined(
315 blocks: impl IntoIterator<Item = acp::ContentBlock>,
316 language_registry: Arc<LanguageRegistry>,
317 cx: &mut App,
318 ) -> Self {
319 let mut this = Self::Empty;
320 for block in blocks {
321 this.append(block, &language_registry, cx);
322 }
323 this
324 }
325
326 pub fn append(
327 &mut self,
328 block: acp::ContentBlock,
329 language_registry: &Arc<LanguageRegistry>,
330 cx: &mut App,
331 ) {
332 let new_content = match block {
333 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
334 acp::ContentBlock::ResourceLink(resource_link) => {
335 if let Some(path) = resource_link.uri.strip_prefix("file://") {
336 format!("{}", MentionPath(path.as_ref()))
337 } else {
338 resource_link.uri.clone()
339 }
340 }
341 acp::ContentBlock::Image(_)
342 | acp::ContentBlock::Audio(_)
343 | acp::ContentBlock::Resource(_) => String::new(),
344 };
345
346 match self {
347 ContentBlock::Empty => {
348 *self = ContentBlock::Markdown {
349 markdown: cx.new(|cx| {
350 Markdown::new(
351 new_content.into(),
352 Some(language_registry.clone()),
353 None,
354 cx,
355 )
356 }),
357 };
358 }
359 ContentBlock::Markdown { markdown } => {
360 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
361 }
362 }
363 }
364
365 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
366 match self {
367 ContentBlock::Empty => "",
368 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
369 }
370 }
371
372 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
373 match self {
374 ContentBlock::Empty => None,
375 ContentBlock::Markdown { markdown } => Some(markdown),
376 }
377 }
378}
379
380#[derive(Debug)]
381pub enum ToolCallContent {
382 ContentBlock { content: ContentBlock },
383 Diff { diff: Diff },
384}
385
386impl ToolCallContent {
387 pub fn from_acp(
388 content: acp::ToolCallContent,
389 language_registry: Arc<LanguageRegistry>,
390 cx: &mut App,
391 ) -> Self {
392 match content {
393 acp::ToolCallContent::Content { content } => Self::ContentBlock {
394 content: ContentBlock::new(content, &language_registry, cx),
395 },
396 acp::ToolCallContent::Diff { diff } => Self::Diff {
397 diff: Diff::from_acp(diff, language_registry, cx),
398 },
399 }
400 }
401
402 pub fn to_markdown(&self, cx: &App) -> String {
403 match self {
404 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
405 Self::Diff { diff } => diff.to_markdown(cx),
406 }
407 }
408}
409
410#[derive(Debug)]
411pub struct Diff {
412 pub multibuffer: Entity<MultiBuffer>,
413 pub path: PathBuf,
414 _task: Task<Result<()>>,
415}
416
417impl Diff {
418 pub fn from_acp(
419 diff: acp::Diff,
420 language_registry: Arc<LanguageRegistry>,
421 cx: &mut App,
422 ) -> Self {
423 let acp::Diff {
424 path,
425 old_text,
426 new_text,
427 } = diff;
428
429 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
430
431 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
432 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
433 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
434 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
435
436 let task = cx.spawn({
437 let multibuffer = multibuffer.clone();
438 let path = path.clone();
439 async move |cx| {
440 let language = language_registry
441 .language_for_file_path(&path)
442 .await
443 .log_err();
444
445 new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
446
447 let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
448 buffer.set_language(language, cx);
449 buffer.snapshot()
450 })?;
451
452 buffer_diff
453 .update(cx, |diff, cx| {
454 diff.set_base_text(
455 old_buffer_snapshot,
456 Some(language_registry),
457 new_buffer_snapshot,
458 cx,
459 )
460 })?
461 .await?;
462
463 multibuffer
464 .update(cx, |multibuffer, cx| {
465 let hunk_ranges = {
466 let buffer = new_buffer.read(cx);
467 let diff = buffer_diff.read(cx);
468 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
469 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
470 .collect::<Vec<_>>()
471 };
472
473 multibuffer.set_excerpts_for_path(
474 PathKey::for_buffer(&new_buffer, cx),
475 new_buffer.clone(),
476 hunk_ranges,
477 editor::DEFAULT_MULTIBUFFER_CONTEXT,
478 cx,
479 );
480 multibuffer.add_diff(buffer_diff, cx);
481 })
482 .log_err();
483
484 anyhow::Ok(())
485 }
486 });
487
488 Self {
489 multibuffer,
490 path,
491 _task: task,
492 }
493 }
494
495 fn to_markdown(&self, cx: &App) -> String {
496 let buffer_text = self
497 .multibuffer
498 .read(cx)
499 .all_buffers()
500 .iter()
501 .map(|buffer| buffer.read(cx).text())
502 .join("\n");
503 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
504 }
505}
506
507#[derive(Debug, Default)]
508pub struct Plan {
509 pub entries: Vec<PlanEntry>,
510}
511
512#[derive(Debug)]
513pub struct PlanStats<'a> {
514 pub in_progress_entry: Option<&'a PlanEntry>,
515 pub pending: u32,
516 pub completed: u32,
517}
518
519impl Plan {
520 pub fn is_empty(&self) -> bool {
521 self.entries.is_empty()
522 }
523
524 pub fn stats(&self) -> PlanStats<'_> {
525 let mut stats = PlanStats {
526 in_progress_entry: None,
527 pending: 0,
528 completed: 0,
529 };
530
531 for entry in &self.entries {
532 match &entry.status {
533 acp::PlanEntryStatus::Pending => {
534 stats.pending += 1;
535 }
536 acp::PlanEntryStatus::InProgress => {
537 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
538 }
539 acp::PlanEntryStatus::Completed => {
540 stats.completed += 1;
541 }
542 }
543 }
544
545 stats
546 }
547}
548
549#[derive(Debug)]
550pub struct PlanEntry {
551 pub content: Entity<Markdown>,
552 pub priority: acp::PlanEntryPriority,
553 pub status: acp::PlanEntryStatus,
554}
555
556impl PlanEntry {
557 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
558 Self {
559 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
560 priority: entry.priority,
561 status: entry.status,
562 }
563 }
564}
565
566pub struct AcpThread {
567 title: SharedString,
568 entries: Vec<AgentThreadEntry>,
569 plan: Plan,
570 project: Entity<Project>,
571 action_log: Entity<ActionLog>,
572 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
573 send_task: Option<Task<()>>,
574 connection: Rc<dyn AgentConnection>,
575 session_id: acp::SessionId,
576}
577
578pub enum AcpThreadEvent {
579 NewEntry,
580 EntryUpdated(usize),
581 ToolAuthorizationRequired,
582 Stopped,
583 Error,
584 ServerExited(ExitStatus),
585}
586
587impl EventEmitter<AcpThreadEvent> for AcpThread {}
588
589#[derive(PartialEq, Eq)]
590pub enum ThreadStatus {
591 Idle,
592 WaitingForToolConfirmation,
593 Generating,
594}
595
596#[derive(Debug, Clone)]
597pub enum LoadError {
598 Unsupported {
599 error_message: SharedString,
600 upgrade_message: SharedString,
601 upgrade_command: String,
602 },
603 Exited(i32),
604 Other(SharedString),
605}
606
607impl Display for LoadError {
608 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
609 match self {
610 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
611 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
612 LoadError::Other(msg) => write!(f, "{}", msg),
613 }
614 }
615}
616
617impl Error for LoadError {}
618
619impl AcpThread {
620 pub fn new(
621 title: impl Into<SharedString>,
622 connection: Rc<dyn AgentConnection>,
623 project: Entity<Project>,
624 session_id: acp::SessionId,
625 cx: &mut Context<Self>,
626 ) -> Self {
627 let action_log = cx.new(|_| ActionLog::new(project.clone()));
628
629 Self {
630 action_log,
631 shared_buffers: Default::default(),
632 entries: Default::default(),
633 plan: Default::default(),
634 title: title.into(),
635 project,
636 send_task: None,
637 connection,
638 session_id,
639 }
640 }
641
642 pub fn action_log(&self) -> &Entity<ActionLog> {
643 &self.action_log
644 }
645
646 pub fn project(&self) -> &Entity<Project> {
647 &self.project
648 }
649
650 pub fn title(&self) -> SharedString {
651 self.title.clone()
652 }
653
654 pub fn entries(&self) -> &[AgentThreadEntry] {
655 &self.entries
656 }
657
658 pub fn session_id(&self) -> &acp::SessionId {
659 &self.session_id
660 }
661
662 pub fn status(&self) -> ThreadStatus {
663 if self.send_task.is_some() {
664 if self.waiting_for_tool_confirmation() {
665 ThreadStatus::WaitingForToolConfirmation
666 } else {
667 ThreadStatus::Generating
668 }
669 } else {
670 ThreadStatus::Idle
671 }
672 }
673
674 pub fn has_pending_edit_tool_calls(&self) -> bool {
675 for entry in self.entries.iter().rev() {
676 match entry {
677 AgentThreadEntry::UserMessage(_) => return false,
678 AgentThreadEntry::ToolCall(
679 call @ ToolCall {
680 status:
681 ToolCallStatus::Allowed {
682 status:
683 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
684 },
685 ..
686 },
687 ) if call.diffs().next().is_some() => {
688 return true;
689 }
690 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
691 }
692 }
693
694 false
695 }
696
697 pub fn used_tools_since_last_user_message(&self) -> bool {
698 for entry in self.entries.iter().rev() {
699 match entry {
700 AgentThreadEntry::UserMessage(..) => return false,
701 AgentThreadEntry::AssistantMessage(..) => continue,
702 AgentThreadEntry::ToolCall(..) => return true,
703 }
704 }
705
706 false
707 }
708
709 pub fn handle_session_update(
710 &mut self,
711 update: acp::SessionUpdate,
712 cx: &mut Context<Self>,
713 ) -> Result<()> {
714 match update {
715 acp::SessionUpdate::UserMessageChunk { content } => {
716 self.push_user_content_block(content, cx);
717 }
718 acp::SessionUpdate::AgentMessageChunk { content } => {
719 self.push_assistant_content_block(content, false, cx);
720 }
721 acp::SessionUpdate::AgentThoughtChunk { content } => {
722 self.push_assistant_content_block(content, true, cx);
723 }
724 acp::SessionUpdate::ToolCall(tool_call) => {
725 self.upsert_tool_call(tool_call, cx);
726 }
727 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
728 self.update_tool_call(tool_call_update, cx)?;
729 }
730 acp::SessionUpdate::Plan(plan) => {
731 self.update_plan(plan, cx);
732 }
733 }
734 Ok(())
735 }
736
737 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
738 let language_registry = self.project.read(cx).languages().clone();
739 let entries_len = self.entries.len();
740
741 if let Some(last_entry) = self.entries.last_mut()
742 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
743 {
744 content.append(chunk, &language_registry, cx);
745 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
746 } else {
747 let content = ContentBlock::new(chunk, &language_registry, cx);
748 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
749 }
750 }
751
752 pub fn push_assistant_content_block(
753 &mut self,
754 chunk: acp::ContentBlock,
755 is_thought: bool,
756 cx: &mut Context<Self>,
757 ) {
758 let language_registry = self.project.read(cx).languages().clone();
759 let entries_len = self.entries.len();
760 if let Some(last_entry) = self.entries.last_mut()
761 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
762 {
763 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
764 match (chunks.last_mut(), is_thought) {
765 (Some(AssistantMessageChunk::Message { block }), false)
766 | (Some(AssistantMessageChunk::Thought { block }), true) => {
767 block.append(chunk, &language_registry, cx)
768 }
769 _ => {
770 let block = ContentBlock::new(chunk, &language_registry, cx);
771 if is_thought {
772 chunks.push(AssistantMessageChunk::Thought { block })
773 } else {
774 chunks.push(AssistantMessageChunk::Message { block })
775 }
776 }
777 }
778 } else {
779 let block = ContentBlock::new(chunk, &language_registry, cx);
780 let chunk = if is_thought {
781 AssistantMessageChunk::Thought { block }
782 } else {
783 AssistantMessageChunk::Message { block }
784 };
785
786 self.push_entry(
787 AgentThreadEntry::AssistantMessage(AssistantMessage {
788 chunks: vec![chunk],
789 }),
790 cx,
791 );
792 }
793 }
794
795 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
796 self.entries.push(entry);
797 cx.emit(AcpThreadEvent::NewEntry);
798 }
799
800 pub fn update_tool_call(
801 &mut self,
802 update: acp::ToolCallUpdate,
803 cx: &mut Context<Self>,
804 ) -> Result<()> {
805 let languages = self.project.read(cx).languages().clone();
806
807 let (ix, current_call) = self
808 .tool_call_mut(&update.id)
809 .context("Tool call not found")?;
810 current_call.update(update.fields, languages, cx);
811
812 cx.emit(AcpThreadEvent::EntryUpdated(ix));
813
814 Ok(())
815 }
816
817 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
818 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
819 let status = ToolCallStatus::Allowed {
820 status: tool_call.status,
821 };
822 self.upsert_tool_call_inner(tool_call, status, cx)
823 }
824
825 pub fn upsert_tool_call_inner(
826 &mut self,
827 tool_call: acp::ToolCall,
828 status: ToolCallStatus,
829 cx: &mut Context<Self>,
830 ) {
831 let language_registry = self.project.read(cx).languages().clone();
832 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
833
834 let location = call.locations.last().cloned();
835
836 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
837 *current_call = call;
838
839 cx.emit(AcpThreadEvent::EntryUpdated(ix));
840 } else {
841 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
842 }
843
844 if let Some(location) = location {
845 self.set_project_location(location, cx)
846 }
847 }
848
849 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
850 // The tool call we are looking for is typically the last one, or very close to the end.
851 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
852 self.entries
853 .iter_mut()
854 .enumerate()
855 .rev()
856 .find_map(|(index, tool_call)| {
857 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
858 && &tool_call.id == id
859 {
860 Some((index, tool_call))
861 } else {
862 None
863 }
864 })
865 }
866
867 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
868 self.project.update(cx, |project, cx| {
869 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
870 return;
871 };
872 let buffer = project.open_buffer(path, cx);
873 cx.spawn(async move |project, cx| {
874 let buffer = buffer.await?;
875
876 project.update(cx, |project, cx| {
877 let position = if let Some(line) = location.line {
878 let snapshot = buffer.read(cx).snapshot();
879 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
880 snapshot.anchor_before(point)
881 } else {
882 Anchor::MIN
883 };
884
885 project.set_agent_location(
886 Some(AgentLocation {
887 buffer: buffer.downgrade(),
888 position,
889 }),
890 cx,
891 );
892 })
893 })
894 .detach_and_log_err(cx);
895 });
896 }
897
898 pub fn request_tool_call_permission(
899 &mut self,
900 tool_call: acp::ToolCall,
901 options: Vec<acp::PermissionOption>,
902 cx: &mut Context<Self>,
903 ) -> oneshot::Receiver<acp::PermissionOptionId> {
904 let (tx, rx) = oneshot::channel();
905
906 let status = ToolCallStatus::WaitingForConfirmation {
907 options,
908 respond_tx: tx,
909 };
910
911 self.upsert_tool_call_inner(tool_call, status, cx);
912 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
913 rx
914 }
915
916 pub fn authorize_tool_call(
917 &mut self,
918 id: acp::ToolCallId,
919 option_id: acp::PermissionOptionId,
920 option_kind: acp::PermissionOptionKind,
921 cx: &mut Context<Self>,
922 ) {
923 let Some((ix, call)) = self.tool_call_mut(&id) else {
924 return;
925 };
926
927 let new_status = match option_kind {
928 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
929 ToolCallStatus::Rejected
930 }
931 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
932 ToolCallStatus::Allowed {
933 status: acp::ToolCallStatus::InProgress,
934 }
935 }
936 };
937
938 let curr_status = mem::replace(&mut call.status, new_status);
939
940 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
941 respond_tx.send(option_id).log_err();
942 } else if cfg!(debug_assertions) {
943 panic!("tried to authorize an already authorized tool call");
944 }
945
946 cx.emit(AcpThreadEvent::EntryUpdated(ix));
947 }
948
949 /// Returns true if the last turn is awaiting tool authorization
950 pub fn waiting_for_tool_confirmation(&self) -> bool {
951 for entry in self.entries.iter().rev() {
952 match &entry {
953 AgentThreadEntry::ToolCall(call) => match call.status {
954 ToolCallStatus::WaitingForConfirmation { .. } => return true,
955 ToolCallStatus::Allowed { .. }
956 | ToolCallStatus::Rejected
957 | ToolCallStatus::Canceled => continue,
958 },
959 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
960 // Reached the beginning of the turn
961 return false;
962 }
963 }
964 }
965 false
966 }
967
968 pub fn plan(&self) -> &Plan {
969 &self.plan
970 }
971
972 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
973 self.plan = Plan {
974 entries: request
975 .entries
976 .into_iter()
977 .map(|entry| PlanEntry::from_acp(entry, cx))
978 .collect(),
979 };
980
981 cx.notify();
982 }
983
984 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
985 self.plan
986 .entries
987 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
988 cx.notify();
989 }
990
991 #[cfg(any(test, feature = "test-support"))]
992 pub fn send_raw(
993 &mut self,
994 message: &str,
995 cx: &mut Context<Self>,
996 ) -> BoxFuture<'static, Result<()>> {
997 self.send(
998 vec![acp::ContentBlock::Text(acp::TextContent {
999 text: message.to_string(),
1000 annotations: None,
1001 })],
1002 cx,
1003 )
1004 }
1005
1006 pub fn send(
1007 &mut self,
1008 message: Vec<acp::ContentBlock>,
1009 cx: &mut Context<Self>,
1010 ) -> BoxFuture<'static, Result<()>> {
1011 let block = ContentBlock::new_combined(
1012 message.clone(),
1013 self.project.read(cx).languages().clone(),
1014 cx,
1015 );
1016 self.push_entry(
1017 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1018 cx,
1019 );
1020 self.clear_completed_plan_entries(cx);
1021
1022 let (tx, rx) = oneshot::channel();
1023 let cancel_task = self.cancel(cx);
1024
1025 self.send_task = Some(cx.spawn(async move |this, cx| {
1026 async {
1027 cancel_task.await;
1028
1029 let result = this
1030 .update(cx, |this, cx| {
1031 this.connection.prompt(
1032 acp::PromptRequest {
1033 prompt: message,
1034 session_id: this.session_id.clone(),
1035 },
1036 cx,
1037 )
1038 })?
1039 .await;
1040 tx.send(result).log_err();
1041 this.update(cx, |this, _cx| this.send_task.take())?;
1042 anyhow::Ok(())
1043 }
1044 .await
1045 .log_err();
1046 }));
1047
1048 cx.spawn(async move |this, cx| match rx.await {
1049 Ok(Err(e)) => {
1050 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1051 .log_err();
1052 Err(e)?
1053 }
1054 _ => {
1055 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1056 .log_err();
1057 Ok(())
1058 }
1059 })
1060 .boxed()
1061 }
1062
1063 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1064 let Some(send_task) = self.send_task.take() else {
1065 return Task::ready(());
1066 };
1067
1068 for entry in self.entries.iter_mut() {
1069 if let AgentThreadEntry::ToolCall(call) = entry {
1070 let cancel = matches!(
1071 call.status,
1072 ToolCallStatus::WaitingForConfirmation { .. }
1073 | ToolCallStatus::Allowed {
1074 status: acp::ToolCallStatus::InProgress
1075 }
1076 );
1077
1078 if cancel {
1079 call.status = ToolCallStatus::Canceled;
1080 }
1081 }
1082 }
1083
1084 self.connection.cancel(&self.session_id, cx);
1085
1086 // Wait for the send task to complete
1087 cx.foreground_executor().spawn(send_task)
1088 }
1089
1090 pub fn read_text_file(
1091 &self,
1092 path: PathBuf,
1093 line: Option<u32>,
1094 limit: Option<u32>,
1095 reuse_shared_snapshot: bool,
1096 cx: &mut Context<Self>,
1097 ) -> Task<Result<String>> {
1098 let project = self.project.clone();
1099 let action_log = self.action_log.clone();
1100 cx.spawn(async move |this, cx| {
1101 let load = project.update(cx, |project, cx| {
1102 let path = project
1103 .project_path_for_absolute_path(&path, cx)
1104 .context("invalid path")?;
1105 anyhow::Ok(project.open_buffer(path, cx))
1106 });
1107 let buffer = load??.await?;
1108
1109 let snapshot = if reuse_shared_snapshot {
1110 this.read_with(cx, |this, _| {
1111 this.shared_buffers.get(&buffer.clone()).cloned()
1112 })
1113 .log_err()
1114 .flatten()
1115 } else {
1116 None
1117 };
1118
1119 let snapshot = if let Some(snapshot) = snapshot {
1120 snapshot
1121 } else {
1122 action_log.update(cx, |action_log, cx| {
1123 action_log.buffer_read(buffer.clone(), cx);
1124 })?;
1125 project.update(cx, |project, cx| {
1126 let position = buffer
1127 .read(cx)
1128 .snapshot()
1129 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1130 project.set_agent_location(
1131 Some(AgentLocation {
1132 buffer: buffer.downgrade(),
1133 position,
1134 }),
1135 cx,
1136 );
1137 })?;
1138
1139 buffer.update(cx, |buffer, _| buffer.snapshot())?
1140 };
1141
1142 this.update(cx, |this, _| {
1143 let text = snapshot.text();
1144 this.shared_buffers.insert(buffer.clone(), snapshot);
1145 if line.is_none() && limit.is_none() {
1146 return Ok(text);
1147 }
1148 let limit = limit.unwrap_or(u32::MAX) as usize;
1149 let Some(line) = line else {
1150 return Ok(text.lines().take(limit).collect::<String>());
1151 };
1152
1153 let count = text.lines().count();
1154 if count < line as usize {
1155 anyhow::bail!("There are only {} lines", count);
1156 }
1157 Ok(text
1158 .lines()
1159 .skip(line as usize + 1)
1160 .take(limit)
1161 .collect::<String>())
1162 })?
1163 })
1164 }
1165
1166 pub fn write_text_file(
1167 &self,
1168 path: PathBuf,
1169 content: String,
1170 cx: &mut Context<Self>,
1171 ) -> Task<Result<()>> {
1172 let project = self.project.clone();
1173 let action_log = self.action_log.clone();
1174 cx.spawn(async move |this, cx| {
1175 let load = project.update(cx, |project, cx| {
1176 let path = project
1177 .project_path_for_absolute_path(&path, cx)
1178 .context("invalid path")?;
1179 anyhow::Ok(project.open_buffer(path, cx))
1180 });
1181 let buffer = load??.await?;
1182 let snapshot = this.update(cx, |this, cx| {
1183 this.shared_buffers
1184 .get(&buffer)
1185 .cloned()
1186 .unwrap_or_else(|| buffer.read(cx).snapshot())
1187 })?;
1188 let edits = cx
1189 .background_executor()
1190 .spawn(async move {
1191 let old_text = snapshot.text();
1192 text_diff(old_text.as_str(), &content)
1193 .into_iter()
1194 .map(|(range, replacement)| {
1195 (
1196 snapshot.anchor_after(range.start)
1197 ..snapshot.anchor_before(range.end),
1198 replacement,
1199 )
1200 })
1201 .collect::<Vec<_>>()
1202 })
1203 .await;
1204 cx.update(|cx| {
1205 project.update(cx, |project, cx| {
1206 project.set_agent_location(
1207 Some(AgentLocation {
1208 buffer: buffer.downgrade(),
1209 position: edits
1210 .last()
1211 .map(|(range, _)| range.end)
1212 .unwrap_or(Anchor::MIN),
1213 }),
1214 cx,
1215 );
1216 });
1217
1218 action_log.update(cx, |action_log, cx| {
1219 action_log.buffer_read(buffer.clone(), cx);
1220 });
1221 buffer.update(cx, |buffer, cx| {
1222 buffer.edit(edits, None, cx);
1223 });
1224 action_log.update(cx, |action_log, cx| {
1225 action_log.buffer_edited(buffer.clone(), cx);
1226 });
1227 })?;
1228 project
1229 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1230 .await
1231 })
1232 }
1233
1234 pub fn to_markdown(&self, cx: &App) -> String {
1235 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1236 }
1237
1238 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1239 cx.emit(AcpThreadEvent::ServerExited(status));
1240 }
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245 use super::*;
1246 use anyhow::anyhow;
1247 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1248 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1249 use indoc::indoc;
1250 use project::FakeFs;
1251 use rand::Rng as _;
1252 use serde_json::json;
1253 use settings::SettingsStore;
1254 use smol::stream::StreamExt as _;
1255 use std::{cell::RefCell, rc::Rc, time::Duration};
1256
1257 use util::path;
1258
1259 fn init_test(cx: &mut TestAppContext) {
1260 env_logger::try_init().ok();
1261 cx.update(|cx| {
1262 let settings_store = SettingsStore::test(cx);
1263 cx.set_global(settings_store);
1264 Project::init_settings(cx);
1265 language::init(cx);
1266 });
1267 }
1268
1269 #[gpui::test]
1270 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1271 init_test(cx);
1272
1273 let fs = FakeFs::new(cx.executor());
1274 let project = Project::test(fs, [], cx).await;
1275 let connection = Rc::new(FakeAgentConnection::new());
1276 let thread = cx
1277 .spawn(async move |mut cx| {
1278 connection
1279 .new_thread(project, Path::new(path!("/test")), &mut cx)
1280 .await
1281 })
1282 .await
1283 .unwrap();
1284
1285 // Test creating a new user message
1286 thread.update(cx, |thread, cx| {
1287 thread.push_user_content_block(
1288 acp::ContentBlock::Text(acp::TextContent {
1289 annotations: None,
1290 text: "Hello, ".to_string(),
1291 }),
1292 cx,
1293 );
1294 });
1295
1296 thread.update(cx, |thread, cx| {
1297 assert_eq!(thread.entries.len(), 1);
1298 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1299 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1300 } else {
1301 panic!("Expected UserMessage");
1302 }
1303 });
1304
1305 // Test appending to existing user message
1306 thread.update(cx, |thread, cx| {
1307 thread.push_user_content_block(
1308 acp::ContentBlock::Text(acp::TextContent {
1309 annotations: None,
1310 text: "world!".to_string(),
1311 }),
1312 cx,
1313 );
1314 });
1315
1316 thread.update(cx, |thread, cx| {
1317 assert_eq!(thread.entries.len(), 1);
1318 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1319 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1320 } else {
1321 panic!("Expected UserMessage");
1322 }
1323 });
1324
1325 // Test creating new user message after assistant message
1326 thread.update(cx, |thread, cx| {
1327 thread.push_assistant_content_block(
1328 acp::ContentBlock::Text(acp::TextContent {
1329 annotations: None,
1330 text: "Assistant response".to_string(),
1331 }),
1332 false,
1333 cx,
1334 );
1335 });
1336
1337 thread.update(cx, |thread, cx| {
1338 thread.push_user_content_block(
1339 acp::ContentBlock::Text(acp::TextContent {
1340 annotations: None,
1341 text: "New user message".to_string(),
1342 }),
1343 cx,
1344 );
1345 });
1346
1347 thread.update(cx, |thread, cx| {
1348 assert_eq!(thread.entries.len(), 3);
1349 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1350 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1351 } else {
1352 panic!("Expected UserMessage at index 2");
1353 }
1354 });
1355 }
1356
1357 #[gpui::test]
1358 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1359 init_test(cx);
1360
1361 let fs = FakeFs::new(cx.executor());
1362 let project = Project::test(fs, [], cx).await;
1363 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1364 |_, thread, mut cx| {
1365 async move {
1366 thread.update(&mut cx, |thread, cx| {
1367 thread
1368 .handle_session_update(
1369 acp::SessionUpdate::AgentThoughtChunk {
1370 content: "Thinking ".into(),
1371 },
1372 cx,
1373 )
1374 .unwrap();
1375 thread
1376 .handle_session_update(
1377 acp::SessionUpdate::AgentThoughtChunk {
1378 content: "hard!".into(),
1379 },
1380 cx,
1381 )
1382 .unwrap();
1383 })?;
1384 Ok(acp::PromptResponse {
1385 stop_reason: acp::StopReason::EndTurn,
1386 })
1387 }
1388 .boxed_local()
1389 },
1390 ));
1391
1392 let thread = cx
1393 .spawn(async move |mut cx| {
1394 connection
1395 .new_thread(project, Path::new(path!("/test")), &mut cx)
1396 .await
1397 })
1398 .await
1399 .unwrap();
1400
1401 thread
1402 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1403 .await
1404 .unwrap();
1405
1406 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1407 assert_eq!(
1408 output,
1409 indoc! {r#"
1410 ## User
1411
1412 Hello from Zed!
1413
1414 ## Assistant
1415
1416 <thinking>
1417 Thinking hard!
1418 </thinking>
1419
1420 "#}
1421 );
1422 }
1423
1424 #[gpui::test]
1425 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1426 init_test(cx);
1427
1428 let fs = FakeFs::new(cx.executor());
1429 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1430 .await;
1431 let project = Project::test(fs.clone(), [], cx).await;
1432 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1433 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1434 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1435 move |_, thread, mut cx| {
1436 let read_file_tx = read_file_tx.clone();
1437 async move {
1438 let content = thread
1439 .update(&mut cx, |thread, cx| {
1440 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1441 })
1442 .unwrap()
1443 .await
1444 .unwrap();
1445 assert_eq!(content, "one\ntwo\nthree\n");
1446 read_file_tx.take().unwrap().send(()).unwrap();
1447 thread
1448 .update(&mut cx, |thread, cx| {
1449 thread.write_text_file(
1450 path!("/tmp/foo").into(),
1451 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1452 cx,
1453 )
1454 })
1455 .unwrap()
1456 .await
1457 .unwrap();
1458 Ok(acp::PromptResponse {
1459 stop_reason: acp::StopReason::EndTurn,
1460 })
1461 }
1462 .boxed_local()
1463 },
1464 ));
1465
1466 let (worktree, pathbuf) = project
1467 .update(cx, |project, cx| {
1468 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1469 })
1470 .await
1471 .unwrap();
1472 let buffer = project
1473 .update(cx, |project, cx| {
1474 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1475 })
1476 .await
1477 .unwrap();
1478
1479 let thread = cx
1480 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1481 .await
1482 .unwrap();
1483
1484 let request = thread.update(cx, |thread, cx| {
1485 thread.send_raw("Extend the count in /tmp/foo", cx)
1486 });
1487 read_file_rx.await.ok();
1488 buffer.update(cx, |buffer, cx| {
1489 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1490 });
1491 cx.run_until_parked();
1492 assert_eq!(
1493 buffer.read_with(cx, |buffer, _| buffer.text()),
1494 "zero\none\ntwo\nthree\nfour\nfive\n"
1495 );
1496 assert_eq!(
1497 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1498 "zero\none\ntwo\nthree\nfour\nfive\n"
1499 );
1500 request.await.unwrap();
1501 }
1502
1503 #[gpui::test]
1504 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1505 init_test(cx);
1506
1507 let fs = FakeFs::new(cx.executor());
1508 let project = Project::test(fs, [], cx).await;
1509 let id = acp::ToolCallId("test".into());
1510
1511 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1512 let id = id.clone();
1513 move |_, thread, mut cx| {
1514 let id = id.clone();
1515 async move {
1516 thread
1517 .update(&mut cx, |thread, cx| {
1518 thread.handle_session_update(
1519 acp::SessionUpdate::ToolCall(acp::ToolCall {
1520 id: id.clone(),
1521 title: "Label".into(),
1522 kind: acp::ToolKind::Fetch,
1523 status: acp::ToolCallStatus::InProgress,
1524 content: vec![],
1525 locations: vec![],
1526 raw_input: None,
1527 }),
1528 cx,
1529 )
1530 })
1531 .unwrap()
1532 .unwrap();
1533 Ok(acp::PromptResponse {
1534 stop_reason: acp::StopReason::EndTurn,
1535 })
1536 }
1537 .boxed_local()
1538 }
1539 }));
1540
1541 let thread = cx
1542 .spawn(async move |mut cx| {
1543 connection
1544 .new_thread(project, Path::new(path!("/test")), &mut cx)
1545 .await
1546 })
1547 .await
1548 .unwrap();
1549
1550 let request = thread.update(cx, |thread, cx| {
1551 thread.send_raw("Fetch https://example.com", cx)
1552 });
1553
1554 run_until_first_tool_call(&thread, cx).await;
1555
1556 thread.read_with(cx, |thread, _| {
1557 assert!(matches!(
1558 thread.entries[1],
1559 AgentThreadEntry::ToolCall(ToolCall {
1560 status: ToolCallStatus::Allowed {
1561 status: acp::ToolCallStatus::InProgress,
1562 ..
1563 },
1564 ..
1565 })
1566 ));
1567 });
1568
1569 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1570
1571 thread.read_with(cx, |thread, _| {
1572 assert!(matches!(
1573 &thread.entries[1],
1574 AgentThreadEntry::ToolCall(ToolCall {
1575 status: ToolCallStatus::Canceled,
1576 ..
1577 })
1578 ));
1579 });
1580
1581 thread
1582 .update(cx, |thread, cx| {
1583 thread.handle_session_update(
1584 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1585 id,
1586 fields: acp::ToolCallUpdateFields {
1587 status: Some(acp::ToolCallStatus::Completed),
1588 ..Default::default()
1589 },
1590 }),
1591 cx,
1592 )
1593 })
1594 .unwrap();
1595
1596 request.await.unwrap();
1597
1598 thread.read_with(cx, |thread, _| {
1599 assert!(matches!(
1600 thread.entries[1],
1601 AgentThreadEntry::ToolCall(ToolCall {
1602 status: ToolCallStatus::Allowed {
1603 status: acp::ToolCallStatus::Completed,
1604 ..
1605 },
1606 ..
1607 })
1608 ));
1609 });
1610 }
1611
1612 #[gpui::test]
1613 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1614 init_test(cx);
1615 let fs = FakeFs::new(cx.background_executor.clone());
1616 fs.insert_tree(path!("/test"), json!({})).await;
1617 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1618
1619 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1620 move |_, thread, mut cx| {
1621 async move {
1622 thread
1623 .update(&mut cx, |thread, cx| {
1624 thread.handle_session_update(
1625 acp::SessionUpdate::ToolCall(acp::ToolCall {
1626 id: acp::ToolCallId("test".into()),
1627 title: "Label".into(),
1628 kind: acp::ToolKind::Edit,
1629 status: acp::ToolCallStatus::Completed,
1630 content: vec![acp::ToolCallContent::Diff {
1631 diff: acp::Diff {
1632 path: "/test/test.txt".into(),
1633 old_text: None,
1634 new_text: "foo".into(),
1635 },
1636 }],
1637 locations: vec![],
1638 raw_input: None,
1639 }),
1640 cx,
1641 )
1642 })
1643 .unwrap()
1644 .unwrap();
1645 Ok(acp::PromptResponse {
1646 stop_reason: acp::StopReason::EndTurn,
1647 })
1648 }
1649 .boxed_local()
1650 }
1651 }));
1652
1653 let thread = connection
1654 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1655 .await
1656 .unwrap();
1657 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1658 .await
1659 .unwrap();
1660
1661 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1662 }
1663
1664 async fn run_until_first_tool_call(
1665 thread: &Entity<AcpThread>,
1666 cx: &mut TestAppContext,
1667 ) -> usize {
1668 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1669
1670 let subscription = cx.update(|cx| {
1671 cx.subscribe(thread, move |thread, _, cx| {
1672 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1673 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1674 return tx.try_send(ix).unwrap();
1675 }
1676 }
1677 })
1678 });
1679
1680 select! {
1681 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1682 panic!("Timeout waiting for tool call")
1683 }
1684 ix = rx.next().fuse() => {
1685 drop(subscription);
1686 ix.unwrap()
1687 }
1688 }
1689 }
1690
1691 #[derive(Clone, Default)]
1692 struct FakeAgentConnection {
1693 auth_methods: Vec<acp::AuthMethod>,
1694 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1695 on_user_message: Option<
1696 Rc<
1697 dyn Fn(
1698 acp::PromptRequest,
1699 WeakEntity<AcpThread>,
1700 AsyncApp,
1701 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1702 + 'static,
1703 >,
1704 >,
1705 }
1706
1707 impl FakeAgentConnection {
1708 fn new() -> Self {
1709 Self {
1710 auth_methods: Vec::new(),
1711 on_user_message: None,
1712 sessions: Arc::default(),
1713 }
1714 }
1715
1716 #[expect(unused)]
1717 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1718 self.auth_methods = auth_methods;
1719 self
1720 }
1721
1722 fn on_user_message(
1723 mut self,
1724 handler: impl Fn(
1725 acp::PromptRequest,
1726 WeakEntity<AcpThread>,
1727 AsyncApp,
1728 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1729 + 'static,
1730 ) -> Self {
1731 self.on_user_message.replace(Rc::new(handler));
1732 self
1733 }
1734 }
1735
1736 impl AgentConnection for FakeAgentConnection {
1737 fn auth_methods(&self) -> &[acp::AuthMethod] {
1738 &self.auth_methods
1739 }
1740
1741 fn new_thread(
1742 self: Rc<Self>,
1743 project: Entity<Project>,
1744 _cwd: &Path,
1745 cx: &mut gpui::AsyncApp,
1746 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1747 let session_id = acp::SessionId(
1748 rand::thread_rng()
1749 .sample_iter(&rand::distributions::Alphanumeric)
1750 .take(7)
1751 .map(char::from)
1752 .collect::<String>()
1753 .into(),
1754 );
1755 let thread = cx
1756 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1757 .unwrap();
1758 self.sessions.lock().insert(session_id, thread.downgrade());
1759 Task::ready(Ok(thread))
1760 }
1761
1762 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1763 if self.auth_methods().iter().any(|m| m.id == method) {
1764 Task::ready(Ok(()))
1765 } else {
1766 Task::ready(Err(anyhow!("Invalid Auth Method")))
1767 }
1768 }
1769
1770 fn prompt(
1771 &self,
1772 params: acp::PromptRequest,
1773 cx: &mut App,
1774 ) -> Task<gpui::Result<acp::PromptResponse>> {
1775 let sessions = self.sessions.lock();
1776 let thread = sessions.get(¶ms.session_id).unwrap();
1777 if let Some(handler) = &self.on_user_message {
1778 let handler = handler.clone();
1779 let thread = thread.clone();
1780 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1781 } else {
1782 Task::ready(Ok(acp::PromptResponse {
1783 stop_reason: acp::StopReason::EndTurn,
1784 }))
1785 }
1786 }
1787
1788 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1789 let sessions = self.sessions.lock();
1790 let thread = sessions.get(&session_id).unwrap().clone();
1791
1792 cx.spawn(async move |cx| {
1793 thread
1794 .update(cx, |thread, cx| thread.cancel(cx))
1795 .unwrap()
1796 .await
1797 })
1798 .detach();
1799 }
1800 }
1801}