1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result};
6use assistant_tool::ActionLog;
7use buffer_diff::BufferDiff;
8use editor::{Bias, MultiBuffer, PathKey};
9use futures::future::{Fuse, FusedFuture};
10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
11use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
12use itertools::Itertools;
13use language::{
14 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
15 text_diff,
16};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::process::ExitStatus;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
145 if let AgentThreadEntry::ToolCall(call) = self {
146 itertools::Either::Left(call.diffs())
147 } else {
148 itertools::Either::Right(std::iter::empty())
149 }
150 }
151
152 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
153 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
154 Some(locations)
155 } else {
156 None
157 }
158 }
159}
160
161#[derive(Debug)]
162pub struct ToolCall {
163 pub id: acp::ToolCallId,
164 pub label: Entity<Markdown>,
165 pub kind: acp::ToolKind,
166 pub content: Vec<ToolCallContent>,
167 pub status: ToolCallStatus,
168 pub locations: Vec<acp::ToolCallLocation>,
169 pub raw_input: Option<serde_json::Value>,
170}
171
172impl ToolCall {
173 fn from_acp(
174 tool_call: acp::ToolCall,
175 status: ToolCallStatus,
176 language_registry: Arc<LanguageRegistry>,
177 cx: &mut App,
178 ) -> Self {
179 Self {
180 id: tool_call.id,
181 label: cx.new(|cx| {
182 Markdown::new(
183 tool_call.title.into(),
184 Some(language_registry.clone()),
185 None,
186 cx,
187 )
188 }),
189 kind: tool_call.kind,
190 content: tool_call
191 .content
192 .into_iter()
193 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
194 .collect(),
195 locations: tool_call.locations,
196 status,
197 raw_input: tool_call.raw_input,
198 }
199 }
200
201 fn update(
202 &mut self,
203 fields: acp::ToolCallUpdateFields,
204 language_registry: Arc<LanguageRegistry>,
205 cx: &mut App,
206 ) {
207 let acp::ToolCallUpdateFields {
208 kind,
209 status,
210 title,
211 content,
212 locations,
213 raw_input,
214 } = fields;
215
216 if let Some(kind) = kind {
217 self.kind = kind;
218 }
219
220 if let Some(status) = status {
221 self.status = ToolCallStatus::Allowed { status };
222 }
223
224 if let Some(title) = title {
225 self.label.update(cx, |label, cx| {
226 label.replace(title, cx);
227 });
228 }
229
230 if let Some(content) = content {
231 self.content = content
232 .into_iter()
233 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
234 .collect();
235 }
236
237 if let Some(locations) = locations {
238 self.locations = locations;
239 }
240
241 if let Some(raw_input) = raw_input {
242 self.raw_input = Some(raw_input);
243 }
244 }
245
246 pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
247 self.content.iter().filter_map(|content| match content {
248 ToolCallContent::ContentBlock { .. } => None,
249 ToolCallContent::Diff { diff } => Some(diff),
250 })
251 }
252
253 fn to_markdown(&self, cx: &App) -> String {
254 let mut markdown = format!(
255 "**Tool Call: {}**\nStatus: {}\n\n",
256 self.label.read(cx).source(),
257 self.status
258 );
259 for content in &self.content {
260 markdown.push_str(content.to_markdown(cx).as_str());
261 markdown.push_str("\n\n");
262 }
263 markdown
264 }
265}
266
267#[derive(Debug)]
268pub enum ToolCallStatus {
269 WaitingForConfirmation {
270 options: Vec<acp::PermissionOption>,
271 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
272 },
273 Allowed {
274 status: acp::ToolCallStatus,
275 },
276 Rejected,
277 Canceled,
278}
279
280impl Display for ToolCallStatus {
281 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
282 write!(
283 f,
284 "{}",
285 match self {
286 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
287 ToolCallStatus::Allowed { status } => match status {
288 acp::ToolCallStatus::Pending => "Pending",
289 acp::ToolCallStatus::InProgress => "In Progress",
290 acp::ToolCallStatus::Completed => "Completed",
291 acp::ToolCallStatus::Failed => "Failed",
292 },
293 ToolCallStatus::Rejected => "Rejected",
294 ToolCallStatus::Canceled => "Canceled",
295 }
296 )
297 }
298}
299
300#[derive(Debug, PartialEq, Clone)]
301pub enum ContentBlock {
302 Empty,
303 Markdown { markdown: Entity<Markdown> },
304}
305
306impl ContentBlock {
307 pub fn new(
308 block: acp::ContentBlock,
309 language_registry: &Arc<LanguageRegistry>,
310 cx: &mut App,
311 ) -> Self {
312 let mut this = Self::Empty;
313 this.append(block, language_registry, cx);
314 this
315 }
316
317 pub fn new_combined(
318 blocks: impl IntoIterator<Item = acp::ContentBlock>,
319 language_registry: Arc<LanguageRegistry>,
320 cx: &mut App,
321 ) -> Self {
322 let mut this = Self::Empty;
323 for block in blocks {
324 this.append(block, &language_registry, cx);
325 }
326 this
327 }
328
329 pub fn append(
330 &mut self,
331 block: acp::ContentBlock,
332 language_registry: &Arc<LanguageRegistry>,
333 cx: &mut App,
334 ) {
335 let new_content = match block {
336 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
337 acp::ContentBlock::ResourceLink(resource_link) => {
338 if let Some(path) = resource_link.uri.strip_prefix("file://") {
339 format!("{}", MentionPath(path.as_ref()))
340 } else {
341 resource_link.uri.clone()
342 }
343 }
344 acp::ContentBlock::Image(_)
345 | acp::ContentBlock::Audio(_)
346 | acp::ContentBlock::Resource(_) => String::new(),
347 };
348
349 match self {
350 ContentBlock::Empty => {
351 *self = ContentBlock::Markdown {
352 markdown: cx.new(|cx| {
353 Markdown::new(
354 new_content.into(),
355 Some(language_registry.clone()),
356 None,
357 cx,
358 )
359 }),
360 };
361 }
362 ContentBlock::Markdown { markdown } => {
363 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
364 }
365 }
366 }
367
368 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
369 match self {
370 ContentBlock::Empty => "",
371 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
372 }
373 }
374
375 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
376 match self {
377 ContentBlock::Empty => None,
378 ContentBlock::Markdown { markdown } => Some(markdown),
379 }
380 }
381}
382
383#[derive(Debug)]
384pub enum ToolCallContent {
385 ContentBlock { content: ContentBlock },
386 Diff { diff: Diff },
387}
388
389impl ToolCallContent {
390 pub fn from_acp(
391 content: acp::ToolCallContent,
392 language_registry: Arc<LanguageRegistry>,
393 cx: &mut App,
394 ) -> Self {
395 match content {
396 acp::ToolCallContent::Content { content } => Self::ContentBlock {
397 content: ContentBlock::new(content, &language_registry, cx),
398 },
399 acp::ToolCallContent::Diff { diff } => Self::Diff {
400 diff: Diff::from_acp(diff, language_registry, cx),
401 },
402 }
403 }
404
405 pub fn to_markdown(&self, cx: &App) -> String {
406 match self {
407 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
408 Self::Diff { diff } => diff.to_markdown(cx),
409 }
410 }
411}
412
413#[derive(Debug)]
414pub struct Diff {
415 pub multibuffer: Entity<MultiBuffer>,
416 pub path: PathBuf,
417 _task: Task<Result<()>>,
418}
419
420impl Diff {
421 pub fn from_acp(
422 diff: acp::Diff,
423 language_registry: Arc<LanguageRegistry>,
424 cx: &mut App,
425 ) -> Self {
426 let acp::Diff {
427 path,
428 old_text,
429 new_text,
430 } = diff;
431
432 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
433
434 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
435 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
436 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
437 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
438
439 let task = cx.spawn({
440 let multibuffer = multibuffer.clone();
441 let path = path.clone();
442 async move |cx| {
443 let language = language_registry
444 .language_for_file_path(&path)
445 .await
446 .log_err();
447
448 new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
449
450 let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
451 buffer.set_language(language, cx);
452 buffer.snapshot()
453 })?;
454
455 buffer_diff
456 .update(cx, |diff, cx| {
457 diff.set_base_text(
458 old_buffer_snapshot,
459 Some(language_registry),
460 new_buffer_snapshot,
461 cx,
462 )
463 })?
464 .await?;
465
466 multibuffer
467 .update(cx, |multibuffer, cx| {
468 let hunk_ranges = {
469 let buffer = new_buffer.read(cx);
470 let diff = buffer_diff.read(cx);
471 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
472 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
473 .collect::<Vec<_>>()
474 };
475
476 multibuffer.set_excerpts_for_path(
477 PathKey::for_buffer(&new_buffer, cx),
478 new_buffer.clone(),
479 hunk_ranges,
480 editor::DEFAULT_MULTIBUFFER_CONTEXT,
481 cx,
482 );
483 multibuffer.add_diff(buffer_diff, cx);
484 })
485 .log_err();
486
487 anyhow::Ok(())
488 }
489 });
490
491 Self {
492 multibuffer,
493 path,
494 _task: task,
495 }
496 }
497
498 fn to_markdown(&self, cx: &App) -> String {
499 let buffer_text = self
500 .multibuffer
501 .read(cx)
502 .all_buffers()
503 .iter()
504 .map(|buffer| buffer.read(cx).text())
505 .join("\n");
506 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
507 }
508}
509
510#[derive(Debug, Default)]
511pub struct Plan {
512 pub entries: Vec<PlanEntry>,
513}
514
515#[derive(Debug)]
516pub struct PlanStats<'a> {
517 pub in_progress_entry: Option<&'a PlanEntry>,
518 pub pending: u32,
519 pub completed: u32,
520}
521
522impl Plan {
523 pub fn is_empty(&self) -> bool {
524 self.entries.is_empty()
525 }
526
527 pub fn stats(&self) -> PlanStats<'_> {
528 let mut stats = PlanStats {
529 in_progress_entry: None,
530 pending: 0,
531 completed: 0,
532 };
533
534 for entry in &self.entries {
535 match &entry.status {
536 acp::PlanEntryStatus::Pending => {
537 stats.pending += 1;
538 }
539 acp::PlanEntryStatus::InProgress => {
540 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
541 }
542 acp::PlanEntryStatus::Completed => {
543 stats.completed += 1;
544 }
545 }
546 }
547
548 stats
549 }
550}
551
552#[derive(Debug)]
553pub struct PlanEntry {
554 pub content: Entity<Markdown>,
555 pub priority: acp::PlanEntryPriority,
556 pub status: acp::PlanEntryStatus,
557}
558
559impl PlanEntry {
560 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
561 Self {
562 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
563 priority: entry.priority,
564 status: entry.status,
565 }
566 }
567}
568
569pub struct AcpThread {
570 title: SharedString,
571 entries: Vec<AgentThreadEntry>,
572 plan: Plan,
573 project: Entity<Project>,
574 action_log: Entity<ActionLog>,
575 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
576 send_task: Option<Fuse<Task<()>>>,
577 connection: Rc<dyn AgentConnection>,
578 session_id: acp::SessionId,
579}
580
581pub enum AcpThreadEvent {
582 NewEntry,
583 EntryUpdated(usize),
584 ToolAuthorizationRequired,
585 Stopped,
586 Error,
587 ServerExited(ExitStatus),
588}
589
590impl EventEmitter<AcpThreadEvent> for AcpThread {}
591
592#[derive(PartialEq, Eq)]
593pub enum ThreadStatus {
594 Idle,
595 WaitingForToolConfirmation,
596 Generating,
597}
598
599#[derive(Debug, Clone)]
600pub enum LoadError {
601 Unsupported {
602 error_message: SharedString,
603 upgrade_message: SharedString,
604 upgrade_command: String,
605 },
606 Exited(i32),
607 Other(SharedString),
608}
609
610impl Display for LoadError {
611 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
612 match self {
613 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
614 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
615 LoadError::Other(msg) => write!(f, "{}", msg),
616 }
617 }
618}
619
620impl Error for LoadError {}
621
622impl AcpThread {
623 pub fn new(
624 title: impl Into<SharedString>,
625 connection: Rc<dyn AgentConnection>,
626 project: Entity<Project>,
627 session_id: acp::SessionId,
628 cx: &mut Context<Self>,
629 ) -> Self {
630 let action_log = cx.new(|_| ActionLog::new(project.clone()));
631
632 Self {
633 action_log,
634 shared_buffers: Default::default(),
635 entries: Default::default(),
636 plan: Default::default(),
637 title: title.into(),
638 project,
639 send_task: None,
640 connection,
641 session_id,
642 }
643 }
644
645 pub fn action_log(&self) -> &Entity<ActionLog> {
646 &self.action_log
647 }
648
649 pub fn project(&self) -> &Entity<Project> {
650 &self.project
651 }
652
653 pub fn title(&self) -> SharedString {
654 self.title.clone()
655 }
656
657 pub fn entries(&self) -> &[AgentThreadEntry] {
658 &self.entries
659 }
660
661 pub fn session_id(&self) -> &acp::SessionId {
662 &self.session_id
663 }
664
665 pub fn status(&self) -> ThreadStatus {
666 if self
667 .send_task
668 .as_ref()
669 .map_or(false, |t| !t.is_terminated())
670 {
671 if self.waiting_for_tool_confirmation() {
672 ThreadStatus::WaitingForToolConfirmation
673 } else {
674 ThreadStatus::Generating
675 }
676 } else {
677 ThreadStatus::Idle
678 }
679 }
680
681 pub fn has_pending_edit_tool_calls(&self) -> bool {
682 for entry in self.entries.iter().rev() {
683 match entry {
684 AgentThreadEntry::UserMessage(_) => return false,
685 AgentThreadEntry::ToolCall(
686 call @ ToolCall {
687 status:
688 ToolCallStatus::Allowed {
689 status:
690 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
691 },
692 ..
693 },
694 ) if call.diffs().next().is_some() => {
695 return true;
696 }
697 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
698 }
699 }
700
701 false
702 }
703
704 pub fn used_tools_since_last_user_message(&self) -> bool {
705 for entry in self.entries.iter().rev() {
706 match entry {
707 AgentThreadEntry::UserMessage(..) => return false,
708 AgentThreadEntry::AssistantMessage(..) => continue,
709 AgentThreadEntry::ToolCall(..) => return true,
710 }
711 }
712
713 false
714 }
715
716 pub fn handle_session_update(
717 &mut self,
718 update: acp::SessionUpdate,
719 cx: &mut Context<Self>,
720 ) -> Result<()> {
721 match update {
722 acp::SessionUpdate::UserMessageChunk { content } => {
723 self.push_user_content_block(content, cx);
724 }
725 acp::SessionUpdate::AgentMessageChunk { content } => {
726 self.push_assistant_content_block(content, false, cx);
727 }
728 acp::SessionUpdate::AgentThoughtChunk { content } => {
729 self.push_assistant_content_block(content, true, cx);
730 }
731 acp::SessionUpdate::ToolCall(tool_call) => {
732 self.upsert_tool_call(tool_call, cx);
733 }
734 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
735 self.update_tool_call(tool_call_update, cx)?;
736 }
737 acp::SessionUpdate::Plan(plan) => {
738 self.update_plan(plan, cx);
739 }
740 }
741 Ok(())
742 }
743
744 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
745 let language_registry = self.project.read(cx).languages().clone();
746 let entries_len = self.entries.len();
747
748 if let Some(last_entry) = self.entries.last_mut()
749 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
750 {
751 content.append(chunk, &language_registry, cx);
752 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
753 } else {
754 let content = ContentBlock::new(chunk, &language_registry, cx);
755 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
756 }
757 }
758
759 pub fn push_assistant_content_block(
760 &mut self,
761 chunk: acp::ContentBlock,
762 is_thought: bool,
763 cx: &mut Context<Self>,
764 ) {
765 let language_registry = self.project.read(cx).languages().clone();
766 let entries_len = self.entries.len();
767 if let Some(last_entry) = self.entries.last_mut()
768 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
769 {
770 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
771 match (chunks.last_mut(), is_thought) {
772 (Some(AssistantMessageChunk::Message { block }), false)
773 | (Some(AssistantMessageChunk::Thought { block }), true) => {
774 block.append(chunk, &language_registry, cx)
775 }
776 _ => {
777 let block = ContentBlock::new(chunk, &language_registry, cx);
778 if is_thought {
779 chunks.push(AssistantMessageChunk::Thought { block })
780 } else {
781 chunks.push(AssistantMessageChunk::Message { block })
782 }
783 }
784 }
785 } else {
786 let block = ContentBlock::new(chunk, &language_registry, cx);
787 let chunk = if is_thought {
788 AssistantMessageChunk::Thought { block }
789 } else {
790 AssistantMessageChunk::Message { block }
791 };
792
793 self.push_entry(
794 AgentThreadEntry::AssistantMessage(AssistantMessage {
795 chunks: vec![chunk],
796 }),
797 cx,
798 );
799 }
800 }
801
802 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
803 self.entries.push(entry);
804 cx.emit(AcpThreadEvent::NewEntry);
805 }
806
807 pub fn update_tool_call(
808 &mut self,
809 update: acp::ToolCallUpdate,
810 cx: &mut Context<Self>,
811 ) -> Result<()> {
812 let languages = self.project.read(cx).languages().clone();
813
814 let (ix, current_call) = self
815 .tool_call_mut(&update.id)
816 .context("Tool call not found")?;
817 current_call.update(update.fields, languages, cx);
818
819 cx.emit(AcpThreadEvent::EntryUpdated(ix));
820
821 Ok(())
822 }
823
824 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
825 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
826 let status = ToolCallStatus::Allowed {
827 status: tool_call.status,
828 };
829 self.upsert_tool_call_inner(tool_call, status, cx)
830 }
831
832 pub fn upsert_tool_call_inner(
833 &mut self,
834 tool_call: acp::ToolCall,
835 status: ToolCallStatus,
836 cx: &mut Context<Self>,
837 ) {
838 let language_registry = self.project.read(cx).languages().clone();
839 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
840
841 let location = call.locations.last().cloned();
842
843 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
844 *current_call = call;
845
846 cx.emit(AcpThreadEvent::EntryUpdated(ix));
847 } else {
848 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
849 }
850
851 if let Some(location) = location {
852 self.set_project_location(location, cx)
853 }
854 }
855
856 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
857 // The tool call we are looking for is typically the last one, or very close to the end.
858 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
859 self.entries
860 .iter_mut()
861 .enumerate()
862 .rev()
863 .find_map(|(index, tool_call)| {
864 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
865 && &tool_call.id == id
866 {
867 Some((index, tool_call))
868 } else {
869 None
870 }
871 })
872 }
873
874 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
875 self.project.update(cx, |project, cx| {
876 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
877 return;
878 };
879 let buffer = project.open_buffer(path, cx);
880 cx.spawn(async move |project, cx| {
881 let buffer = buffer.await?;
882
883 project.update(cx, |project, cx| {
884 let position = if let Some(line) = location.line {
885 let snapshot = buffer.read(cx).snapshot();
886 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
887 snapshot.anchor_before(point)
888 } else {
889 Anchor::MIN
890 };
891
892 project.set_agent_location(
893 Some(AgentLocation {
894 buffer: buffer.downgrade(),
895 position,
896 }),
897 cx,
898 );
899 })
900 })
901 .detach_and_log_err(cx);
902 });
903 }
904
905 pub fn request_tool_call_authorization(
906 &mut self,
907 tool_call: acp::ToolCall,
908 options: Vec<acp::PermissionOption>,
909 cx: &mut Context<Self>,
910 ) -> oneshot::Receiver<acp::PermissionOptionId> {
911 let (tx, rx) = oneshot::channel();
912
913 let status = ToolCallStatus::WaitingForConfirmation {
914 options,
915 respond_tx: tx,
916 };
917
918 self.upsert_tool_call_inner(tool_call, status, cx);
919 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
920 rx
921 }
922
923 pub fn authorize_tool_call(
924 &mut self,
925 id: acp::ToolCallId,
926 option_id: acp::PermissionOptionId,
927 option_kind: acp::PermissionOptionKind,
928 cx: &mut Context<Self>,
929 ) {
930 let Some((ix, call)) = self.tool_call_mut(&id) else {
931 return;
932 };
933
934 let new_status = match option_kind {
935 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
936 ToolCallStatus::Rejected
937 }
938 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
939 ToolCallStatus::Allowed {
940 status: acp::ToolCallStatus::InProgress,
941 }
942 }
943 };
944
945 let curr_status = mem::replace(&mut call.status, new_status);
946
947 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
948 respond_tx.send(option_id).log_err();
949 } else if cfg!(debug_assertions) {
950 panic!("tried to authorize an already authorized tool call");
951 }
952
953 cx.emit(AcpThreadEvent::EntryUpdated(ix));
954 }
955
956 /// Returns true if the last turn is awaiting tool authorization
957 pub fn waiting_for_tool_confirmation(&self) -> bool {
958 for entry in self.entries.iter().rev() {
959 match &entry {
960 AgentThreadEntry::ToolCall(call) => match call.status {
961 ToolCallStatus::WaitingForConfirmation { .. } => return true,
962 ToolCallStatus::Allowed { .. }
963 | ToolCallStatus::Rejected
964 | ToolCallStatus::Canceled => continue,
965 },
966 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
967 // Reached the beginning of the turn
968 return false;
969 }
970 }
971 }
972 false
973 }
974
975 pub fn plan(&self) -> &Plan {
976 &self.plan
977 }
978
979 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
980 let new_entries_len = request.entries.len();
981 let mut new_entries = request.entries.into_iter();
982
983 // Reuse existing markdown to prevent flickering
984 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
985 let PlanEntry {
986 content,
987 priority,
988 status,
989 } = old;
990 content.update(cx, |old, cx| {
991 old.replace(new.content, cx);
992 });
993 *priority = new.priority;
994 *status = new.status;
995 }
996 for new in new_entries {
997 self.plan.entries.push(PlanEntry::from_acp(new, cx))
998 }
999 self.plan.entries.truncate(new_entries_len);
1000
1001 cx.notify();
1002 }
1003
1004 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1005 self.plan
1006 .entries
1007 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1008 cx.notify();
1009 }
1010
1011 #[cfg(any(test, feature = "test-support"))]
1012 pub fn send_raw(
1013 &mut self,
1014 message: &str,
1015 cx: &mut Context<Self>,
1016 ) -> BoxFuture<'static, Result<()>> {
1017 self.send(
1018 vec![acp::ContentBlock::Text(acp::TextContent {
1019 text: message.to_string(),
1020 annotations: None,
1021 })],
1022 cx,
1023 )
1024 }
1025
1026 pub fn send(
1027 &mut self,
1028 message: Vec<acp::ContentBlock>,
1029 cx: &mut Context<Self>,
1030 ) -> BoxFuture<'static, Result<()>> {
1031 let block = ContentBlock::new_combined(
1032 message.clone(),
1033 self.project.read(cx).languages().clone(),
1034 cx,
1035 );
1036 self.push_entry(
1037 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1038 cx,
1039 );
1040 self.clear_completed_plan_entries(cx);
1041
1042 let (tx, rx) = oneshot::channel();
1043 let cancel_task = self.cancel(cx);
1044
1045 self.send_task = Some(
1046 cx.spawn(async move |this, cx| {
1047 async {
1048 cancel_task.await;
1049
1050 let result = this
1051 .update(cx, |this, cx| {
1052 this.connection.prompt(
1053 acp::PromptRequest {
1054 prompt: message,
1055 session_id: this.session_id.clone(),
1056 },
1057 cx,
1058 )
1059 })?
1060 .await;
1061
1062 tx.send(result).log_err();
1063 anyhow::Ok(())
1064 }
1065 .await
1066 .log_err();
1067 })
1068 .fuse(),
1069 );
1070
1071 cx.spawn(async move |this, cx| match rx.await {
1072 Ok(Err(e)) => {
1073 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1074 .log_err();
1075 Err(e)?
1076 }
1077 _ => {
1078 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1079 .log_err();
1080 Ok(())
1081 }
1082 })
1083 .boxed()
1084 }
1085
1086 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1087 let Some(send_task) = self.send_task.take() else {
1088 return Task::ready(());
1089 };
1090
1091 for entry in self.entries.iter_mut() {
1092 if let AgentThreadEntry::ToolCall(call) = entry {
1093 let cancel = matches!(
1094 call.status,
1095 ToolCallStatus::WaitingForConfirmation { .. }
1096 | ToolCallStatus::Allowed {
1097 status: acp::ToolCallStatus::InProgress
1098 }
1099 );
1100
1101 if cancel {
1102 call.status = ToolCallStatus::Canceled;
1103 }
1104 }
1105 }
1106
1107 self.connection.cancel(&self.session_id, cx);
1108
1109 // Wait for the send task to complete
1110 cx.foreground_executor().spawn(send_task)
1111 }
1112
1113 pub fn read_text_file(
1114 &self,
1115 path: PathBuf,
1116 line: Option<u32>,
1117 limit: Option<u32>,
1118 reuse_shared_snapshot: bool,
1119 cx: &mut Context<Self>,
1120 ) -> Task<Result<String>> {
1121 let project = self.project.clone();
1122 let action_log = self.action_log.clone();
1123 cx.spawn(async move |this, cx| {
1124 let load = project.update(cx, |project, cx| {
1125 let path = project
1126 .project_path_for_absolute_path(&path, cx)
1127 .context("invalid path")?;
1128 anyhow::Ok(project.open_buffer(path, cx))
1129 });
1130 let buffer = load??.await?;
1131
1132 let snapshot = if reuse_shared_snapshot {
1133 this.read_with(cx, |this, _| {
1134 this.shared_buffers.get(&buffer.clone()).cloned()
1135 })
1136 .log_err()
1137 .flatten()
1138 } else {
1139 None
1140 };
1141
1142 let snapshot = if let Some(snapshot) = snapshot {
1143 snapshot
1144 } else {
1145 action_log.update(cx, |action_log, cx| {
1146 action_log.buffer_read(buffer.clone(), cx);
1147 })?;
1148 project.update(cx, |project, cx| {
1149 let position = buffer
1150 .read(cx)
1151 .snapshot()
1152 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1153 project.set_agent_location(
1154 Some(AgentLocation {
1155 buffer: buffer.downgrade(),
1156 position,
1157 }),
1158 cx,
1159 );
1160 })?;
1161
1162 buffer.update(cx, |buffer, _| buffer.snapshot())?
1163 };
1164
1165 this.update(cx, |this, _| {
1166 let text = snapshot.text();
1167 this.shared_buffers.insert(buffer.clone(), snapshot);
1168 if line.is_none() && limit.is_none() {
1169 return Ok(text);
1170 }
1171 let limit = limit.unwrap_or(u32::MAX) as usize;
1172 let Some(line) = line else {
1173 return Ok(text.lines().take(limit).collect::<String>());
1174 };
1175
1176 let count = text.lines().count();
1177 if count < line as usize {
1178 anyhow::bail!("There are only {} lines", count);
1179 }
1180 Ok(text
1181 .lines()
1182 .skip(line as usize + 1)
1183 .take(limit)
1184 .collect::<String>())
1185 })?
1186 })
1187 }
1188
1189 pub fn write_text_file(
1190 &self,
1191 path: PathBuf,
1192 content: String,
1193 cx: &mut Context<Self>,
1194 ) -> Task<Result<()>> {
1195 let project = self.project.clone();
1196 let action_log = self.action_log.clone();
1197 cx.spawn(async move |this, cx| {
1198 let load = project.update(cx, |project, cx| {
1199 let path = project
1200 .project_path_for_absolute_path(&path, cx)
1201 .context("invalid path")?;
1202 anyhow::Ok(project.open_buffer(path, cx))
1203 });
1204 let buffer = load??.await?;
1205 let snapshot = this.update(cx, |this, cx| {
1206 this.shared_buffers
1207 .get(&buffer)
1208 .cloned()
1209 .unwrap_or_else(|| buffer.read(cx).snapshot())
1210 })?;
1211 let edits = cx
1212 .background_executor()
1213 .spawn(async move {
1214 let old_text = snapshot.text();
1215 text_diff(old_text.as_str(), &content)
1216 .into_iter()
1217 .map(|(range, replacement)| {
1218 (
1219 snapshot.anchor_after(range.start)
1220 ..snapshot.anchor_before(range.end),
1221 replacement,
1222 )
1223 })
1224 .collect::<Vec<_>>()
1225 })
1226 .await;
1227 cx.update(|cx| {
1228 project.update(cx, |project, cx| {
1229 project.set_agent_location(
1230 Some(AgentLocation {
1231 buffer: buffer.downgrade(),
1232 position: edits
1233 .last()
1234 .map(|(range, _)| range.end)
1235 .unwrap_or(Anchor::MIN),
1236 }),
1237 cx,
1238 );
1239 });
1240
1241 action_log.update(cx, |action_log, cx| {
1242 action_log.buffer_read(buffer.clone(), cx);
1243 });
1244 buffer.update(cx, |buffer, cx| {
1245 buffer.edit(edits, None, cx);
1246 });
1247 action_log.update(cx, |action_log, cx| {
1248 action_log.buffer_edited(buffer.clone(), cx);
1249 });
1250 })?;
1251 project
1252 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1253 .await
1254 })
1255 }
1256
1257 pub fn to_markdown(&self, cx: &App) -> String {
1258 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1259 }
1260
1261 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1262 cx.emit(AcpThreadEvent::ServerExited(status));
1263 }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268 use super::*;
1269 use anyhow::anyhow;
1270 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1271 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1272 use indoc::indoc;
1273 use project::FakeFs;
1274 use rand::Rng as _;
1275 use serde_json::json;
1276 use settings::SettingsStore;
1277 use smol::stream::StreamExt as _;
1278 use std::{cell::RefCell, rc::Rc, time::Duration};
1279
1280 use util::path;
1281
1282 fn init_test(cx: &mut TestAppContext) {
1283 env_logger::try_init().ok();
1284 cx.update(|cx| {
1285 let settings_store = SettingsStore::test(cx);
1286 cx.set_global(settings_store);
1287 Project::init_settings(cx);
1288 language::init(cx);
1289 });
1290 }
1291
1292 #[gpui::test]
1293 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1294 init_test(cx);
1295
1296 let fs = FakeFs::new(cx.executor());
1297 let project = Project::test(fs, [], cx).await;
1298 let connection = Rc::new(FakeAgentConnection::new());
1299 let thread = cx
1300 .spawn(async move |mut cx| {
1301 connection
1302 .new_thread(project, Path::new(path!("/test")), &mut cx)
1303 .await
1304 })
1305 .await
1306 .unwrap();
1307
1308 // Test creating a new user message
1309 thread.update(cx, |thread, cx| {
1310 thread.push_user_content_block(
1311 acp::ContentBlock::Text(acp::TextContent {
1312 annotations: None,
1313 text: "Hello, ".to_string(),
1314 }),
1315 cx,
1316 );
1317 });
1318
1319 thread.update(cx, |thread, cx| {
1320 assert_eq!(thread.entries.len(), 1);
1321 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1322 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1323 } else {
1324 panic!("Expected UserMessage");
1325 }
1326 });
1327
1328 // Test appending to existing user message
1329 thread.update(cx, |thread, cx| {
1330 thread.push_user_content_block(
1331 acp::ContentBlock::Text(acp::TextContent {
1332 annotations: None,
1333 text: "world!".to_string(),
1334 }),
1335 cx,
1336 );
1337 });
1338
1339 thread.update(cx, |thread, cx| {
1340 assert_eq!(thread.entries.len(), 1);
1341 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1342 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1343 } else {
1344 panic!("Expected UserMessage");
1345 }
1346 });
1347
1348 // Test creating new user message after assistant message
1349 thread.update(cx, |thread, cx| {
1350 thread.push_assistant_content_block(
1351 acp::ContentBlock::Text(acp::TextContent {
1352 annotations: None,
1353 text: "Assistant response".to_string(),
1354 }),
1355 false,
1356 cx,
1357 );
1358 });
1359
1360 thread.update(cx, |thread, cx| {
1361 thread.push_user_content_block(
1362 acp::ContentBlock::Text(acp::TextContent {
1363 annotations: None,
1364 text: "New user message".to_string(),
1365 }),
1366 cx,
1367 );
1368 });
1369
1370 thread.update(cx, |thread, cx| {
1371 assert_eq!(thread.entries.len(), 3);
1372 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1373 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1374 } else {
1375 panic!("Expected UserMessage at index 2");
1376 }
1377 });
1378 }
1379
1380 #[gpui::test]
1381 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1382 init_test(cx);
1383
1384 let fs = FakeFs::new(cx.executor());
1385 let project = Project::test(fs, [], cx).await;
1386 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1387 |_, thread, mut cx| {
1388 async move {
1389 thread.update(&mut cx, |thread, cx| {
1390 thread
1391 .handle_session_update(
1392 acp::SessionUpdate::AgentThoughtChunk {
1393 content: "Thinking ".into(),
1394 },
1395 cx,
1396 )
1397 .unwrap();
1398 thread
1399 .handle_session_update(
1400 acp::SessionUpdate::AgentThoughtChunk {
1401 content: "hard!".into(),
1402 },
1403 cx,
1404 )
1405 .unwrap();
1406 })?;
1407 Ok(acp::PromptResponse {
1408 stop_reason: acp::StopReason::EndTurn,
1409 })
1410 }
1411 .boxed_local()
1412 },
1413 ));
1414
1415 let thread = cx
1416 .spawn(async move |mut cx| {
1417 connection
1418 .new_thread(project, Path::new(path!("/test")), &mut cx)
1419 .await
1420 })
1421 .await
1422 .unwrap();
1423
1424 thread
1425 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1426 .await
1427 .unwrap();
1428
1429 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1430 assert_eq!(
1431 output,
1432 indoc! {r#"
1433 ## User
1434
1435 Hello from Zed!
1436
1437 ## Assistant
1438
1439 <thinking>
1440 Thinking hard!
1441 </thinking>
1442
1443 "#}
1444 );
1445 }
1446
1447 #[gpui::test]
1448 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1449 init_test(cx);
1450
1451 let fs = FakeFs::new(cx.executor());
1452 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1453 .await;
1454 let project = Project::test(fs.clone(), [], cx).await;
1455 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1456 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1457 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1458 move |_, thread, mut cx| {
1459 let read_file_tx = read_file_tx.clone();
1460 async move {
1461 let content = thread
1462 .update(&mut cx, |thread, cx| {
1463 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1464 })
1465 .unwrap()
1466 .await
1467 .unwrap();
1468 assert_eq!(content, "one\ntwo\nthree\n");
1469 read_file_tx.take().unwrap().send(()).unwrap();
1470 thread
1471 .update(&mut cx, |thread, cx| {
1472 thread.write_text_file(
1473 path!("/tmp/foo").into(),
1474 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1475 cx,
1476 )
1477 })
1478 .unwrap()
1479 .await
1480 .unwrap();
1481 Ok(acp::PromptResponse {
1482 stop_reason: acp::StopReason::EndTurn,
1483 })
1484 }
1485 .boxed_local()
1486 },
1487 ));
1488
1489 let (worktree, pathbuf) = project
1490 .update(cx, |project, cx| {
1491 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1492 })
1493 .await
1494 .unwrap();
1495 let buffer = project
1496 .update(cx, |project, cx| {
1497 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1498 })
1499 .await
1500 .unwrap();
1501
1502 let thread = cx
1503 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1504 .await
1505 .unwrap();
1506
1507 let request = thread.update(cx, |thread, cx| {
1508 thread.send_raw("Extend the count in /tmp/foo", cx)
1509 });
1510 read_file_rx.await.ok();
1511 buffer.update(cx, |buffer, cx| {
1512 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1513 });
1514 cx.run_until_parked();
1515 assert_eq!(
1516 buffer.read_with(cx, |buffer, _| buffer.text()),
1517 "zero\none\ntwo\nthree\nfour\nfive\n"
1518 );
1519 assert_eq!(
1520 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1521 "zero\none\ntwo\nthree\nfour\nfive\n"
1522 );
1523 request.await.unwrap();
1524 }
1525
1526 #[gpui::test]
1527 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1528 init_test(cx);
1529
1530 let fs = FakeFs::new(cx.executor());
1531 let project = Project::test(fs, [], cx).await;
1532 let id = acp::ToolCallId("test".into());
1533
1534 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1535 let id = id.clone();
1536 move |_, thread, mut cx| {
1537 let id = id.clone();
1538 async move {
1539 thread
1540 .update(&mut cx, |thread, cx| {
1541 thread.handle_session_update(
1542 acp::SessionUpdate::ToolCall(acp::ToolCall {
1543 id: id.clone(),
1544 title: "Label".into(),
1545 kind: acp::ToolKind::Fetch,
1546 status: acp::ToolCallStatus::InProgress,
1547 content: vec![],
1548 locations: vec![],
1549 raw_input: None,
1550 }),
1551 cx,
1552 )
1553 })
1554 .unwrap()
1555 .unwrap();
1556 Ok(acp::PromptResponse {
1557 stop_reason: acp::StopReason::EndTurn,
1558 })
1559 }
1560 .boxed_local()
1561 }
1562 }));
1563
1564 let thread = cx
1565 .spawn(async move |mut cx| {
1566 connection
1567 .new_thread(project, Path::new(path!("/test")), &mut cx)
1568 .await
1569 })
1570 .await
1571 .unwrap();
1572
1573 let request = thread.update(cx, |thread, cx| {
1574 thread.send_raw("Fetch https://example.com", cx)
1575 });
1576
1577 run_until_first_tool_call(&thread, cx).await;
1578
1579 thread.read_with(cx, |thread, _| {
1580 assert!(matches!(
1581 thread.entries[1],
1582 AgentThreadEntry::ToolCall(ToolCall {
1583 status: ToolCallStatus::Allowed {
1584 status: acp::ToolCallStatus::InProgress,
1585 ..
1586 },
1587 ..
1588 })
1589 ));
1590 });
1591
1592 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1593
1594 thread.read_with(cx, |thread, _| {
1595 assert!(matches!(
1596 &thread.entries[1],
1597 AgentThreadEntry::ToolCall(ToolCall {
1598 status: ToolCallStatus::Canceled,
1599 ..
1600 })
1601 ));
1602 });
1603
1604 thread
1605 .update(cx, |thread, cx| {
1606 thread.handle_session_update(
1607 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1608 id,
1609 fields: acp::ToolCallUpdateFields {
1610 status: Some(acp::ToolCallStatus::Completed),
1611 ..Default::default()
1612 },
1613 }),
1614 cx,
1615 )
1616 })
1617 .unwrap();
1618
1619 request.await.unwrap();
1620
1621 thread.read_with(cx, |thread, _| {
1622 assert!(matches!(
1623 thread.entries[1],
1624 AgentThreadEntry::ToolCall(ToolCall {
1625 status: ToolCallStatus::Allowed {
1626 status: acp::ToolCallStatus::Completed,
1627 ..
1628 },
1629 ..
1630 })
1631 ));
1632 });
1633 }
1634
1635 #[gpui::test]
1636 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1637 init_test(cx);
1638 let fs = FakeFs::new(cx.background_executor.clone());
1639 fs.insert_tree(path!("/test"), json!({})).await;
1640 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1641
1642 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1643 move |_, thread, mut cx| {
1644 async move {
1645 thread
1646 .update(&mut cx, |thread, cx| {
1647 thread.handle_session_update(
1648 acp::SessionUpdate::ToolCall(acp::ToolCall {
1649 id: acp::ToolCallId("test".into()),
1650 title: "Label".into(),
1651 kind: acp::ToolKind::Edit,
1652 status: acp::ToolCallStatus::Completed,
1653 content: vec![acp::ToolCallContent::Diff {
1654 diff: acp::Diff {
1655 path: "/test/test.txt".into(),
1656 old_text: None,
1657 new_text: "foo".into(),
1658 },
1659 }],
1660 locations: vec![],
1661 raw_input: None,
1662 }),
1663 cx,
1664 )
1665 })
1666 .unwrap()
1667 .unwrap();
1668 Ok(acp::PromptResponse {
1669 stop_reason: acp::StopReason::EndTurn,
1670 })
1671 }
1672 .boxed_local()
1673 }
1674 }));
1675
1676 let thread = connection
1677 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1678 .await
1679 .unwrap();
1680 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1681 .await
1682 .unwrap();
1683
1684 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1685 }
1686
1687 async fn run_until_first_tool_call(
1688 thread: &Entity<AcpThread>,
1689 cx: &mut TestAppContext,
1690 ) -> usize {
1691 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1692
1693 let subscription = cx.update(|cx| {
1694 cx.subscribe(thread, move |thread, _, cx| {
1695 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1696 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1697 return tx.try_send(ix).unwrap();
1698 }
1699 }
1700 })
1701 });
1702
1703 select! {
1704 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1705 panic!("Timeout waiting for tool call")
1706 }
1707 ix = rx.next().fuse() => {
1708 drop(subscription);
1709 ix.unwrap()
1710 }
1711 }
1712 }
1713
1714 #[derive(Clone, Default)]
1715 struct FakeAgentConnection {
1716 auth_methods: Vec<acp::AuthMethod>,
1717 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1718 on_user_message: Option<
1719 Rc<
1720 dyn Fn(
1721 acp::PromptRequest,
1722 WeakEntity<AcpThread>,
1723 AsyncApp,
1724 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1725 + 'static,
1726 >,
1727 >,
1728 }
1729
1730 impl FakeAgentConnection {
1731 fn new() -> Self {
1732 Self {
1733 auth_methods: Vec::new(),
1734 on_user_message: None,
1735 sessions: Arc::default(),
1736 }
1737 }
1738
1739 #[expect(unused)]
1740 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1741 self.auth_methods = auth_methods;
1742 self
1743 }
1744
1745 fn on_user_message(
1746 mut self,
1747 handler: impl Fn(
1748 acp::PromptRequest,
1749 WeakEntity<AcpThread>,
1750 AsyncApp,
1751 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1752 + 'static,
1753 ) -> Self {
1754 self.on_user_message.replace(Rc::new(handler));
1755 self
1756 }
1757 }
1758
1759 impl AgentConnection for FakeAgentConnection {
1760 fn auth_methods(&self) -> &[acp::AuthMethod] {
1761 &self.auth_methods
1762 }
1763
1764 fn new_thread(
1765 self: Rc<Self>,
1766 project: Entity<Project>,
1767 _cwd: &Path,
1768 cx: &mut gpui::AsyncApp,
1769 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1770 let session_id = acp::SessionId(
1771 rand::thread_rng()
1772 .sample_iter(&rand::distributions::Alphanumeric)
1773 .take(7)
1774 .map(char::from)
1775 .collect::<String>()
1776 .into(),
1777 );
1778 let thread = cx
1779 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1780 .unwrap();
1781 self.sessions.lock().insert(session_id, thread.downgrade());
1782 Task::ready(Ok(thread))
1783 }
1784
1785 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1786 if self.auth_methods().iter().any(|m| m.id == method) {
1787 Task::ready(Ok(()))
1788 } else {
1789 Task::ready(Err(anyhow!("Invalid Auth Method")))
1790 }
1791 }
1792
1793 fn prompt(
1794 &self,
1795 params: acp::PromptRequest,
1796 cx: &mut App,
1797 ) -> Task<gpui::Result<acp::PromptResponse>> {
1798 let sessions = self.sessions.lock();
1799 let thread = sessions.get(¶ms.session_id).unwrap();
1800 if let Some(handler) = &self.on_user_message {
1801 let handler = handler.clone();
1802 let thread = thread.clone();
1803 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1804 } else {
1805 Task::ready(Ok(acp::PromptResponse {
1806 stop_reason: acp::StopReason::EndTurn,
1807 }))
1808 }
1809 }
1810
1811 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1812 let sessions = self.sessions.lock();
1813 let thread = sessions.get(&session_id).unwrap().clone();
1814
1815 cx.spawn(async move |cx| {
1816 thread
1817 .update(cx, |thread, cx| thread.cancel(cx))
1818 .unwrap()
1819 .await
1820 })
1821 .detach();
1822 }
1823 }
1824}