1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use itertools::Itertools;
12use language::{
13 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
14 text_diff,
15};
16use markdown::Markdown;
17use project::{AgentLocation, Project};
18use std::collections::HashMap;
19use std::error::Error;
20use std::fmt::Formatter;
21use std::process::ExitStatus;
22use std::rc::Rc;
23use std::{
24 fmt::Display,
25 mem,
26 path::{Path, PathBuf},
27 sync::Arc,
28};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub content: ContentBlock,
35}
36
37impl UserMessage {
38 pub fn from_acp(
39 message: impl IntoIterator<Item = acp::ContentBlock>,
40 language_registry: Arc<LanguageRegistry>,
41 cx: &mut App,
42 ) -> Self {
43 let mut content = ContentBlock::Empty;
44 for chunk in message {
45 content.append(chunk, &language_registry, cx)
46 }
47 Self { content: content }
48 }
49
50 fn to_markdown(&self, cx: &App) -> String {
51 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
52 }
53}
54
55#[derive(Debug)]
56pub struct MentionPath<'a>(&'a Path);
57
58impl<'a> MentionPath<'a> {
59 const PREFIX: &'static str = "@file:";
60
61 pub fn new(path: &'a Path) -> Self {
62 MentionPath(path)
63 }
64
65 pub fn try_parse(url: &'a str) -> Option<Self> {
66 let path = url.strip_prefix(Self::PREFIX)?;
67 Some(MentionPath(Path::new(path)))
68 }
69
70 pub fn path(&self) -> &Path {
71 self.0
72 }
73}
74
75impl Display for MentionPath<'_> {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(
78 f,
79 "[@{}]({}{})",
80 self.0.file_name().unwrap_or_default().display(),
81 Self::PREFIX,
82 self.0.display()
83 )
84 }
85}
86
87#[derive(Debug, PartialEq)]
88pub struct AssistantMessage {
89 pub chunks: Vec<AssistantMessageChunk>,
90}
91
92impl AssistantMessage {
93 pub fn to_markdown(&self, cx: &App) -> String {
94 format!(
95 "## Assistant\n\n{}\n\n",
96 self.chunks
97 .iter()
98 .map(|chunk| chunk.to_markdown(cx))
99 .join("\n\n")
100 )
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub enum AssistantMessageChunk {
106 Message { block: ContentBlock },
107 Thought { block: ContentBlock },
108}
109
110impl AssistantMessageChunk {
111 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
112 Self::Message {
113 block: ContentBlock::new(chunk.into(), language_registry, cx),
114 }
115 }
116
117 fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::Message { block } => block.to_markdown(cx).to_string(),
120 Self::Thought { block } => {
121 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
128pub enum AgentThreadEntry {
129 UserMessage(UserMessage),
130 AssistantMessage(AssistantMessage),
131 ToolCall(ToolCall),
132}
133
134impl AgentThreadEntry {
135 fn to_markdown(&self, cx: &App) -> String {
136 match self {
137 Self::UserMessage(message) => message.to_markdown(cx),
138 Self::AssistantMessage(message) => message.to_markdown(cx),
139 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
140 }
141 }
142
143 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.diffs())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
152 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
153 Some(locations)
154 } else {
155 None
156 }
157 }
158}
159
160#[derive(Debug)]
161pub struct ToolCall {
162 pub id: acp::ToolCallId,
163 pub label: Entity<Markdown>,
164 pub kind: acp::ToolKind,
165 pub content: Vec<ToolCallContent>,
166 pub status: ToolCallStatus,
167 pub locations: Vec<acp::ToolCallLocation>,
168 pub raw_input: Option<serde_json::Value>,
169 pub raw_output: Option<serde_json::Value>,
170}
171
172impl ToolCall {
173 fn from_acp(
174 tool_call: acp::ToolCall,
175 status: ToolCallStatus,
176 language_registry: Arc<LanguageRegistry>,
177 cx: &mut App,
178 ) -> Self {
179 Self {
180 id: tool_call.id,
181 label: cx.new(|cx| {
182 Markdown::new(
183 tool_call.title.into(),
184 Some(language_registry.clone()),
185 None,
186 cx,
187 )
188 }),
189 kind: tool_call.kind,
190 content: tool_call
191 .content
192 .into_iter()
193 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
194 .collect(),
195 locations: tool_call.locations,
196 status,
197 raw_input: tool_call.raw_input,
198 raw_output: tool_call.raw_output,
199 }
200 }
201
202 fn update(
203 &mut self,
204 fields: acp::ToolCallUpdateFields,
205 language_registry: Arc<LanguageRegistry>,
206 cx: &mut App,
207 ) {
208 let acp::ToolCallUpdateFields {
209 kind,
210 status,
211 title,
212 content,
213 locations,
214 raw_input,
215 raw_output,
216 } = fields;
217
218 if let Some(kind) = kind {
219 self.kind = kind;
220 }
221
222 if let Some(status) = status {
223 self.status = ToolCallStatus::Allowed { status };
224 }
225
226 if let Some(title) = title {
227 self.label.update(cx, |label, cx| {
228 label.replace(title, cx);
229 });
230 }
231
232 if let Some(content) = content {
233 self.content = content
234 .into_iter()
235 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
236 .collect();
237 }
238
239 if let Some(locations) = locations {
240 self.locations = locations;
241 }
242
243 if let Some(raw_input) = raw_input {
244 self.raw_input = Some(raw_input);
245 }
246
247 if let Some(raw_output) = raw_output {
248 self.raw_output = Some(raw_output);
249 }
250 }
251
252 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
253 self.content.iter().filter_map(|content| match content {
254 ToolCallContent::ContentBlock { .. } => None,
255 ToolCallContent::Diff { diff } => Some(diff),
256 })
257 }
258
259 fn to_markdown(&self, cx: &App) -> String {
260 let mut markdown = format!(
261 "**Tool Call: {}**\nStatus: {}\n\n",
262 self.label.read(cx).source(),
263 self.status
264 );
265 for content in &self.content {
266 markdown.push_str(content.to_markdown(cx).as_str());
267 markdown.push_str("\n\n");
268 }
269 markdown
270 }
271}
272
273#[derive(Debug)]
274pub enum ToolCallStatus {
275 WaitingForConfirmation {
276 options: Vec<acp::PermissionOption>,
277 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
278 },
279 Allowed {
280 status: acp::ToolCallStatus,
281 },
282 Rejected,
283 Canceled,
284}
285
286impl Display for ToolCallStatus {
287 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
288 write!(
289 f,
290 "{}",
291 match self {
292 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
293 ToolCallStatus::Allowed { status } => match status {
294 acp::ToolCallStatus::Pending => "Pending",
295 acp::ToolCallStatus::InProgress => "In Progress",
296 acp::ToolCallStatus::Completed => "Completed",
297 acp::ToolCallStatus::Failed => "Failed",
298 },
299 ToolCallStatus::Rejected => "Rejected",
300 ToolCallStatus::Canceled => "Canceled",
301 }
302 )
303 }
304}
305
306#[derive(Debug, PartialEq, Clone)]
307pub enum ContentBlock {
308 Empty,
309 Markdown { markdown: Entity<Markdown> },
310}
311
312impl ContentBlock {
313 pub fn new(
314 block: acp::ContentBlock,
315 language_registry: &Arc<LanguageRegistry>,
316 cx: &mut App,
317 ) -> Self {
318 let mut this = Self::Empty;
319 this.append(block, language_registry, cx);
320 this
321 }
322
323 pub fn new_combined(
324 blocks: impl IntoIterator<Item = acp::ContentBlock>,
325 language_registry: Arc<LanguageRegistry>,
326 cx: &mut App,
327 ) -> Self {
328 let mut this = Self::Empty;
329 for block in blocks {
330 this.append(block, &language_registry, cx);
331 }
332 this
333 }
334
335 pub fn append(
336 &mut self,
337 block: acp::ContentBlock,
338 language_registry: &Arc<LanguageRegistry>,
339 cx: &mut App,
340 ) {
341 let new_content = match block {
342 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
343 acp::ContentBlock::ResourceLink(resource_link) => {
344 if let Some(path) = resource_link.uri.strip_prefix("file://") {
345 format!("{}", MentionPath(path.as_ref()))
346 } else {
347 resource_link.uri.clone()
348 }
349 }
350 acp::ContentBlock::Image(_)
351 | acp::ContentBlock::Audio(_)
352 | acp::ContentBlock::Resource(_) => String::new(),
353 };
354
355 match self {
356 ContentBlock::Empty => {
357 *self = ContentBlock::Markdown {
358 markdown: cx.new(|cx| {
359 Markdown::new(
360 new_content.into(),
361 Some(language_registry.clone()),
362 None,
363 cx,
364 )
365 }),
366 };
367 }
368 ContentBlock::Markdown { markdown } => {
369 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
370 }
371 }
372 }
373
374 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
375 match self {
376 ContentBlock::Empty => "",
377 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
378 }
379 }
380
381 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
382 match self {
383 ContentBlock::Empty => None,
384 ContentBlock::Markdown { markdown } => Some(markdown),
385 }
386 }
387}
388
389#[derive(Debug)]
390pub enum ToolCallContent {
391 ContentBlock { content: ContentBlock },
392 Diff { diff: Diff },
393}
394
395impl ToolCallContent {
396 pub fn from_acp(
397 content: acp::ToolCallContent,
398 language_registry: Arc<LanguageRegistry>,
399 cx: &mut App,
400 ) -> Self {
401 match content {
402 acp::ToolCallContent::Content { content } => Self::ContentBlock {
403 content: ContentBlock::new(content, &language_registry, cx),
404 },
405 acp::ToolCallContent::Diff { diff } => Self::Diff {
406 diff: Diff::from_acp(diff, language_registry, cx),
407 },
408 }
409 }
410
411 pub fn to_markdown(&self, cx: &App) -> String {
412 match self {
413 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
414 Self::Diff { diff } => diff.to_markdown(cx),
415 }
416 }
417}
418
419#[derive(Debug)]
420pub struct Diff {
421 pub multibuffer: Entity<MultiBuffer>,
422 pub path: PathBuf,
423 _task: Task<Result<()>>,
424}
425
426impl Diff {
427 pub fn from_acp(
428 diff: acp::Diff,
429 language_registry: Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) -> Self {
432 let acp::Diff {
433 path,
434 old_text,
435 new_text,
436 } = diff;
437
438 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
439
440 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
441 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
442 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
443 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
444
445 let task = cx.spawn({
446 let multibuffer = multibuffer.clone();
447 let path = path.clone();
448 async move |cx| {
449 let language = language_registry
450 .language_for_file_path(&path)
451 .await
452 .log_err();
453
454 new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
455
456 let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
457 buffer.set_language(language, cx);
458 buffer.snapshot()
459 })?;
460
461 buffer_diff
462 .update(cx, |diff, cx| {
463 diff.set_base_text(
464 old_buffer_snapshot,
465 Some(language_registry),
466 new_buffer_snapshot,
467 cx,
468 )
469 })?
470 .await?;
471
472 multibuffer
473 .update(cx, |multibuffer, cx| {
474 let hunk_ranges = {
475 let buffer = new_buffer.read(cx);
476 let diff = buffer_diff.read(cx);
477 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
478 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
479 .collect::<Vec<_>>()
480 };
481
482 multibuffer.set_excerpts_for_path(
483 PathKey::for_buffer(&new_buffer, cx),
484 new_buffer.clone(),
485 hunk_ranges,
486 editor::DEFAULT_MULTIBUFFER_CONTEXT,
487 cx,
488 );
489 multibuffer.add_diff(buffer_diff, cx);
490 })
491 .log_err();
492
493 anyhow::Ok(())
494 }
495 });
496
497 Self {
498 multibuffer,
499 path,
500 _task: task,
501 }
502 }
503
504 fn to_markdown(&self, cx: &App) -> String {
505 let buffer_text = self
506 .multibuffer
507 .read(cx)
508 .all_buffers()
509 .iter()
510 .map(|buffer| buffer.read(cx).text())
511 .join("\n");
512 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
513 }
514}
515
516#[derive(Debug, Default)]
517pub struct Plan {
518 pub entries: Vec<PlanEntry>,
519}
520
521#[derive(Debug)]
522pub struct PlanStats<'a> {
523 pub in_progress_entry: Option<&'a PlanEntry>,
524 pub pending: u32,
525 pub completed: u32,
526}
527
528impl Plan {
529 pub fn is_empty(&self) -> bool {
530 self.entries.is_empty()
531 }
532
533 pub fn stats(&self) -> PlanStats<'_> {
534 let mut stats = PlanStats {
535 in_progress_entry: None,
536 pending: 0,
537 completed: 0,
538 };
539
540 for entry in &self.entries {
541 match &entry.status {
542 acp::PlanEntryStatus::Pending => {
543 stats.pending += 1;
544 }
545 acp::PlanEntryStatus::InProgress => {
546 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
547 }
548 acp::PlanEntryStatus::Completed => {
549 stats.completed += 1;
550 }
551 }
552 }
553
554 stats
555 }
556}
557
558#[derive(Debug)]
559pub struct PlanEntry {
560 pub content: Entity<Markdown>,
561 pub priority: acp::PlanEntryPriority,
562 pub status: acp::PlanEntryStatus,
563}
564
565impl PlanEntry {
566 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
567 Self {
568 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
569 priority: entry.priority,
570 status: entry.status,
571 }
572 }
573}
574
575pub struct AcpThread {
576 title: SharedString,
577 entries: Vec<AgentThreadEntry>,
578 plan: Plan,
579 project: Entity<Project>,
580 action_log: Entity<ActionLog>,
581 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
582 send_task: Option<Task<()>>,
583 connection: Rc<dyn AgentConnection>,
584 session_id: acp::SessionId,
585}
586
587pub enum AcpThreadEvent {
588 NewEntry,
589 EntryUpdated(usize),
590 ToolAuthorizationRequired,
591 Stopped,
592 Error,
593 ServerExited(ExitStatus),
594}
595
596impl EventEmitter<AcpThreadEvent> for AcpThread {}
597
598#[derive(PartialEq, Eq)]
599pub enum ThreadStatus {
600 Idle,
601 WaitingForToolConfirmation,
602 Generating,
603}
604
605#[derive(Debug, Clone)]
606pub enum LoadError {
607 Unsupported {
608 error_message: SharedString,
609 upgrade_message: SharedString,
610 upgrade_command: String,
611 },
612 Exited(i32),
613 Other(SharedString),
614}
615
616impl Display for LoadError {
617 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
618 match self {
619 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
620 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
621 LoadError::Other(msg) => write!(f, "{}", msg),
622 }
623 }
624}
625
626impl Error for LoadError {}
627
628impl AcpThread {
629 pub fn new(
630 title: impl Into<SharedString>,
631 connection: Rc<dyn AgentConnection>,
632 project: Entity<Project>,
633 session_id: acp::SessionId,
634 cx: &mut Context<Self>,
635 ) -> Self {
636 let action_log = cx.new(|_| ActionLog::new(project.clone()));
637
638 Self {
639 action_log,
640 shared_buffers: Default::default(),
641 entries: Default::default(),
642 plan: Default::default(),
643 title: title.into(),
644 project,
645 send_task: None,
646 connection,
647 session_id,
648 }
649 }
650
651 pub fn action_log(&self) -> &Entity<ActionLog> {
652 &self.action_log
653 }
654
655 pub fn project(&self) -> &Entity<Project> {
656 &self.project
657 }
658
659 pub fn title(&self) -> SharedString {
660 self.title.clone()
661 }
662
663 pub fn entries(&self) -> &[AgentThreadEntry] {
664 &self.entries
665 }
666
667 pub fn session_id(&self) -> &acp::SessionId {
668 &self.session_id
669 }
670
671 pub fn status(&self) -> ThreadStatus {
672 if self.send_task.is_some() {
673 if self.waiting_for_tool_confirmation() {
674 ThreadStatus::WaitingForToolConfirmation
675 } else {
676 ThreadStatus::Generating
677 }
678 } else {
679 ThreadStatus::Idle
680 }
681 }
682
683 pub fn has_pending_edit_tool_calls(&self) -> bool {
684 for entry in self.entries.iter().rev() {
685 match entry {
686 AgentThreadEntry::UserMessage(_) => return false,
687 AgentThreadEntry::ToolCall(
688 call @ ToolCall {
689 status:
690 ToolCallStatus::Allowed {
691 status:
692 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
693 },
694 ..
695 },
696 ) if call.diffs().next().is_some() => {
697 return true;
698 }
699 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
700 }
701 }
702
703 false
704 }
705
706 pub fn used_tools_since_last_user_message(&self) -> bool {
707 for entry in self.entries.iter().rev() {
708 match entry {
709 AgentThreadEntry::UserMessage(..) => return false,
710 AgentThreadEntry::AssistantMessage(..) => continue,
711 AgentThreadEntry::ToolCall(..) => return true,
712 }
713 }
714
715 false
716 }
717
718 pub fn handle_session_update(
719 &mut self,
720 update: acp::SessionUpdate,
721 cx: &mut Context<Self>,
722 ) -> Result<()> {
723 match update {
724 acp::SessionUpdate::UserMessageChunk { content } => {
725 self.push_user_content_block(content, cx);
726 }
727 acp::SessionUpdate::AgentMessageChunk { content } => {
728 self.push_assistant_content_block(content, false, cx);
729 }
730 acp::SessionUpdate::AgentThoughtChunk { content } => {
731 self.push_assistant_content_block(content, true, cx);
732 }
733 acp::SessionUpdate::ToolCall(tool_call) => {
734 self.upsert_tool_call(tool_call, cx);
735 }
736 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
737 self.update_tool_call(tool_call_update, cx)?;
738 }
739 acp::SessionUpdate::Plan(plan) => {
740 self.update_plan(plan, cx);
741 }
742 }
743 Ok(())
744 }
745
746 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
747 let language_registry = self.project.read(cx).languages().clone();
748 let entries_len = self.entries.len();
749
750 if let Some(last_entry) = self.entries.last_mut()
751 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
752 {
753 content.append(chunk, &language_registry, cx);
754 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
755 } else {
756 let content = ContentBlock::new(chunk, &language_registry, cx);
757 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
758 }
759 }
760
761 pub fn push_assistant_content_block(
762 &mut self,
763 chunk: acp::ContentBlock,
764 is_thought: bool,
765 cx: &mut Context<Self>,
766 ) {
767 let language_registry = self.project.read(cx).languages().clone();
768 let entries_len = self.entries.len();
769 if let Some(last_entry) = self.entries.last_mut()
770 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
771 {
772 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
773 match (chunks.last_mut(), is_thought) {
774 (Some(AssistantMessageChunk::Message { block }), false)
775 | (Some(AssistantMessageChunk::Thought { block }), true) => {
776 block.append(chunk, &language_registry, cx)
777 }
778 _ => {
779 let block = ContentBlock::new(chunk, &language_registry, cx);
780 if is_thought {
781 chunks.push(AssistantMessageChunk::Thought { block })
782 } else {
783 chunks.push(AssistantMessageChunk::Message { block })
784 }
785 }
786 }
787 } else {
788 let block = ContentBlock::new(chunk, &language_registry, cx);
789 let chunk = if is_thought {
790 AssistantMessageChunk::Thought { block }
791 } else {
792 AssistantMessageChunk::Message { block }
793 };
794
795 self.push_entry(
796 AgentThreadEntry::AssistantMessage(AssistantMessage {
797 chunks: vec![chunk],
798 }),
799 cx,
800 );
801 }
802 }
803
804 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
805 self.entries.push(entry);
806 cx.emit(AcpThreadEvent::NewEntry);
807 }
808
809 pub fn update_tool_call(
810 &mut self,
811 update: acp::ToolCallUpdate,
812 cx: &mut Context<Self>,
813 ) -> Result<()> {
814 let languages = self.project.read(cx).languages().clone();
815
816 let (ix, current_call) = self
817 .tool_call_mut(&update.id)
818 .context("Tool call not found")?;
819 current_call.update(update.fields, languages, cx);
820
821 cx.emit(AcpThreadEvent::EntryUpdated(ix));
822
823 Ok(())
824 }
825
826 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
827 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
828 let status = ToolCallStatus::Allowed {
829 status: tool_call.status,
830 };
831 self.upsert_tool_call_inner(tool_call, status, cx)
832 }
833
834 pub fn upsert_tool_call_inner(
835 &mut self,
836 tool_call: acp::ToolCall,
837 status: ToolCallStatus,
838 cx: &mut Context<Self>,
839 ) {
840 let language_registry = self.project.read(cx).languages().clone();
841 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
842
843 let location = call.locations.last().cloned();
844
845 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
846 *current_call = call;
847
848 cx.emit(AcpThreadEvent::EntryUpdated(ix));
849 } else {
850 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
851 }
852
853 if let Some(location) = location {
854 self.set_project_location(location, cx)
855 }
856 }
857
858 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
859 // The tool call we are looking for is typically the last one, or very close to the end.
860 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
861 self.entries
862 .iter_mut()
863 .enumerate()
864 .rev()
865 .find_map(|(index, tool_call)| {
866 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
867 && &tool_call.id == id
868 {
869 Some((index, tool_call))
870 } else {
871 None
872 }
873 })
874 }
875
876 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
877 self.project.update(cx, |project, cx| {
878 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
879 return;
880 };
881 let buffer = project.open_buffer(path, cx);
882 cx.spawn(async move |project, cx| {
883 let buffer = buffer.await?;
884
885 project.update(cx, |project, cx| {
886 let position = if let Some(line) = location.line {
887 let snapshot = buffer.read(cx).snapshot();
888 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
889 snapshot.anchor_before(point)
890 } else {
891 Anchor::MIN
892 };
893
894 project.set_agent_location(
895 Some(AgentLocation {
896 buffer: buffer.downgrade(),
897 position,
898 }),
899 cx,
900 );
901 })
902 })
903 .detach_and_log_err(cx);
904 });
905 }
906
907 pub fn request_tool_call_authorization(
908 &mut self,
909 tool_call: acp::ToolCall,
910 options: Vec<acp::PermissionOption>,
911 cx: &mut Context<Self>,
912 ) -> oneshot::Receiver<acp::PermissionOptionId> {
913 let (tx, rx) = oneshot::channel();
914
915 let status = ToolCallStatus::WaitingForConfirmation {
916 options,
917 respond_tx: tx,
918 };
919
920 self.upsert_tool_call_inner(tool_call, status, cx);
921 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
922 rx
923 }
924
925 pub fn authorize_tool_call(
926 &mut self,
927 id: acp::ToolCallId,
928 option_id: acp::PermissionOptionId,
929 option_kind: acp::PermissionOptionKind,
930 cx: &mut Context<Self>,
931 ) {
932 let Some((ix, call)) = self.tool_call_mut(&id) else {
933 return;
934 };
935
936 let new_status = match option_kind {
937 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
938 ToolCallStatus::Rejected
939 }
940 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
941 ToolCallStatus::Allowed {
942 status: acp::ToolCallStatus::InProgress,
943 }
944 }
945 };
946
947 let curr_status = mem::replace(&mut call.status, new_status);
948
949 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
950 respond_tx.send(option_id).log_err();
951 } else if cfg!(debug_assertions) {
952 panic!("tried to authorize an already authorized tool call");
953 }
954
955 cx.emit(AcpThreadEvent::EntryUpdated(ix));
956 }
957
958 /// Returns true if the last turn is awaiting tool authorization
959 pub fn waiting_for_tool_confirmation(&self) -> bool {
960 for entry in self.entries.iter().rev() {
961 match &entry {
962 AgentThreadEntry::ToolCall(call) => match call.status {
963 ToolCallStatus::WaitingForConfirmation { .. } => return true,
964 ToolCallStatus::Allowed { .. }
965 | ToolCallStatus::Rejected
966 | ToolCallStatus::Canceled => continue,
967 },
968 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
969 // Reached the beginning of the turn
970 return false;
971 }
972 }
973 }
974 false
975 }
976
977 pub fn plan(&self) -> &Plan {
978 &self.plan
979 }
980
981 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
982 let new_entries_len = request.entries.len();
983 let mut new_entries = request.entries.into_iter();
984
985 // Reuse existing markdown to prevent flickering
986 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
987 let PlanEntry {
988 content,
989 priority,
990 status,
991 } = old;
992 content.update(cx, |old, cx| {
993 old.replace(new.content, cx);
994 });
995 *priority = new.priority;
996 *status = new.status;
997 }
998 for new in new_entries {
999 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1000 }
1001 self.plan.entries.truncate(new_entries_len);
1002
1003 cx.notify();
1004 }
1005
1006 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1007 self.plan
1008 .entries
1009 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1010 cx.notify();
1011 }
1012
1013 #[cfg(any(test, feature = "test-support"))]
1014 pub fn send_raw(
1015 &mut self,
1016 message: &str,
1017 cx: &mut Context<Self>,
1018 ) -> BoxFuture<'static, Result<()>> {
1019 self.send(
1020 vec![acp::ContentBlock::Text(acp::TextContent {
1021 text: message.to_string(),
1022 annotations: None,
1023 })],
1024 cx,
1025 )
1026 }
1027
1028 pub fn send(
1029 &mut self,
1030 message: Vec<acp::ContentBlock>,
1031 cx: &mut Context<Self>,
1032 ) -> BoxFuture<'static, Result<()>> {
1033 let block = ContentBlock::new_combined(
1034 message.clone(),
1035 self.project.read(cx).languages().clone(),
1036 cx,
1037 );
1038 self.push_entry(
1039 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1040 cx,
1041 );
1042 self.clear_completed_plan_entries(cx);
1043
1044 let (tx, rx) = oneshot::channel();
1045 let cancel_task = self.cancel(cx);
1046
1047 self.send_task = Some(cx.spawn(async move |this, cx| {
1048 async {
1049 cancel_task.await;
1050
1051 let result = this
1052 .update(cx, |this, cx| {
1053 this.connection.prompt(
1054 acp::PromptRequest {
1055 prompt: message,
1056 session_id: this.session_id.clone(),
1057 },
1058 cx,
1059 )
1060 })?
1061 .await;
1062
1063 tx.send(result).log_err();
1064
1065 anyhow::Ok(())
1066 }
1067 .await
1068 .log_err();
1069 }));
1070
1071 cx.spawn(async move |this, cx| match rx.await {
1072 Ok(Err(e)) => {
1073 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1074 .log_err();
1075 Err(e)?
1076 }
1077 result => {
1078 let cancelled = matches!(
1079 result,
1080 Ok(Ok(acp::PromptResponse {
1081 stop_reason: acp::StopReason::Cancelled
1082 }))
1083 );
1084
1085 // We only take the task if the current prompt wasn't cancelled.
1086 //
1087 // This prompt may have been cancelled because another one was sent
1088 // while it was still generating. In these cases, dropping `send_task`
1089 // would cause the next generation to be cancelled.
1090 if !cancelled {
1091 this.update(cx, |this, _cx| this.send_task.take()).ok();
1092 }
1093
1094 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1095 .log_err();
1096 Ok(())
1097 }
1098 })
1099 .boxed()
1100 }
1101
1102 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1103 let Some(send_task) = self.send_task.take() else {
1104 return Task::ready(());
1105 };
1106
1107 for entry in self.entries.iter_mut() {
1108 if let AgentThreadEntry::ToolCall(call) = entry {
1109 let cancel = matches!(
1110 call.status,
1111 ToolCallStatus::WaitingForConfirmation { .. }
1112 | ToolCallStatus::Allowed {
1113 status: acp::ToolCallStatus::InProgress
1114 }
1115 );
1116
1117 if cancel {
1118 call.status = ToolCallStatus::Canceled;
1119 }
1120 }
1121 }
1122
1123 self.connection.cancel(&self.session_id, cx);
1124
1125 // Wait for the send task to complete
1126 cx.foreground_executor().spawn(send_task)
1127 }
1128
1129 pub fn read_text_file(
1130 &self,
1131 path: PathBuf,
1132 line: Option<u32>,
1133 limit: Option<u32>,
1134 reuse_shared_snapshot: bool,
1135 cx: &mut Context<Self>,
1136 ) -> Task<Result<String>> {
1137 let project = self.project.clone();
1138 let action_log = self.action_log.clone();
1139 cx.spawn(async move |this, cx| {
1140 let load = project.update(cx, |project, cx| {
1141 let path = project
1142 .project_path_for_absolute_path(&path, cx)
1143 .context("invalid path")?;
1144 anyhow::Ok(project.open_buffer(path, cx))
1145 });
1146 let buffer = load??.await?;
1147
1148 let snapshot = if reuse_shared_snapshot {
1149 this.read_with(cx, |this, _| {
1150 this.shared_buffers.get(&buffer.clone()).cloned()
1151 })
1152 .log_err()
1153 .flatten()
1154 } else {
1155 None
1156 };
1157
1158 let snapshot = if let Some(snapshot) = snapshot {
1159 snapshot
1160 } else {
1161 action_log.update(cx, |action_log, cx| {
1162 action_log.buffer_read(buffer.clone(), cx);
1163 })?;
1164 project.update(cx, |project, cx| {
1165 let position = buffer
1166 .read(cx)
1167 .snapshot()
1168 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1169 project.set_agent_location(
1170 Some(AgentLocation {
1171 buffer: buffer.downgrade(),
1172 position,
1173 }),
1174 cx,
1175 );
1176 })?;
1177
1178 buffer.update(cx, |buffer, _| buffer.snapshot())?
1179 };
1180
1181 this.update(cx, |this, _| {
1182 let text = snapshot.text();
1183 this.shared_buffers.insert(buffer.clone(), snapshot);
1184 if line.is_none() && limit.is_none() {
1185 return Ok(text);
1186 }
1187 let limit = limit.unwrap_or(u32::MAX) as usize;
1188 let Some(line) = line else {
1189 return Ok(text.lines().take(limit).collect::<String>());
1190 };
1191
1192 let count = text.lines().count();
1193 if count < line as usize {
1194 anyhow::bail!("There are only {} lines", count);
1195 }
1196 Ok(text
1197 .lines()
1198 .skip(line as usize + 1)
1199 .take(limit)
1200 .collect::<String>())
1201 })?
1202 })
1203 }
1204
1205 pub fn write_text_file(
1206 &self,
1207 path: PathBuf,
1208 content: String,
1209 cx: &mut Context<Self>,
1210 ) -> Task<Result<()>> {
1211 let project = self.project.clone();
1212 let action_log = self.action_log.clone();
1213 cx.spawn(async move |this, cx| {
1214 let load = project.update(cx, |project, cx| {
1215 let path = project
1216 .project_path_for_absolute_path(&path, cx)
1217 .context("invalid path")?;
1218 anyhow::Ok(project.open_buffer(path, cx))
1219 });
1220 let buffer = load??.await?;
1221 let snapshot = this.update(cx, |this, cx| {
1222 this.shared_buffers
1223 .get(&buffer)
1224 .cloned()
1225 .unwrap_or_else(|| buffer.read(cx).snapshot())
1226 })?;
1227 let edits = cx
1228 .background_executor()
1229 .spawn(async move {
1230 let old_text = snapshot.text();
1231 text_diff(old_text.as_str(), &content)
1232 .into_iter()
1233 .map(|(range, replacement)| {
1234 (
1235 snapshot.anchor_after(range.start)
1236 ..snapshot.anchor_before(range.end),
1237 replacement,
1238 )
1239 })
1240 .collect::<Vec<_>>()
1241 })
1242 .await;
1243 cx.update(|cx| {
1244 project.update(cx, |project, cx| {
1245 project.set_agent_location(
1246 Some(AgentLocation {
1247 buffer: buffer.downgrade(),
1248 position: edits
1249 .last()
1250 .map(|(range, _)| range.end)
1251 .unwrap_or(Anchor::MIN),
1252 }),
1253 cx,
1254 );
1255 });
1256
1257 action_log.update(cx, |action_log, cx| {
1258 action_log.buffer_read(buffer.clone(), cx);
1259 });
1260 buffer.update(cx, |buffer, cx| {
1261 buffer.edit(edits, None, cx);
1262 });
1263 action_log.update(cx, |action_log, cx| {
1264 action_log.buffer_edited(buffer.clone(), cx);
1265 });
1266 })?;
1267 project
1268 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1269 .await
1270 })
1271 }
1272
1273 pub fn to_markdown(&self, cx: &App) -> String {
1274 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1275 }
1276
1277 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1278 cx.emit(AcpThreadEvent::ServerExited(status));
1279 }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284 use super::*;
1285 use anyhow::anyhow;
1286 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1287 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1288 use indoc::indoc;
1289 use project::FakeFs;
1290 use rand::Rng as _;
1291 use serde_json::json;
1292 use settings::SettingsStore;
1293 use smol::stream::StreamExt as _;
1294 use std::{cell::RefCell, rc::Rc, time::Duration};
1295
1296 use util::path;
1297
1298 fn init_test(cx: &mut TestAppContext) {
1299 env_logger::try_init().ok();
1300 cx.update(|cx| {
1301 let settings_store = SettingsStore::test(cx);
1302 cx.set_global(settings_store);
1303 Project::init_settings(cx);
1304 language::init(cx);
1305 });
1306 }
1307
1308 #[gpui::test]
1309 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1310 init_test(cx);
1311
1312 let fs = FakeFs::new(cx.executor());
1313 let project = Project::test(fs, [], cx).await;
1314 let connection = Rc::new(FakeAgentConnection::new());
1315 let thread = cx
1316 .spawn(async move |mut cx| {
1317 connection
1318 .new_thread(project, Path::new(path!("/test")), &mut cx)
1319 .await
1320 })
1321 .await
1322 .unwrap();
1323
1324 // Test creating a new user message
1325 thread.update(cx, |thread, cx| {
1326 thread.push_user_content_block(
1327 acp::ContentBlock::Text(acp::TextContent {
1328 annotations: None,
1329 text: "Hello, ".to_string(),
1330 }),
1331 cx,
1332 );
1333 });
1334
1335 thread.update(cx, |thread, cx| {
1336 assert_eq!(thread.entries.len(), 1);
1337 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1338 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1339 } else {
1340 panic!("Expected UserMessage");
1341 }
1342 });
1343
1344 // Test appending to existing user message
1345 thread.update(cx, |thread, cx| {
1346 thread.push_user_content_block(
1347 acp::ContentBlock::Text(acp::TextContent {
1348 annotations: None,
1349 text: "world!".to_string(),
1350 }),
1351 cx,
1352 );
1353 });
1354
1355 thread.update(cx, |thread, cx| {
1356 assert_eq!(thread.entries.len(), 1);
1357 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1358 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1359 } else {
1360 panic!("Expected UserMessage");
1361 }
1362 });
1363
1364 // Test creating new user message after assistant message
1365 thread.update(cx, |thread, cx| {
1366 thread.push_assistant_content_block(
1367 acp::ContentBlock::Text(acp::TextContent {
1368 annotations: None,
1369 text: "Assistant response".to_string(),
1370 }),
1371 false,
1372 cx,
1373 );
1374 });
1375
1376 thread.update(cx, |thread, cx| {
1377 thread.push_user_content_block(
1378 acp::ContentBlock::Text(acp::TextContent {
1379 annotations: None,
1380 text: "New user message".to_string(),
1381 }),
1382 cx,
1383 );
1384 });
1385
1386 thread.update(cx, |thread, cx| {
1387 assert_eq!(thread.entries.len(), 3);
1388 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1389 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1390 } else {
1391 panic!("Expected UserMessage at index 2");
1392 }
1393 });
1394 }
1395
1396 #[gpui::test]
1397 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1398 init_test(cx);
1399
1400 let fs = FakeFs::new(cx.executor());
1401 let project = Project::test(fs, [], cx).await;
1402 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1403 |_, thread, mut cx| {
1404 async move {
1405 thread.update(&mut cx, |thread, cx| {
1406 thread
1407 .handle_session_update(
1408 acp::SessionUpdate::AgentThoughtChunk {
1409 content: "Thinking ".into(),
1410 },
1411 cx,
1412 )
1413 .unwrap();
1414 thread
1415 .handle_session_update(
1416 acp::SessionUpdate::AgentThoughtChunk {
1417 content: "hard!".into(),
1418 },
1419 cx,
1420 )
1421 .unwrap();
1422 })?;
1423 Ok(acp::PromptResponse {
1424 stop_reason: acp::StopReason::EndTurn,
1425 })
1426 }
1427 .boxed_local()
1428 },
1429 ));
1430
1431 let thread = cx
1432 .spawn(async move |mut cx| {
1433 connection
1434 .new_thread(project, Path::new(path!("/test")), &mut cx)
1435 .await
1436 })
1437 .await
1438 .unwrap();
1439
1440 thread
1441 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1442 .await
1443 .unwrap();
1444
1445 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1446 assert_eq!(
1447 output,
1448 indoc! {r#"
1449 ## User
1450
1451 Hello from Zed!
1452
1453 ## Assistant
1454
1455 <thinking>
1456 Thinking hard!
1457 </thinking>
1458
1459 "#}
1460 );
1461 }
1462
1463 #[gpui::test]
1464 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1465 init_test(cx);
1466
1467 let fs = FakeFs::new(cx.executor());
1468 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1469 .await;
1470 let project = Project::test(fs.clone(), [], cx).await;
1471 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1472 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1473 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1474 move |_, thread, mut cx| {
1475 let read_file_tx = read_file_tx.clone();
1476 async move {
1477 let content = thread
1478 .update(&mut cx, |thread, cx| {
1479 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1480 })
1481 .unwrap()
1482 .await
1483 .unwrap();
1484 assert_eq!(content, "one\ntwo\nthree\n");
1485 read_file_tx.take().unwrap().send(()).unwrap();
1486 thread
1487 .update(&mut cx, |thread, cx| {
1488 thread.write_text_file(
1489 path!("/tmp/foo").into(),
1490 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1491 cx,
1492 )
1493 })
1494 .unwrap()
1495 .await
1496 .unwrap();
1497 Ok(acp::PromptResponse {
1498 stop_reason: acp::StopReason::EndTurn,
1499 })
1500 }
1501 .boxed_local()
1502 },
1503 ));
1504
1505 let (worktree, pathbuf) = project
1506 .update(cx, |project, cx| {
1507 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1508 })
1509 .await
1510 .unwrap();
1511 let buffer = project
1512 .update(cx, |project, cx| {
1513 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1514 })
1515 .await
1516 .unwrap();
1517
1518 let thread = cx
1519 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1520 .await
1521 .unwrap();
1522
1523 let request = thread.update(cx, |thread, cx| {
1524 thread.send_raw("Extend the count in /tmp/foo", cx)
1525 });
1526 read_file_rx.await.ok();
1527 buffer.update(cx, |buffer, cx| {
1528 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1529 });
1530 cx.run_until_parked();
1531 assert_eq!(
1532 buffer.read_with(cx, |buffer, _| buffer.text()),
1533 "zero\none\ntwo\nthree\nfour\nfive\n"
1534 );
1535 assert_eq!(
1536 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1537 "zero\none\ntwo\nthree\nfour\nfive\n"
1538 );
1539 request.await.unwrap();
1540 }
1541
1542 #[gpui::test]
1543 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1544 init_test(cx);
1545
1546 let fs = FakeFs::new(cx.executor());
1547 let project = Project::test(fs, [], cx).await;
1548 let id = acp::ToolCallId("test".into());
1549
1550 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1551 let id = id.clone();
1552 move |_, thread, mut cx| {
1553 let id = id.clone();
1554 async move {
1555 thread
1556 .update(&mut cx, |thread, cx| {
1557 thread.handle_session_update(
1558 acp::SessionUpdate::ToolCall(acp::ToolCall {
1559 id: id.clone(),
1560 title: "Label".into(),
1561 kind: acp::ToolKind::Fetch,
1562 status: acp::ToolCallStatus::InProgress,
1563 content: vec![],
1564 locations: vec![],
1565 raw_input: None,
1566 raw_output: None,
1567 }),
1568 cx,
1569 )
1570 })
1571 .unwrap()
1572 .unwrap();
1573 Ok(acp::PromptResponse {
1574 stop_reason: acp::StopReason::EndTurn,
1575 })
1576 }
1577 .boxed_local()
1578 }
1579 }));
1580
1581 let thread = cx
1582 .spawn(async move |mut cx| {
1583 connection
1584 .new_thread(project, Path::new(path!("/test")), &mut cx)
1585 .await
1586 })
1587 .await
1588 .unwrap();
1589
1590 let request = thread.update(cx, |thread, cx| {
1591 thread.send_raw("Fetch https://example.com", cx)
1592 });
1593
1594 run_until_first_tool_call(&thread, cx).await;
1595
1596 thread.read_with(cx, |thread, _| {
1597 assert!(matches!(
1598 thread.entries[1],
1599 AgentThreadEntry::ToolCall(ToolCall {
1600 status: ToolCallStatus::Allowed {
1601 status: acp::ToolCallStatus::InProgress,
1602 ..
1603 },
1604 ..
1605 })
1606 ));
1607 });
1608
1609 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1610
1611 thread.read_with(cx, |thread, _| {
1612 assert!(matches!(
1613 &thread.entries[1],
1614 AgentThreadEntry::ToolCall(ToolCall {
1615 status: ToolCallStatus::Canceled,
1616 ..
1617 })
1618 ));
1619 });
1620
1621 thread
1622 .update(cx, |thread, cx| {
1623 thread.handle_session_update(
1624 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1625 id,
1626 fields: acp::ToolCallUpdateFields {
1627 status: Some(acp::ToolCallStatus::Completed),
1628 ..Default::default()
1629 },
1630 }),
1631 cx,
1632 )
1633 })
1634 .unwrap();
1635
1636 request.await.unwrap();
1637
1638 thread.read_with(cx, |thread, _| {
1639 assert!(matches!(
1640 thread.entries[1],
1641 AgentThreadEntry::ToolCall(ToolCall {
1642 status: ToolCallStatus::Allowed {
1643 status: acp::ToolCallStatus::Completed,
1644 ..
1645 },
1646 ..
1647 })
1648 ));
1649 });
1650 }
1651
1652 #[gpui::test]
1653 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1654 init_test(cx);
1655 let fs = FakeFs::new(cx.background_executor.clone());
1656 fs.insert_tree(path!("/test"), json!({})).await;
1657 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1658
1659 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1660 move |_, thread, mut cx| {
1661 async move {
1662 thread
1663 .update(&mut cx, |thread, cx| {
1664 thread.handle_session_update(
1665 acp::SessionUpdate::ToolCall(acp::ToolCall {
1666 id: acp::ToolCallId("test".into()),
1667 title: "Label".into(),
1668 kind: acp::ToolKind::Edit,
1669 status: acp::ToolCallStatus::Completed,
1670 content: vec![acp::ToolCallContent::Diff {
1671 diff: acp::Diff {
1672 path: "/test/test.txt".into(),
1673 old_text: None,
1674 new_text: "foo".into(),
1675 },
1676 }],
1677 locations: vec![],
1678 raw_input: None,
1679 raw_output: None,
1680 }),
1681 cx,
1682 )
1683 })
1684 .unwrap()
1685 .unwrap();
1686 Ok(acp::PromptResponse {
1687 stop_reason: acp::StopReason::EndTurn,
1688 })
1689 }
1690 .boxed_local()
1691 }
1692 }));
1693
1694 let thread = connection
1695 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1696 .await
1697 .unwrap();
1698 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1699 .await
1700 .unwrap();
1701
1702 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1703 }
1704
1705 async fn run_until_first_tool_call(
1706 thread: &Entity<AcpThread>,
1707 cx: &mut TestAppContext,
1708 ) -> usize {
1709 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1710
1711 let subscription = cx.update(|cx| {
1712 cx.subscribe(thread, move |thread, _, cx| {
1713 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1714 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1715 return tx.try_send(ix).unwrap();
1716 }
1717 }
1718 })
1719 });
1720
1721 select! {
1722 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1723 panic!("Timeout waiting for tool call")
1724 }
1725 ix = rx.next().fuse() => {
1726 drop(subscription);
1727 ix.unwrap()
1728 }
1729 }
1730 }
1731
1732 #[derive(Clone, Default)]
1733 struct FakeAgentConnection {
1734 auth_methods: Vec<acp::AuthMethod>,
1735 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1736 on_user_message: Option<
1737 Rc<
1738 dyn Fn(
1739 acp::PromptRequest,
1740 WeakEntity<AcpThread>,
1741 AsyncApp,
1742 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1743 + 'static,
1744 >,
1745 >,
1746 }
1747
1748 impl FakeAgentConnection {
1749 fn new() -> Self {
1750 Self {
1751 auth_methods: Vec::new(),
1752 on_user_message: None,
1753 sessions: Arc::default(),
1754 }
1755 }
1756
1757 #[expect(unused)]
1758 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1759 self.auth_methods = auth_methods;
1760 self
1761 }
1762
1763 fn on_user_message(
1764 mut self,
1765 handler: impl Fn(
1766 acp::PromptRequest,
1767 WeakEntity<AcpThread>,
1768 AsyncApp,
1769 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1770 + 'static,
1771 ) -> Self {
1772 self.on_user_message.replace(Rc::new(handler));
1773 self
1774 }
1775 }
1776
1777 impl AgentConnection for FakeAgentConnection {
1778 fn auth_methods(&self) -> &[acp::AuthMethod] {
1779 &self.auth_methods
1780 }
1781
1782 fn new_thread(
1783 self: Rc<Self>,
1784 project: Entity<Project>,
1785 _cwd: &Path,
1786 cx: &mut gpui::AsyncApp,
1787 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1788 let session_id = acp::SessionId(
1789 rand::thread_rng()
1790 .sample_iter(&rand::distributions::Alphanumeric)
1791 .take(7)
1792 .map(char::from)
1793 .collect::<String>()
1794 .into(),
1795 );
1796 let thread = cx
1797 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1798 .unwrap();
1799 self.sessions.lock().insert(session_id, thread.downgrade());
1800 Task::ready(Ok(thread))
1801 }
1802
1803 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1804 if self.auth_methods().iter().any(|m| m.id == method) {
1805 Task::ready(Ok(()))
1806 } else {
1807 Task::ready(Err(anyhow!("Invalid Auth Method")))
1808 }
1809 }
1810
1811 fn prompt(
1812 &self,
1813 params: acp::PromptRequest,
1814 cx: &mut App,
1815 ) -> Task<gpui::Result<acp::PromptResponse>> {
1816 let sessions = self.sessions.lock();
1817 let thread = sessions.get(¶ms.session_id).unwrap();
1818 if let Some(handler) = &self.on_user_message {
1819 let handler = handler.clone();
1820 let thread = thread.clone();
1821 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1822 } else {
1823 Task::ready(Ok(acp::PromptResponse {
1824 stop_reason: acp::StopReason::EndTurn,
1825 }))
1826 }
1827 }
1828
1829 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1830 let sessions = self.sessions.lock();
1831 let thread = sessions.get(&session_id).unwrap().clone();
1832
1833 cx.spawn(async move |cx| {
1834 thread
1835 .update(cx, |thread, cx| thread.cancel(cx))
1836 .unwrap()
1837 .await
1838 })
1839 .detach();
1840 }
1841 }
1842}