1mod connection;
2mod diff;
3mod terminal;
4
5pub use connection::*;
6pub use diff::*;
7pub use terminal::*;
8
9use action_log::ActionLog;
10use agent_client_protocol as acp;
11use anyhow::{Context as _, Result};
12use editor::Bias;
13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
14use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
15use itertools::Itertools;
16use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::process::ExitStatus;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
145 if let AgentThreadEntry::ToolCall(call) = self {
146 itertools::Either::Left(call.diffs())
147 } else {
148 itertools::Either::Right(std::iter::empty())
149 }
150 }
151
152 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
153 if let AgentThreadEntry::ToolCall(call) = self {
154 itertools::Either::Left(call.terminals())
155 } else {
156 itertools::Either::Right(std::iter::empty())
157 }
158 }
159
160 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
161 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
162 Some(locations)
163 } else {
164 None
165 }
166 }
167}
168
169#[derive(Debug)]
170pub struct ToolCall {
171 pub id: acp::ToolCallId,
172 pub label: Entity<Markdown>,
173 pub kind: acp::ToolKind,
174 pub content: Vec<ToolCallContent>,
175 pub status: ToolCallStatus,
176 pub locations: Vec<acp::ToolCallLocation>,
177 pub raw_input: Option<serde_json::Value>,
178 pub raw_output: Option<serde_json::Value>,
179}
180
181impl ToolCall {
182 fn from_acp(
183 tool_call: acp::ToolCall,
184 status: ToolCallStatus,
185 language_registry: Arc<LanguageRegistry>,
186 cx: &mut App,
187 ) -> Self {
188 Self {
189 id: tool_call.id,
190 label: cx.new(|cx| {
191 Markdown::new(
192 tool_call.title.into(),
193 Some(language_registry.clone()),
194 None,
195 cx,
196 )
197 }),
198 kind: tool_call.kind,
199 content: tool_call
200 .content
201 .into_iter()
202 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
203 .collect(),
204 locations: tool_call.locations,
205 status,
206 raw_input: tool_call.raw_input,
207 raw_output: tool_call.raw_output,
208 }
209 }
210
211 fn update_fields(
212 &mut self,
213 fields: acp::ToolCallUpdateFields,
214 language_registry: Arc<LanguageRegistry>,
215 cx: &mut App,
216 ) {
217 let acp::ToolCallUpdateFields {
218 kind,
219 status,
220 title,
221 content,
222 locations,
223 raw_input,
224 raw_output,
225 } = fields;
226
227 if let Some(kind) = kind {
228 self.kind = kind;
229 }
230
231 if let Some(status) = status {
232 self.status = ToolCallStatus::Allowed { status };
233 }
234
235 if let Some(title) = title {
236 self.label.update(cx, |label, cx| {
237 label.replace(title, cx);
238 });
239 }
240
241 if let Some(content) = content {
242 self.content = content
243 .into_iter()
244 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
245 .collect();
246 }
247
248 if let Some(locations) = locations {
249 self.locations = locations;
250 }
251
252 if let Some(raw_input) = raw_input {
253 self.raw_input = Some(raw_input);
254 }
255
256 if let Some(raw_output) = raw_output {
257 self.raw_output = Some(raw_output);
258 }
259 }
260
261 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
262 self.content.iter().filter_map(|content| match content {
263 ToolCallContent::Diff(diff) => Some(diff),
264 ToolCallContent::ContentBlock(_) => None,
265 ToolCallContent::Terminal(_) => None,
266 })
267 }
268
269 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
270 self.content.iter().filter_map(|content| match content {
271 ToolCallContent::Terminal(terminal) => Some(terminal),
272 ToolCallContent::ContentBlock(_) => None,
273 ToolCallContent::Diff(_) => None,
274 })
275 }
276
277 fn to_markdown(&self, cx: &App) -> String {
278 let mut markdown = format!(
279 "**Tool Call: {}**\nStatus: {}\n\n",
280 self.label.read(cx).source(),
281 self.status
282 );
283 for content in &self.content {
284 markdown.push_str(content.to_markdown(cx).as_str());
285 markdown.push_str("\n\n");
286 }
287 markdown
288 }
289}
290
291#[derive(Debug)]
292pub enum ToolCallStatus {
293 WaitingForConfirmation {
294 options: Vec<acp::PermissionOption>,
295 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
296 },
297 Allowed {
298 status: acp::ToolCallStatus,
299 },
300 Rejected,
301 Canceled,
302}
303
304impl Display for ToolCallStatus {
305 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
306 write!(
307 f,
308 "{}",
309 match self {
310 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
311 ToolCallStatus::Allowed { status } => match status {
312 acp::ToolCallStatus::Pending => "Pending",
313 acp::ToolCallStatus::InProgress => "In Progress",
314 acp::ToolCallStatus::Completed => "Completed",
315 acp::ToolCallStatus::Failed => "Failed",
316 },
317 ToolCallStatus::Rejected => "Rejected",
318 ToolCallStatus::Canceled => "Canceled",
319 }
320 )
321 }
322}
323
324#[derive(Debug, PartialEq, Clone)]
325pub enum ContentBlock {
326 Empty,
327 Markdown { markdown: Entity<Markdown> },
328}
329
330impl ContentBlock {
331 pub fn new(
332 block: acp::ContentBlock,
333 language_registry: &Arc<LanguageRegistry>,
334 cx: &mut App,
335 ) -> Self {
336 let mut this = Self::Empty;
337 this.append(block, language_registry, cx);
338 this
339 }
340
341 pub fn new_combined(
342 blocks: impl IntoIterator<Item = acp::ContentBlock>,
343 language_registry: Arc<LanguageRegistry>,
344 cx: &mut App,
345 ) -> Self {
346 let mut this = Self::Empty;
347 for block in blocks {
348 this.append(block, &language_registry, cx);
349 }
350 this
351 }
352
353 pub fn append(
354 &mut self,
355 block: acp::ContentBlock,
356 language_registry: &Arc<LanguageRegistry>,
357 cx: &mut App,
358 ) {
359 let new_content = match block {
360 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
361 acp::ContentBlock::ResourceLink(resource_link) => {
362 if let Some(path) = resource_link.uri.strip_prefix("file://") {
363 format!("{}", MentionPath(path.as_ref()))
364 } else {
365 resource_link.uri.clone()
366 }
367 }
368 acp::ContentBlock::Image(_)
369 | acp::ContentBlock::Audio(_)
370 | acp::ContentBlock::Resource(_) => String::new(),
371 };
372
373 match self {
374 ContentBlock::Empty => {
375 *self = ContentBlock::Markdown {
376 markdown: cx.new(|cx| {
377 Markdown::new(
378 new_content.into(),
379 Some(language_registry.clone()),
380 None,
381 cx,
382 )
383 }),
384 };
385 }
386 ContentBlock::Markdown { markdown } => {
387 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
388 }
389 }
390 }
391
392 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
393 match self {
394 ContentBlock::Empty => "",
395 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
396 }
397 }
398
399 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
400 match self {
401 ContentBlock::Empty => None,
402 ContentBlock::Markdown { markdown } => Some(markdown),
403 }
404 }
405}
406
407#[derive(Debug)]
408pub enum ToolCallContent {
409 ContentBlock(ContentBlock),
410 Diff(Entity<Diff>),
411 Terminal(Entity<Terminal>),
412}
413
414impl ToolCallContent {
415 pub fn from_acp(
416 content: acp::ToolCallContent,
417 language_registry: Arc<LanguageRegistry>,
418 cx: &mut App,
419 ) -> Self {
420 match content {
421 acp::ToolCallContent::Content { content } => {
422 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
423 }
424 acp::ToolCallContent::Diff { diff } => {
425 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
426 }
427 }
428 }
429
430 pub fn to_markdown(&self, cx: &App) -> String {
431 match self {
432 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
433 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
434 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
435 }
436 }
437}
438
439#[derive(Debug, PartialEq)]
440pub enum ToolCallUpdate {
441 UpdateFields(acp::ToolCallUpdate),
442 UpdateDiff(ToolCallUpdateDiff),
443 UpdateTerminal(ToolCallUpdateTerminal),
444}
445
446impl ToolCallUpdate {
447 fn id(&self) -> &acp::ToolCallId {
448 match self {
449 Self::UpdateFields(update) => &update.id,
450 Self::UpdateDiff(diff) => &diff.id,
451 Self::UpdateTerminal(terminal) => &terminal.id,
452 }
453 }
454}
455
456impl From<acp::ToolCallUpdate> for ToolCallUpdate {
457 fn from(update: acp::ToolCallUpdate) -> Self {
458 Self::UpdateFields(update)
459 }
460}
461
462impl From<ToolCallUpdateDiff> for ToolCallUpdate {
463 fn from(diff: ToolCallUpdateDiff) -> Self {
464 Self::UpdateDiff(diff)
465 }
466}
467
468#[derive(Debug, PartialEq)]
469pub struct ToolCallUpdateDiff {
470 pub id: acp::ToolCallId,
471 pub diff: Entity<Diff>,
472}
473
474impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
475 fn from(terminal: ToolCallUpdateTerminal) -> Self {
476 Self::UpdateTerminal(terminal)
477 }
478}
479
480#[derive(Debug, PartialEq)]
481pub struct ToolCallUpdateTerminal {
482 pub id: acp::ToolCallId,
483 pub terminal: Entity<Terminal>,
484}
485
486#[derive(Debug, Default)]
487pub struct Plan {
488 pub entries: Vec<PlanEntry>,
489}
490
491#[derive(Debug)]
492pub struct PlanStats<'a> {
493 pub in_progress_entry: Option<&'a PlanEntry>,
494 pub pending: u32,
495 pub completed: u32,
496}
497
498impl Plan {
499 pub fn is_empty(&self) -> bool {
500 self.entries.is_empty()
501 }
502
503 pub fn stats(&self) -> PlanStats<'_> {
504 let mut stats = PlanStats {
505 in_progress_entry: None,
506 pending: 0,
507 completed: 0,
508 };
509
510 for entry in &self.entries {
511 match &entry.status {
512 acp::PlanEntryStatus::Pending => {
513 stats.pending += 1;
514 }
515 acp::PlanEntryStatus::InProgress => {
516 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
517 }
518 acp::PlanEntryStatus::Completed => {
519 stats.completed += 1;
520 }
521 }
522 }
523
524 stats
525 }
526}
527
528#[derive(Debug)]
529pub struct PlanEntry {
530 pub content: Entity<Markdown>,
531 pub priority: acp::PlanEntryPriority,
532 pub status: acp::PlanEntryStatus,
533}
534
535impl PlanEntry {
536 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
537 Self {
538 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
539 priority: entry.priority,
540 status: entry.status,
541 }
542 }
543}
544
545pub struct AcpThread {
546 title: SharedString,
547 entries: Vec<AgentThreadEntry>,
548 plan: Plan,
549 project: Entity<Project>,
550 action_log: Entity<ActionLog>,
551 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
552 send_task: Option<Task<()>>,
553 connection: Rc<dyn AgentConnection>,
554 session_id: acp::SessionId,
555}
556
557pub enum AcpThreadEvent {
558 NewEntry,
559 EntryUpdated(usize),
560 ToolAuthorizationRequired,
561 Stopped,
562 Error,
563 ServerExited(ExitStatus),
564}
565
566impl EventEmitter<AcpThreadEvent> for AcpThread {}
567
568#[derive(PartialEq, Eq)]
569pub enum ThreadStatus {
570 Idle,
571 WaitingForToolConfirmation,
572 Generating,
573}
574
575#[derive(Debug, Clone)]
576pub enum LoadError {
577 Unsupported {
578 error_message: SharedString,
579 upgrade_message: SharedString,
580 upgrade_command: String,
581 },
582 Exited(i32),
583 Other(SharedString),
584}
585
586impl Display for LoadError {
587 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
588 match self {
589 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
590 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
591 LoadError::Other(msg) => write!(f, "{}", msg),
592 }
593 }
594}
595
596impl Error for LoadError {}
597
598impl AcpThread {
599 pub fn new(
600 title: impl Into<SharedString>,
601 connection: Rc<dyn AgentConnection>,
602 project: Entity<Project>,
603 session_id: acp::SessionId,
604 cx: &mut Context<Self>,
605 ) -> Self {
606 let action_log = cx.new(|_| ActionLog::new(project.clone()));
607
608 Self {
609 action_log,
610 shared_buffers: Default::default(),
611 entries: Default::default(),
612 plan: Default::default(),
613 title: title.into(),
614 project,
615 send_task: None,
616 connection,
617 session_id,
618 }
619 }
620
621 pub fn action_log(&self) -> &Entity<ActionLog> {
622 &self.action_log
623 }
624
625 pub fn project(&self) -> &Entity<Project> {
626 &self.project
627 }
628
629 pub fn title(&self) -> SharedString {
630 self.title.clone()
631 }
632
633 pub fn entries(&self) -> &[AgentThreadEntry] {
634 &self.entries
635 }
636
637 pub fn session_id(&self) -> &acp::SessionId {
638 &self.session_id
639 }
640
641 pub fn status(&self) -> ThreadStatus {
642 if self.send_task.is_some() {
643 if self.waiting_for_tool_confirmation() {
644 ThreadStatus::WaitingForToolConfirmation
645 } else {
646 ThreadStatus::Generating
647 }
648 } else {
649 ThreadStatus::Idle
650 }
651 }
652
653 pub fn has_pending_edit_tool_calls(&self) -> bool {
654 for entry in self.entries.iter().rev() {
655 match entry {
656 AgentThreadEntry::UserMessage(_) => return false,
657 AgentThreadEntry::ToolCall(
658 call @ ToolCall {
659 status:
660 ToolCallStatus::Allowed {
661 status:
662 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
663 },
664 ..
665 },
666 ) if call.diffs().next().is_some() => {
667 return true;
668 }
669 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
670 }
671 }
672
673 false
674 }
675
676 pub fn used_tools_since_last_user_message(&self) -> bool {
677 for entry in self.entries.iter().rev() {
678 match entry {
679 AgentThreadEntry::UserMessage(..) => return false,
680 AgentThreadEntry::AssistantMessage(..) => continue,
681 AgentThreadEntry::ToolCall(..) => return true,
682 }
683 }
684
685 false
686 }
687
688 pub fn handle_session_update(
689 &mut self,
690 update: acp::SessionUpdate,
691 cx: &mut Context<Self>,
692 ) -> Result<()> {
693 match update {
694 acp::SessionUpdate::UserMessageChunk { content } => {
695 self.push_user_content_block(content, cx);
696 }
697 acp::SessionUpdate::AgentMessageChunk { content } => {
698 self.push_assistant_content_block(content, false, cx);
699 }
700 acp::SessionUpdate::AgentThoughtChunk { content } => {
701 self.push_assistant_content_block(content, true, cx);
702 }
703 acp::SessionUpdate::ToolCall(tool_call) => {
704 self.upsert_tool_call(tool_call, cx);
705 }
706 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
707 self.update_tool_call(tool_call_update, cx)?;
708 }
709 acp::SessionUpdate::Plan(plan) => {
710 self.update_plan(plan, cx);
711 }
712 }
713 Ok(())
714 }
715
716 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
717 let language_registry = self.project.read(cx).languages().clone();
718 let entries_len = self.entries.len();
719
720 if let Some(last_entry) = self.entries.last_mut()
721 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
722 {
723 content.append(chunk, &language_registry, cx);
724 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
725 } else {
726 let content = ContentBlock::new(chunk, &language_registry, cx);
727 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
728 }
729 }
730
731 pub fn push_assistant_content_block(
732 &mut self,
733 chunk: acp::ContentBlock,
734 is_thought: bool,
735 cx: &mut Context<Self>,
736 ) {
737 let language_registry = self.project.read(cx).languages().clone();
738 let entries_len = self.entries.len();
739 if let Some(last_entry) = self.entries.last_mut()
740 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
741 {
742 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
743 match (chunks.last_mut(), is_thought) {
744 (Some(AssistantMessageChunk::Message { block }), false)
745 | (Some(AssistantMessageChunk::Thought { block }), true) => {
746 block.append(chunk, &language_registry, cx)
747 }
748 _ => {
749 let block = ContentBlock::new(chunk, &language_registry, cx);
750 if is_thought {
751 chunks.push(AssistantMessageChunk::Thought { block })
752 } else {
753 chunks.push(AssistantMessageChunk::Message { block })
754 }
755 }
756 }
757 } else {
758 let block = ContentBlock::new(chunk, &language_registry, cx);
759 let chunk = if is_thought {
760 AssistantMessageChunk::Thought { block }
761 } else {
762 AssistantMessageChunk::Message { block }
763 };
764
765 self.push_entry(
766 AgentThreadEntry::AssistantMessage(AssistantMessage {
767 chunks: vec![chunk],
768 }),
769 cx,
770 );
771 }
772 }
773
774 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
775 self.entries.push(entry);
776 cx.emit(AcpThreadEvent::NewEntry);
777 }
778
779 pub fn update_tool_call(
780 &mut self,
781 update: impl Into<ToolCallUpdate>,
782 cx: &mut Context<Self>,
783 ) -> Result<()> {
784 let update = update.into();
785 let languages = self.project.read(cx).languages().clone();
786
787 let (ix, current_call) = self
788 .tool_call_mut(update.id())
789 .context("Tool call not found")?;
790 match update {
791 ToolCallUpdate::UpdateFields(update) => {
792 current_call.update_fields(update.fields, languages, cx);
793 }
794 ToolCallUpdate::UpdateDiff(update) => {
795 current_call.content.clear();
796 current_call
797 .content
798 .push(ToolCallContent::Diff(update.diff));
799 }
800 ToolCallUpdate::UpdateTerminal(update) => {
801 current_call.content.clear();
802 current_call
803 .content
804 .push(ToolCallContent::Terminal(update.terminal));
805 }
806 }
807
808 cx.emit(AcpThreadEvent::EntryUpdated(ix));
809
810 Ok(())
811 }
812
813 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
814 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
815 let status = ToolCallStatus::Allowed {
816 status: tool_call.status,
817 };
818 self.upsert_tool_call_inner(tool_call, status, cx)
819 }
820
821 pub fn upsert_tool_call_inner(
822 &mut self,
823 tool_call: acp::ToolCall,
824 status: ToolCallStatus,
825 cx: &mut Context<Self>,
826 ) {
827 let language_registry = self.project.read(cx).languages().clone();
828 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
829
830 let location = call.locations.last().cloned();
831
832 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
833 *current_call = call;
834
835 cx.emit(AcpThreadEvent::EntryUpdated(ix));
836 } else {
837 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
838 }
839
840 if let Some(location) = location {
841 self.set_project_location(location, cx)
842 }
843 }
844
845 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
846 // The tool call we are looking for is typically the last one, or very close to the end.
847 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
848 self.entries
849 .iter_mut()
850 .enumerate()
851 .rev()
852 .find_map(|(index, tool_call)| {
853 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
854 && &tool_call.id == id
855 {
856 Some((index, tool_call))
857 } else {
858 None
859 }
860 })
861 }
862
863 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
864 self.project.update(cx, |project, cx| {
865 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
866 return;
867 };
868 let buffer = project.open_buffer(path, cx);
869 cx.spawn(async move |project, cx| {
870 let buffer = buffer.await?;
871
872 project.update(cx, |project, cx| {
873 let position = if let Some(line) = location.line {
874 let snapshot = buffer.read(cx).snapshot();
875 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
876 snapshot.anchor_before(point)
877 } else {
878 Anchor::MIN
879 };
880
881 project.set_agent_location(
882 Some(AgentLocation {
883 buffer: buffer.downgrade(),
884 position,
885 }),
886 cx,
887 );
888 })
889 })
890 .detach_and_log_err(cx);
891 });
892 }
893
894 pub fn request_tool_call_authorization(
895 &mut self,
896 tool_call: acp::ToolCall,
897 options: Vec<acp::PermissionOption>,
898 cx: &mut Context<Self>,
899 ) -> oneshot::Receiver<acp::PermissionOptionId> {
900 let (tx, rx) = oneshot::channel();
901
902 let status = ToolCallStatus::WaitingForConfirmation {
903 options,
904 respond_tx: tx,
905 };
906
907 self.upsert_tool_call_inner(tool_call, status, cx);
908 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
909 rx
910 }
911
912 pub fn authorize_tool_call(
913 &mut self,
914 id: acp::ToolCallId,
915 option_id: acp::PermissionOptionId,
916 option_kind: acp::PermissionOptionKind,
917 cx: &mut Context<Self>,
918 ) {
919 let Some((ix, call)) = self.tool_call_mut(&id) else {
920 return;
921 };
922
923 let new_status = match option_kind {
924 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
925 ToolCallStatus::Rejected
926 }
927 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
928 ToolCallStatus::Allowed {
929 status: acp::ToolCallStatus::InProgress,
930 }
931 }
932 };
933
934 let curr_status = mem::replace(&mut call.status, new_status);
935
936 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
937 respond_tx.send(option_id).log_err();
938 } else if cfg!(debug_assertions) {
939 panic!("tried to authorize an already authorized tool call");
940 }
941
942 cx.emit(AcpThreadEvent::EntryUpdated(ix));
943 }
944
945 /// Returns true if the last turn is awaiting tool authorization
946 pub fn waiting_for_tool_confirmation(&self) -> bool {
947 for entry in self.entries.iter().rev() {
948 match &entry {
949 AgentThreadEntry::ToolCall(call) => match call.status {
950 ToolCallStatus::WaitingForConfirmation { .. } => return true,
951 ToolCallStatus::Allowed { .. }
952 | ToolCallStatus::Rejected
953 | ToolCallStatus::Canceled => continue,
954 },
955 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
956 // Reached the beginning of the turn
957 return false;
958 }
959 }
960 }
961 false
962 }
963
964 pub fn plan(&self) -> &Plan {
965 &self.plan
966 }
967
968 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
969 let new_entries_len = request.entries.len();
970 let mut new_entries = request.entries.into_iter();
971
972 // Reuse existing markdown to prevent flickering
973 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
974 let PlanEntry {
975 content,
976 priority,
977 status,
978 } = old;
979 content.update(cx, |old, cx| {
980 old.replace(new.content, cx);
981 });
982 *priority = new.priority;
983 *status = new.status;
984 }
985 for new in new_entries {
986 self.plan.entries.push(PlanEntry::from_acp(new, cx))
987 }
988 self.plan.entries.truncate(new_entries_len);
989
990 cx.notify();
991 }
992
993 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
994 self.plan
995 .entries
996 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
997 cx.notify();
998 }
999
1000 #[cfg(any(test, feature = "test-support"))]
1001 pub fn send_raw(
1002 &mut self,
1003 message: &str,
1004 cx: &mut Context<Self>,
1005 ) -> BoxFuture<'static, Result<()>> {
1006 self.send(
1007 vec![acp::ContentBlock::Text(acp::TextContent {
1008 text: message.to_string(),
1009 annotations: None,
1010 })],
1011 cx,
1012 )
1013 }
1014
1015 pub fn send(
1016 &mut self,
1017 message: Vec<acp::ContentBlock>,
1018 cx: &mut Context<Self>,
1019 ) -> BoxFuture<'static, Result<()>> {
1020 let block = ContentBlock::new_combined(
1021 message.clone(),
1022 self.project.read(cx).languages().clone(),
1023 cx,
1024 );
1025 self.push_entry(
1026 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1027 cx,
1028 );
1029 self.clear_completed_plan_entries(cx);
1030
1031 let (tx, rx) = oneshot::channel();
1032 let cancel_task = self.cancel(cx);
1033
1034 self.send_task = Some(cx.spawn(async move |this, cx| {
1035 async {
1036 cancel_task.await;
1037
1038 let result = this
1039 .update(cx, |this, cx| {
1040 this.connection.prompt(
1041 acp::PromptRequest {
1042 prompt: message,
1043 session_id: this.session_id.clone(),
1044 },
1045 cx,
1046 )
1047 })?
1048 .await;
1049
1050 tx.send(result).log_err();
1051
1052 anyhow::Ok(())
1053 }
1054 .await
1055 .log_err();
1056 }));
1057
1058 cx.spawn(async move |this, cx| match rx.await {
1059 Ok(Err(e)) => {
1060 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1061 .log_err();
1062 Err(e)?
1063 }
1064 result => {
1065 let cancelled = matches!(
1066 result,
1067 Ok(Ok(acp::PromptResponse {
1068 stop_reason: acp::StopReason::Cancelled
1069 }))
1070 );
1071
1072 // We only take the task if the current prompt wasn't cancelled.
1073 //
1074 // This prompt may have been cancelled because another one was sent
1075 // while it was still generating. In these cases, dropping `send_task`
1076 // would cause the next generation to be cancelled.
1077 if !cancelled {
1078 this.update(cx, |this, _cx| this.send_task.take()).ok();
1079 }
1080
1081 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1082 .log_err();
1083 Ok(())
1084 }
1085 })
1086 .boxed()
1087 }
1088
1089 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1090 let Some(send_task) = self.send_task.take() else {
1091 return Task::ready(());
1092 };
1093
1094 for entry in self.entries.iter_mut() {
1095 if let AgentThreadEntry::ToolCall(call) = entry {
1096 let cancel = matches!(
1097 call.status,
1098 ToolCallStatus::WaitingForConfirmation { .. }
1099 | ToolCallStatus::Allowed {
1100 status: acp::ToolCallStatus::InProgress
1101 }
1102 );
1103
1104 if cancel {
1105 call.status = ToolCallStatus::Canceled;
1106 }
1107 }
1108 }
1109
1110 self.connection.cancel(&self.session_id, cx);
1111
1112 // Wait for the send task to complete
1113 cx.foreground_executor().spawn(send_task)
1114 }
1115
1116 pub fn read_text_file(
1117 &self,
1118 path: PathBuf,
1119 line: Option<u32>,
1120 limit: Option<u32>,
1121 reuse_shared_snapshot: bool,
1122 cx: &mut Context<Self>,
1123 ) -> Task<Result<String>> {
1124 let project = self.project.clone();
1125 let action_log = self.action_log.clone();
1126 cx.spawn(async move |this, cx| {
1127 let load = project.update(cx, |project, cx| {
1128 let path = project
1129 .project_path_for_absolute_path(&path, cx)
1130 .context("invalid path")?;
1131 anyhow::Ok(project.open_buffer(path, cx))
1132 });
1133 let buffer = load??.await?;
1134
1135 let snapshot = if reuse_shared_snapshot {
1136 this.read_with(cx, |this, _| {
1137 this.shared_buffers.get(&buffer.clone()).cloned()
1138 })
1139 .log_err()
1140 .flatten()
1141 } else {
1142 None
1143 };
1144
1145 let snapshot = if let Some(snapshot) = snapshot {
1146 snapshot
1147 } else {
1148 action_log.update(cx, |action_log, cx| {
1149 action_log.buffer_read(buffer.clone(), cx);
1150 })?;
1151 project.update(cx, |project, cx| {
1152 let position = buffer
1153 .read(cx)
1154 .snapshot()
1155 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1156 project.set_agent_location(
1157 Some(AgentLocation {
1158 buffer: buffer.downgrade(),
1159 position,
1160 }),
1161 cx,
1162 );
1163 })?;
1164
1165 buffer.update(cx, |buffer, _| buffer.snapshot())?
1166 };
1167
1168 this.update(cx, |this, _| {
1169 let text = snapshot.text();
1170 this.shared_buffers.insert(buffer.clone(), snapshot);
1171 if line.is_none() && limit.is_none() {
1172 return Ok(text);
1173 }
1174 let limit = limit.unwrap_or(u32::MAX) as usize;
1175 let Some(line) = line else {
1176 return Ok(text.lines().take(limit).collect::<String>());
1177 };
1178
1179 let count = text.lines().count();
1180 if count < line as usize {
1181 anyhow::bail!("There are only {} lines", count);
1182 }
1183 Ok(text
1184 .lines()
1185 .skip(line as usize + 1)
1186 .take(limit)
1187 .collect::<String>())
1188 })?
1189 })
1190 }
1191
1192 pub fn write_text_file(
1193 &self,
1194 path: PathBuf,
1195 content: String,
1196 cx: &mut Context<Self>,
1197 ) -> Task<Result<()>> {
1198 let project = self.project.clone();
1199 let action_log = self.action_log.clone();
1200 cx.spawn(async move |this, cx| {
1201 let load = project.update(cx, |project, cx| {
1202 let path = project
1203 .project_path_for_absolute_path(&path, cx)
1204 .context("invalid path")?;
1205 anyhow::Ok(project.open_buffer(path, cx))
1206 });
1207 let buffer = load??.await?;
1208 let snapshot = this.update(cx, |this, cx| {
1209 this.shared_buffers
1210 .get(&buffer)
1211 .cloned()
1212 .unwrap_or_else(|| buffer.read(cx).snapshot())
1213 })?;
1214 let edits = cx
1215 .background_executor()
1216 .spawn(async move {
1217 let old_text = snapshot.text();
1218 text_diff(old_text.as_str(), &content)
1219 .into_iter()
1220 .map(|(range, replacement)| {
1221 (
1222 snapshot.anchor_after(range.start)
1223 ..snapshot.anchor_before(range.end),
1224 replacement,
1225 )
1226 })
1227 .collect::<Vec<_>>()
1228 })
1229 .await;
1230 cx.update(|cx| {
1231 project.update(cx, |project, cx| {
1232 project.set_agent_location(
1233 Some(AgentLocation {
1234 buffer: buffer.downgrade(),
1235 position: edits
1236 .last()
1237 .map(|(range, _)| range.end)
1238 .unwrap_or(Anchor::MIN),
1239 }),
1240 cx,
1241 );
1242 });
1243
1244 action_log.update(cx, |action_log, cx| {
1245 action_log.buffer_read(buffer.clone(), cx);
1246 });
1247 buffer.update(cx, |buffer, cx| {
1248 buffer.edit(edits, None, cx);
1249 });
1250 action_log.update(cx, |action_log, cx| {
1251 action_log.buffer_edited(buffer.clone(), cx);
1252 });
1253 })?;
1254 project
1255 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1256 .await
1257 })
1258 }
1259
1260 pub fn to_markdown(&self, cx: &App) -> String {
1261 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1262 }
1263
1264 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1265 cx.emit(AcpThreadEvent::ServerExited(status));
1266 }
1267}
1268
1269#[cfg(test)]
1270mod tests {
1271 use super::*;
1272 use anyhow::anyhow;
1273 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1274 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1275 use indoc::indoc;
1276 use project::FakeFs;
1277 use rand::Rng as _;
1278 use serde_json::json;
1279 use settings::SettingsStore;
1280 use smol::stream::StreamExt as _;
1281 use std::{cell::RefCell, rc::Rc, time::Duration};
1282
1283 use util::path;
1284
1285 fn init_test(cx: &mut TestAppContext) {
1286 env_logger::try_init().ok();
1287 cx.update(|cx| {
1288 let settings_store = SettingsStore::test(cx);
1289 cx.set_global(settings_store);
1290 Project::init_settings(cx);
1291 language::init(cx);
1292 });
1293 }
1294
1295 #[gpui::test]
1296 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1297 init_test(cx);
1298
1299 let fs = FakeFs::new(cx.executor());
1300 let project = Project::test(fs, [], cx).await;
1301 let connection = Rc::new(FakeAgentConnection::new());
1302 let thread = cx
1303 .spawn(async move |mut cx| {
1304 connection
1305 .new_thread(project, Path::new(path!("/test")), &mut cx)
1306 .await
1307 })
1308 .await
1309 .unwrap();
1310
1311 // Test creating a new user message
1312 thread.update(cx, |thread, cx| {
1313 thread.push_user_content_block(
1314 acp::ContentBlock::Text(acp::TextContent {
1315 annotations: None,
1316 text: "Hello, ".to_string(),
1317 }),
1318 cx,
1319 );
1320 });
1321
1322 thread.update(cx, |thread, cx| {
1323 assert_eq!(thread.entries.len(), 1);
1324 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1325 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1326 } else {
1327 panic!("Expected UserMessage");
1328 }
1329 });
1330
1331 // Test appending to existing user message
1332 thread.update(cx, |thread, cx| {
1333 thread.push_user_content_block(
1334 acp::ContentBlock::Text(acp::TextContent {
1335 annotations: None,
1336 text: "world!".to_string(),
1337 }),
1338 cx,
1339 );
1340 });
1341
1342 thread.update(cx, |thread, cx| {
1343 assert_eq!(thread.entries.len(), 1);
1344 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1345 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1346 } else {
1347 panic!("Expected UserMessage");
1348 }
1349 });
1350
1351 // Test creating new user message after assistant message
1352 thread.update(cx, |thread, cx| {
1353 thread.push_assistant_content_block(
1354 acp::ContentBlock::Text(acp::TextContent {
1355 annotations: None,
1356 text: "Assistant response".to_string(),
1357 }),
1358 false,
1359 cx,
1360 );
1361 });
1362
1363 thread.update(cx, |thread, cx| {
1364 thread.push_user_content_block(
1365 acp::ContentBlock::Text(acp::TextContent {
1366 annotations: None,
1367 text: "New user message".to_string(),
1368 }),
1369 cx,
1370 );
1371 });
1372
1373 thread.update(cx, |thread, cx| {
1374 assert_eq!(thread.entries.len(), 3);
1375 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1376 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1377 } else {
1378 panic!("Expected UserMessage at index 2");
1379 }
1380 });
1381 }
1382
1383 #[gpui::test]
1384 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1385 init_test(cx);
1386
1387 let fs = FakeFs::new(cx.executor());
1388 let project = Project::test(fs, [], cx).await;
1389 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1390 |_, thread, mut cx| {
1391 async move {
1392 thread.update(&mut cx, |thread, cx| {
1393 thread
1394 .handle_session_update(
1395 acp::SessionUpdate::AgentThoughtChunk {
1396 content: "Thinking ".into(),
1397 },
1398 cx,
1399 )
1400 .unwrap();
1401 thread
1402 .handle_session_update(
1403 acp::SessionUpdate::AgentThoughtChunk {
1404 content: "hard!".into(),
1405 },
1406 cx,
1407 )
1408 .unwrap();
1409 })?;
1410 Ok(acp::PromptResponse {
1411 stop_reason: acp::StopReason::EndTurn,
1412 })
1413 }
1414 .boxed_local()
1415 },
1416 ));
1417
1418 let thread = cx
1419 .spawn(async move |mut cx| {
1420 connection
1421 .new_thread(project, Path::new(path!("/test")), &mut cx)
1422 .await
1423 })
1424 .await
1425 .unwrap();
1426
1427 thread
1428 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1429 .await
1430 .unwrap();
1431
1432 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1433 assert_eq!(
1434 output,
1435 indoc! {r#"
1436 ## User
1437
1438 Hello from Zed!
1439
1440 ## Assistant
1441
1442 <thinking>
1443 Thinking hard!
1444 </thinking>
1445
1446 "#}
1447 );
1448 }
1449
1450 #[gpui::test]
1451 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1452 init_test(cx);
1453
1454 let fs = FakeFs::new(cx.executor());
1455 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1456 .await;
1457 let project = Project::test(fs.clone(), [], cx).await;
1458 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1459 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1460 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1461 move |_, thread, mut cx| {
1462 let read_file_tx = read_file_tx.clone();
1463 async move {
1464 let content = thread
1465 .update(&mut cx, |thread, cx| {
1466 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1467 })
1468 .unwrap()
1469 .await
1470 .unwrap();
1471 assert_eq!(content, "one\ntwo\nthree\n");
1472 read_file_tx.take().unwrap().send(()).unwrap();
1473 thread
1474 .update(&mut cx, |thread, cx| {
1475 thread.write_text_file(
1476 path!("/tmp/foo").into(),
1477 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1478 cx,
1479 )
1480 })
1481 .unwrap()
1482 .await
1483 .unwrap();
1484 Ok(acp::PromptResponse {
1485 stop_reason: acp::StopReason::EndTurn,
1486 })
1487 }
1488 .boxed_local()
1489 },
1490 ));
1491
1492 let (worktree, pathbuf) = project
1493 .update(cx, |project, cx| {
1494 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1495 })
1496 .await
1497 .unwrap();
1498 let buffer = project
1499 .update(cx, |project, cx| {
1500 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1501 })
1502 .await
1503 .unwrap();
1504
1505 let thread = cx
1506 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1507 .await
1508 .unwrap();
1509
1510 let request = thread.update(cx, |thread, cx| {
1511 thread.send_raw("Extend the count in /tmp/foo", cx)
1512 });
1513 read_file_rx.await.ok();
1514 buffer.update(cx, |buffer, cx| {
1515 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1516 });
1517 cx.run_until_parked();
1518 assert_eq!(
1519 buffer.read_with(cx, |buffer, _| buffer.text()),
1520 "zero\none\ntwo\nthree\nfour\nfive\n"
1521 );
1522 assert_eq!(
1523 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1524 "zero\none\ntwo\nthree\nfour\nfive\n"
1525 );
1526 request.await.unwrap();
1527 }
1528
1529 #[gpui::test]
1530 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1531 init_test(cx);
1532
1533 let fs = FakeFs::new(cx.executor());
1534 let project = Project::test(fs, [], cx).await;
1535 let id = acp::ToolCallId("test".into());
1536
1537 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1538 let id = id.clone();
1539 move |_, thread, mut cx| {
1540 let id = id.clone();
1541 async move {
1542 thread
1543 .update(&mut cx, |thread, cx| {
1544 thread.handle_session_update(
1545 acp::SessionUpdate::ToolCall(acp::ToolCall {
1546 id: id.clone(),
1547 title: "Label".into(),
1548 kind: acp::ToolKind::Fetch,
1549 status: acp::ToolCallStatus::InProgress,
1550 content: vec![],
1551 locations: vec![],
1552 raw_input: None,
1553 raw_output: None,
1554 }),
1555 cx,
1556 )
1557 })
1558 .unwrap()
1559 .unwrap();
1560 Ok(acp::PromptResponse {
1561 stop_reason: acp::StopReason::EndTurn,
1562 })
1563 }
1564 .boxed_local()
1565 }
1566 }));
1567
1568 let thread = cx
1569 .spawn(async move |mut cx| {
1570 connection
1571 .new_thread(project, Path::new(path!("/test")), &mut cx)
1572 .await
1573 })
1574 .await
1575 .unwrap();
1576
1577 let request = thread.update(cx, |thread, cx| {
1578 thread.send_raw("Fetch https://example.com", cx)
1579 });
1580
1581 run_until_first_tool_call(&thread, cx).await;
1582
1583 thread.read_with(cx, |thread, _| {
1584 assert!(matches!(
1585 thread.entries[1],
1586 AgentThreadEntry::ToolCall(ToolCall {
1587 status: ToolCallStatus::Allowed {
1588 status: acp::ToolCallStatus::InProgress,
1589 ..
1590 },
1591 ..
1592 })
1593 ));
1594 });
1595
1596 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1597
1598 thread.read_with(cx, |thread, _| {
1599 assert!(matches!(
1600 &thread.entries[1],
1601 AgentThreadEntry::ToolCall(ToolCall {
1602 status: ToolCallStatus::Canceled,
1603 ..
1604 })
1605 ));
1606 });
1607
1608 thread
1609 .update(cx, |thread, cx| {
1610 thread.handle_session_update(
1611 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1612 id,
1613 fields: acp::ToolCallUpdateFields {
1614 status: Some(acp::ToolCallStatus::Completed),
1615 ..Default::default()
1616 },
1617 }),
1618 cx,
1619 )
1620 })
1621 .unwrap();
1622
1623 request.await.unwrap();
1624
1625 thread.read_with(cx, |thread, _| {
1626 assert!(matches!(
1627 thread.entries[1],
1628 AgentThreadEntry::ToolCall(ToolCall {
1629 status: ToolCallStatus::Allowed {
1630 status: acp::ToolCallStatus::Completed,
1631 ..
1632 },
1633 ..
1634 })
1635 ));
1636 });
1637 }
1638
1639 #[gpui::test]
1640 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1641 init_test(cx);
1642 let fs = FakeFs::new(cx.background_executor.clone());
1643 fs.insert_tree(path!("/test"), json!({})).await;
1644 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1645
1646 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1647 move |_, thread, mut cx| {
1648 async move {
1649 thread
1650 .update(&mut cx, |thread, cx| {
1651 thread.handle_session_update(
1652 acp::SessionUpdate::ToolCall(acp::ToolCall {
1653 id: acp::ToolCallId("test".into()),
1654 title: "Label".into(),
1655 kind: acp::ToolKind::Edit,
1656 status: acp::ToolCallStatus::Completed,
1657 content: vec![acp::ToolCallContent::Diff {
1658 diff: acp::Diff {
1659 path: "/test/test.txt".into(),
1660 old_text: None,
1661 new_text: "foo".into(),
1662 },
1663 }],
1664 locations: vec![],
1665 raw_input: None,
1666 raw_output: None,
1667 }),
1668 cx,
1669 )
1670 })
1671 .unwrap()
1672 .unwrap();
1673 Ok(acp::PromptResponse {
1674 stop_reason: acp::StopReason::EndTurn,
1675 })
1676 }
1677 .boxed_local()
1678 }
1679 }));
1680
1681 let thread = connection
1682 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1683 .await
1684 .unwrap();
1685 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1686 .await
1687 .unwrap();
1688
1689 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1690 }
1691
1692 async fn run_until_first_tool_call(
1693 thread: &Entity<AcpThread>,
1694 cx: &mut TestAppContext,
1695 ) -> usize {
1696 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1697
1698 let subscription = cx.update(|cx| {
1699 cx.subscribe(thread, move |thread, _, cx| {
1700 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1701 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1702 return tx.try_send(ix).unwrap();
1703 }
1704 }
1705 })
1706 });
1707
1708 select! {
1709 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1710 panic!("Timeout waiting for tool call")
1711 }
1712 ix = rx.next().fuse() => {
1713 drop(subscription);
1714 ix.unwrap()
1715 }
1716 }
1717 }
1718
1719 #[derive(Clone, Default)]
1720 struct FakeAgentConnection {
1721 auth_methods: Vec<acp::AuthMethod>,
1722 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1723 on_user_message: Option<
1724 Rc<
1725 dyn Fn(
1726 acp::PromptRequest,
1727 WeakEntity<AcpThread>,
1728 AsyncApp,
1729 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1730 + 'static,
1731 >,
1732 >,
1733 }
1734
1735 impl FakeAgentConnection {
1736 fn new() -> Self {
1737 Self {
1738 auth_methods: Vec::new(),
1739 on_user_message: None,
1740 sessions: Arc::default(),
1741 }
1742 }
1743
1744 #[expect(unused)]
1745 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1746 self.auth_methods = auth_methods;
1747 self
1748 }
1749
1750 fn on_user_message(
1751 mut self,
1752 handler: impl Fn(
1753 acp::PromptRequest,
1754 WeakEntity<AcpThread>,
1755 AsyncApp,
1756 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1757 + 'static,
1758 ) -> Self {
1759 self.on_user_message.replace(Rc::new(handler));
1760 self
1761 }
1762 }
1763
1764 impl AgentConnection for FakeAgentConnection {
1765 fn auth_methods(&self) -> &[acp::AuthMethod] {
1766 &self.auth_methods
1767 }
1768
1769 fn new_thread(
1770 self: Rc<Self>,
1771 project: Entity<Project>,
1772 _cwd: &Path,
1773 cx: &mut gpui::AsyncApp,
1774 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1775 let session_id = acp::SessionId(
1776 rand::thread_rng()
1777 .sample_iter(&rand::distributions::Alphanumeric)
1778 .take(7)
1779 .map(char::from)
1780 .collect::<String>()
1781 .into(),
1782 );
1783 let thread = cx
1784 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1785 .unwrap();
1786 self.sessions.lock().insert(session_id, thread.downgrade());
1787 Task::ready(Ok(thread))
1788 }
1789
1790 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1791 if self.auth_methods().iter().any(|m| m.id == method) {
1792 Task::ready(Ok(()))
1793 } else {
1794 Task::ready(Err(anyhow!("Invalid Auth Method")))
1795 }
1796 }
1797
1798 fn prompt(
1799 &self,
1800 params: acp::PromptRequest,
1801 cx: &mut App,
1802 ) -> Task<gpui::Result<acp::PromptResponse>> {
1803 let sessions = self.sessions.lock();
1804 let thread = sessions.get(¶ms.session_id).unwrap();
1805 if let Some(handler) = &self.on_user_message {
1806 let handler = handler.clone();
1807 let thread = thread.clone();
1808 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1809 } else {
1810 Task::ready(Ok(acp::PromptResponse {
1811 stop_reason: acp::StopReason::EndTurn,
1812 }))
1813 }
1814 }
1815
1816 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1817 let sessions = self.sessions.lock();
1818 let thread = sessions.get(&session_id).unwrap().clone();
1819
1820 cx.spawn(async move |cx| {
1821 thread
1822 .update(cx, |thread, cx| thread.cancel(cx))
1823 .unwrap()
1824 .await
1825 })
1826 .detach();
1827 }
1828 }
1829}