1mod connection;
2mod diff;
3mod terminal;
4
5pub use connection::*;
6pub use diff::*;
7pub use terminal::*;
8
9use action_log::ActionLog;
10use agent_client_protocol as acp;
11use anyhow::{Context as _, Result};
12use editor::Bias;
13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
14use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
15use itertools::Itertools;
16use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::process::ExitStatus;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
145 if let AgentThreadEntry::ToolCall(call) = self {
146 itertools::Either::Left(call.diffs())
147 } else {
148 itertools::Either::Right(std::iter::empty())
149 }
150 }
151
152 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
153 if let AgentThreadEntry::ToolCall(call) = self {
154 itertools::Either::Left(call.terminals())
155 } else {
156 itertools::Either::Right(std::iter::empty())
157 }
158 }
159
160 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
161 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
162 Some(locations)
163 } else {
164 None
165 }
166 }
167}
168
169#[derive(Debug)]
170pub struct ToolCall {
171 pub id: acp::ToolCallId,
172 pub label: Entity<Markdown>,
173 pub kind: acp::ToolKind,
174 pub content: Vec<ToolCallContent>,
175 pub status: ToolCallStatus,
176 pub locations: Vec<acp::ToolCallLocation>,
177 pub raw_input: Option<serde_json::Value>,
178 pub raw_output: Option<serde_json::Value>,
179}
180
181impl ToolCall {
182 fn from_acp(
183 tool_call: acp::ToolCall,
184 status: ToolCallStatus,
185 language_registry: Arc<LanguageRegistry>,
186 cx: &mut App,
187 ) -> Self {
188 Self {
189 id: tool_call.id,
190 label: cx.new(|cx| {
191 Markdown::new(
192 tool_call.title.into(),
193 Some(language_registry.clone()),
194 None,
195 cx,
196 )
197 }),
198 kind: tool_call.kind,
199 content: tool_call
200 .content
201 .into_iter()
202 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
203 .collect(),
204 locations: tool_call.locations,
205 status,
206 raw_input: tool_call.raw_input,
207 raw_output: tool_call.raw_output,
208 }
209 }
210
211 fn update_fields(
212 &mut self,
213 fields: acp::ToolCallUpdateFields,
214 language_registry: Arc<LanguageRegistry>,
215 cx: &mut App,
216 ) {
217 let acp::ToolCallUpdateFields {
218 kind,
219 status,
220 title,
221 content,
222 locations,
223 raw_input,
224 raw_output,
225 } = fields;
226
227 if let Some(kind) = kind {
228 self.kind = kind;
229 }
230
231 if let Some(status) = status {
232 self.status = ToolCallStatus::Allowed { status };
233 }
234
235 if let Some(title) = title {
236 self.label.update(cx, |label, cx| {
237 label.replace(title, cx);
238 });
239 }
240
241 if let Some(content) = content {
242 self.content = content
243 .into_iter()
244 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
245 .collect();
246 }
247
248 if let Some(locations) = locations {
249 self.locations = locations;
250 }
251
252 if let Some(raw_input) = raw_input {
253 self.raw_input = Some(raw_input);
254 }
255
256 if let Some(raw_output) = raw_output {
257 if self.content.is_empty() {
258 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
259 {
260 self.content
261 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
262 markdown,
263 }));
264 }
265 }
266 self.raw_output = Some(raw_output);
267 }
268 }
269
270 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
271 self.content.iter().filter_map(|content| match content {
272 ToolCallContent::Diff(diff) => Some(diff),
273 ToolCallContent::ContentBlock(_) => None,
274 ToolCallContent::Terminal(_) => None,
275 })
276 }
277
278 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
279 self.content.iter().filter_map(|content| match content {
280 ToolCallContent::Terminal(terminal) => Some(terminal),
281 ToolCallContent::ContentBlock(_) => None,
282 ToolCallContent::Diff(_) => None,
283 })
284 }
285
286 fn to_markdown(&self, cx: &App) -> String {
287 let mut markdown = format!(
288 "**Tool Call: {}**\nStatus: {}\n\n",
289 self.label.read(cx).source(),
290 self.status
291 );
292 for content in &self.content {
293 markdown.push_str(content.to_markdown(cx).as_str());
294 markdown.push_str("\n\n");
295 }
296 markdown
297 }
298}
299
300#[derive(Debug)]
301pub enum ToolCallStatus {
302 WaitingForConfirmation {
303 options: Vec<acp::PermissionOption>,
304 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
305 },
306 Allowed {
307 status: acp::ToolCallStatus,
308 },
309 Rejected,
310 Canceled,
311}
312
313impl Display for ToolCallStatus {
314 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
315 write!(
316 f,
317 "{}",
318 match self {
319 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
320 ToolCallStatus::Allowed { status } => match status {
321 acp::ToolCallStatus::Pending => "Pending",
322 acp::ToolCallStatus::InProgress => "In Progress",
323 acp::ToolCallStatus::Completed => "Completed",
324 acp::ToolCallStatus::Failed => "Failed",
325 },
326 ToolCallStatus::Rejected => "Rejected",
327 ToolCallStatus::Canceled => "Canceled",
328 }
329 )
330 }
331}
332
333#[derive(Debug, PartialEq, Clone)]
334pub enum ContentBlock {
335 Empty,
336 Markdown { markdown: Entity<Markdown> },
337}
338
339impl ContentBlock {
340 pub fn new(
341 block: acp::ContentBlock,
342 language_registry: &Arc<LanguageRegistry>,
343 cx: &mut App,
344 ) -> Self {
345 let mut this = Self::Empty;
346 this.append(block, language_registry, cx);
347 this
348 }
349
350 pub fn new_combined(
351 blocks: impl IntoIterator<Item = acp::ContentBlock>,
352 language_registry: Arc<LanguageRegistry>,
353 cx: &mut App,
354 ) -> Self {
355 let mut this = Self::Empty;
356 for block in blocks {
357 this.append(block, &language_registry, cx);
358 }
359 this
360 }
361
362 pub fn append(
363 &mut self,
364 block: acp::ContentBlock,
365 language_registry: &Arc<LanguageRegistry>,
366 cx: &mut App,
367 ) {
368 let new_content = match block {
369 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
370 acp::ContentBlock::ResourceLink(resource_link) => {
371 if let Some(path) = resource_link.uri.strip_prefix("file://") {
372 format!("{}", MentionPath(path.as_ref()))
373 } else {
374 resource_link.uri.clone()
375 }
376 }
377 acp::ContentBlock::Image(_)
378 | acp::ContentBlock::Audio(_)
379 | acp::ContentBlock::Resource(_) => String::new(),
380 };
381
382 match self {
383 ContentBlock::Empty => {
384 *self = ContentBlock::Markdown {
385 markdown: cx.new(|cx| {
386 Markdown::new(
387 new_content.into(),
388 Some(language_registry.clone()),
389 None,
390 cx,
391 )
392 }),
393 };
394 }
395 ContentBlock::Markdown { markdown } => {
396 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
397 }
398 }
399 }
400
401 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
402 match self {
403 ContentBlock::Empty => "",
404 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
405 }
406 }
407
408 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
409 match self {
410 ContentBlock::Empty => None,
411 ContentBlock::Markdown { markdown } => Some(markdown),
412 }
413 }
414}
415
416#[derive(Debug)]
417pub enum ToolCallContent {
418 ContentBlock(ContentBlock),
419 Diff(Entity<Diff>),
420 Terminal(Entity<Terminal>),
421}
422
423impl ToolCallContent {
424 pub fn from_acp(
425 content: acp::ToolCallContent,
426 language_registry: Arc<LanguageRegistry>,
427 cx: &mut App,
428 ) -> Self {
429 match content {
430 acp::ToolCallContent::Content { content } => {
431 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
432 }
433 acp::ToolCallContent::Diff { diff } => {
434 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
435 }
436 }
437 }
438
439 pub fn to_markdown(&self, cx: &App) -> String {
440 match self {
441 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
442 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
443 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
444 }
445 }
446}
447
448#[derive(Debug, PartialEq)]
449pub enum ToolCallUpdate {
450 UpdateFields(acp::ToolCallUpdate),
451 UpdateDiff(ToolCallUpdateDiff),
452 UpdateTerminal(ToolCallUpdateTerminal),
453}
454
455impl ToolCallUpdate {
456 fn id(&self) -> &acp::ToolCallId {
457 match self {
458 Self::UpdateFields(update) => &update.id,
459 Self::UpdateDiff(diff) => &diff.id,
460 Self::UpdateTerminal(terminal) => &terminal.id,
461 }
462 }
463}
464
465impl From<acp::ToolCallUpdate> for ToolCallUpdate {
466 fn from(update: acp::ToolCallUpdate) -> Self {
467 Self::UpdateFields(update)
468 }
469}
470
471impl From<ToolCallUpdateDiff> for ToolCallUpdate {
472 fn from(diff: ToolCallUpdateDiff) -> Self {
473 Self::UpdateDiff(diff)
474 }
475}
476
477#[derive(Debug, PartialEq)]
478pub struct ToolCallUpdateDiff {
479 pub id: acp::ToolCallId,
480 pub diff: Entity<Diff>,
481}
482
483impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
484 fn from(terminal: ToolCallUpdateTerminal) -> Self {
485 Self::UpdateTerminal(terminal)
486 }
487}
488
489#[derive(Debug, PartialEq)]
490pub struct ToolCallUpdateTerminal {
491 pub id: acp::ToolCallId,
492 pub terminal: Entity<Terminal>,
493}
494
495#[derive(Debug, Default)]
496pub struct Plan {
497 pub entries: Vec<PlanEntry>,
498}
499
500#[derive(Debug)]
501pub struct PlanStats<'a> {
502 pub in_progress_entry: Option<&'a PlanEntry>,
503 pub pending: u32,
504 pub completed: u32,
505}
506
507impl Plan {
508 pub fn is_empty(&self) -> bool {
509 self.entries.is_empty()
510 }
511
512 pub fn stats(&self) -> PlanStats<'_> {
513 let mut stats = PlanStats {
514 in_progress_entry: None,
515 pending: 0,
516 completed: 0,
517 };
518
519 for entry in &self.entries {
520 match &entry.status {
521 acp::PlanEntryStatus::Pending => {
522 stats.pending += 1;
523 }
524 acp::PlanEntryStatus::InProgress => {
525 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
526 }
527 acp::PlanEntryStatus::Completed => {
528 stats.completed += 1;
529 }
530 }
531 }
532
533 stats
534 }
535}
536
537#[derive(Debug)]
538pub struct PlanEntry {
539 pub content: Entity<Markdown>,
540 pub priority: acp::PlanEntryPriority,
541 pub status: acp::PlanEntryStatus,
542}
543
544impl PlanEntry {
545 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
546 Self {
547 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
548 priority: entry.priority,
549 status: entry.status,
550 }
551 }
552}
553
554pub struct AcpThread {
555 title: SharedString,
556 entries: Vec<AgentThreadEntry>,
557 plan: Plan,
558 project: Entity<Project>,
559 action_log: Entity<ActionLog>,
560 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
561 send_task: Option<Task<()>>,
562 connection: Rc<dyn AgentConnection>,
563 session_id: acp::SessionId,
564}
565
566pub enum AcpThreadEvent {
567 NewEntry,
568 EntryUpdated(usize),
569 ToolAuthorizationRequired,
570 Stopped,
571 Error,
572 ServerExited(ExitStatus),
573}
574
575impl EventEmitter<AcpThreadEvent> for AcpThread {}
576
577#[derive(PartialEq, Eq)]
578pub enum ThreadStatus {
579 Idle,
580 WaitingForToolConfirmation,
581 Generating,
582}
583
584#[derive(Debug, Clone)]
585pub enum LoadError {
586 Unsupported {
587 error_message: SharedString,
588 upgrade_message: SharedString,
589 upgrade_command: String,
590 },
591 Exited(i32),
592 Other(SharedString),
593}
594
595impl Display for LoadError {
596 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
597 match self {
598 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
599 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
600 LoadError::Other(msg) => write!(f, "{}", msg),
601 }
602 }
603}
604
605impl Error for LoadError {}
606
607impl AcpThread {
608 pub fn new(
609 title: impl Into<SharedString>,
610 connection: Rc<dyn AgentConnection>,
611 project: Entity<Project>,
612 session_id: acp::SessionId,
613 cx: &mut Context<Self>,
614 ) -> Self {
615 let action_log = cx.new(|_| ActionLog::new(project.clone()));
616
617 Self {
618 action_log,
619 shared_buffers: Default::default(),
620 entries: Default::default(),
621 plan: Default::default(),
622 title: title.into(),
623 project,
624 send_task: None,
625 connection,
626 session_id,
627 }
628 }
629
630 pub fn action_log(&self) -> &Entity<ActionLog> {
631 &self.action_log
632 }
633
634 pub fn project(&self) -> &Entity<Project> {
635 &self.project
636 }
637
638 pub fn title(&self) -> SharedString {
639 self.title.clone()
640 }
641
642 pub fn entries(&self) -> &[AgentThreadEntry] {
643 &self.entries
644 }
645
646 pub fn session_id(&self) -> &acp::SessionId {
647 &self.session_id
648 }
649
650 pub fn status(&self) -> ThreadStatus {
651 if self.send_task.is_some() {
652 if self.waiting_for_tool_confirmation() {
653 ThreadStatus::WaitingForToolConfirmation
654 } else {
655 ThreadStatus::Generating
656 }
657 } else {
658 ThreadStatus::Idle
659 }
660 }
661
662 pub fn has_pending_edit_tool_calls(&self) -> bool {
663 for entry in self.entries.iter().rev() {
664 match entry {
665 AgentThreadEntry::UserMessage(_) => return false,
666 AgentThreadEntry::ToolCall(
667 call @ ToolCall {
668 status:
669 ToolCallStatus::Allowed {
670 status:
671 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
672 },
673 ..
674 },
675 ) if call.diffs().next().is_some() => {
676 return true;
677 }
678 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
679 }
680 }
681
682 false
683 }
684
685 pub fn used_tools_since_last_user_message(&self) -> bool {
686 for entry in self.entries.iter().rev() {
687 match entry {
688 AgentThreadEntry::UserMessage(..) => return false,
689 AgentThreadEntry::AssistantMessage(..) => continue,
690 AgentThreadEntry::ToolCall(..) => return true,
691 }
692 }
693
694 false
695 }
696
697 pub fn handle_session_update(
698 &mut self,
699 update: acp::SessionUpdate,
700 cx: &mut Context<Self>,
701 ) -> Result<()> {
702 match update {
703 acp::SessionUpdate::UserMessageChunk { content } => {
704 self.push_user_content_block(content, cx);
705 }
706 acp::SessionUpdate::AgentMessageChunk { content } => {
707 self.push_assistant_content_block(content, false, cx);
708 }
709 acp::SessionUpdate::AgentThoughtChunk { content } => {
710 self.push_assistant_content_block(content, true, cx);
711 }
712 acp::SessionUpdate::ToolCall(tool_call) => {
713 self.upsert_tool_call(tool_call, cx);
714 }
715 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
716 self.update_tool_call(tool_call_update, cx)?;
717 }
718 acp::SessionUpdate::Plan(plan) => {
719 self.update_plan(plan, cx);
720 }
721 }
722 Ok(())
723 }
724
725 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
726 let language_registry = self.project.read(cx).languages().clone();
727 let entries_len = self.entries.len();
728
729 if let Some(last_entry) = self.entries.last_mut()
730 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
731 {
732 content.append(chunk, &language_registry, cx);
733 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
734 } else {
735 let content = ContentBlock::new(chunk, &language_registry, cx);
736 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
737 }
738 }
739
740 pub fn push_assistant_content_block(
741 &mut self,
742 chunk: acp::ContentBlock,
743 is_thought: bool,
744 cx: &mut Context<Self>,
745 ) {
746 let language_registry = self.project.read(cx).languages().clone();
747 let entries_len = self.entries.len();
748 if let Some(last_entry) = self.entries.last_mut()
749 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
750 {
751 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
752 match (chunks.last_mut(), is_thought) {
753 (Some(AssistantMessageChunk::Message { block }), false)
754 | (Some(AssistantMessageChunk::Thought { block }), true) => {
755 block.append(chunk, &language_registry, cx)
756 }
757 _ => {
758 let block = ContentBlock::new(chunk, &language_registry, cx);
759 if is_thought {
760 chunks.push(AssistantMessageChunk::Thought { block })
761 } else {
762 chunks.push(AssistantMessageChunk::Message { block })
763 }
764 }
765 }
766 } else {
767 let block = ContentBlock::new(chunk, &language_registry, cx);
768 let chunk = if is_thought {
769 AssistantMessageChunk::Thought { block }
770 } else {
771 AssistantMessageChunk::Message { block }
772 };
773
774 self.push_entry(
775 AgentThreadEntry::AssistantMessage(AssistantMessage {
776 chunks: vec![chunk],
777 }),
778 cx,
779 );
780 }
781 }
782
783 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
784 self.entries.push(entry);
785 cx.emit(AcpThreadEvent::NewEntry);
786 }
787
788 pub fn update_tool_call(
789 &mut self,
790 update: impl Into<ToolCallUpdate>,
791 cx: &mut Context<Self>,
792 ) -> Result<()> {
793 let update = update.into();
794 let languages = self.project.read(cx).languages().clone();
795
796 let (ix, current_call) = self
797 .tool_call_mut(update.id())
798 .context("Tool call not found")?;
799 match update {
800 ToolCallUpdate::UpdateFields(update) => {
801 current_call.update_fields(update.fields, languages, cx);
802 }
803 ToolCallUpdate::UpdateDiff(update) => {
804 current_call.content.clear();
805 current_call
806 .content
807 .push(ToolCallContent::Diff(update.diff));
808 }
809 ToolCallUpdate::UpdateTerminal(update) => {
810 current_call.content.clear();
811 current_call
812 .content
813 .push(ToolCallContent::Terminal(update.terminal));
814 }
815 }
816
817 cx.emit(AcpThreadEvent::EntryUpdated(ix));
818
819 Ok(())
820 }
821
822 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
823 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
824 let status = ToolCallStatus::Allowed {
825 status: tool_call.status,
826 };
827 self.upsert_tool_call_inner(tool_call, status, cx)
828 }
829
830 pub fn upsert_tool_call_inner(
831 &mut self,
832 tool_call: acp::ToolCall,
833 status: ToolCallStatus,
834 cx: &mut Context<Self>,
835 ) {
836 let language_registry = self.project.read(cx).languages().clone();
837 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
838
839 let location = call.locations.last().cloned();
840
841 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
842 *current_call = call;
843
844 cx.emit(AcpThreadEvent::EntryUpdated(ix));
845 } else {
846 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
847 }
848
849 if let Some(location) = location {
850 self.set_project_location(location, cx)
851 }
852 }
853
854 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
855 // The tool call we are looking for is typically the last one, or very close to the end.
856 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
857 self.entries
858 .iter_mut()
859 .enumerate()
860 .rev()
861 .find_map(|(index, tool_call)| {
862 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
863 && &tool_call.id == id
864 {
865 Some((index, tool_call))
866 } else {
867 None
868 }
869 })
870 }
871
872 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
873 self.project.update(cx, |project, cx| {
874 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
875 return;
876 };
877 let buffer = project.open_buffer(path, cx);
878 cx.spawn(async move |project, cx| {
879 let buffer = buffer.await?;
880
881 project.update(cx, |project, cx| {
882 let position = if let Some(line) = location.line {
883 let snapshot = buffer.read(cx).snapshot();
884 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
885 snapshot.anchor_before(point)
886 } else {
887 Anchor::MIN
888 };
889
890 project.set_agent_location(
891 Some(AgentLocation {
892 buffer: buffer.downgrade(),
893 position,
894 }),
895 cx,
896 );
897 })
898 })
899 .detach_and_log_err(cx);
900 });
901 }
902
903 pub fn request_tool_call_authorization(
904 &mut self,
905 tool_call: acp::ToolCall,
906 options: Vec<acp::PermissionOption>,
907 cx: &mut Context<Self>,
908 ) -> oneshot::Receiver<acp::PermissionOptionId> {
909 let (tx, rx) = oneshot::channel();
910
911 let status = ToolCallStatus::WaitingForConfirmation {
912 options,
913 respond_tx: tx,
914 };
915
916 self.upsert_tool_call_inner(tool_call, status, cx);
917 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
918 rx
919 }
920
921 pub fn authorize_tool_call(
922 &mut self,
923 id: acp::ToolCallId,
924 option_id: acp::PermissionOptionId,
925 option_kind: acp::PermissionOptionKind,
926 cx: &mut Context<Self>,
927 ) {
928 let Some((ix, call)) = self.tool_call_mut(&id) else {
929 return;
930 };
931
932 let new_status = match option_kind {
933 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
934 ToolCallStatus::Rejected
935 }
936 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
937 ToolCallStatus::Allowed {
938 status: acp::ToolCallStatus::InProgress,
939 }
940 }
941 };
942
943 let curr_status = mem::replace(&mut call.status, new_status);
944
945 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
946 respond_tx.send(option_id).log_err();
947 } else if cfg!(debug_assertions) {
948 panic!("tried to authorize an already authorized tool call");
949 }
950
951 cx.emit(AcpThreadEvent::EntryUpdated(ix));
952 }
953
954 /// Returns true if the last turn is awaiting tool authorization
955 pub fn waiting_for_tool_confirmation(&self) -> bool {
956 for entry in self.entries.iter().rev() {
957 match &entry {
958 AgentThreadEntry::ToolCall(call) => match call.status {
959 ToolCallStatus::WaitingForConfirmation { .. } => return true,
960 ToolCallStatus::Allowed { .. }
961 | ToolCallStatus::Rejected
962 | ToolCallStatus::Canceled => continue,
963 },
964 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
965 // Reached the beginning of the turn
966 return false;
967 }
968 }
969 }
970 false
971 }
972
973 pub fn plan(&self) -> &Plan {
974 &self.plan
975 }
976
977 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
978 let new_entries_len = request.entries.len();
979 let mut new_entries = request.entries.into_iter();
980
981 // Reuse existing markdown to prevent flickering
982 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
983 let PlanEntry {
984 content,
985 priority,
986 status,
987 } = old;
988 content.update(cx, |old, cx| {
989 old.replace(new.content, cx);
990 });
991 *priority = new.priority;
992 *status = new.status;
993 }
994 for new in new_entries {
995 self.plan.entries.push(PlanEntry::from_acp(new, cx))
996 }
997 self.plan.entries.truncate(new_entries_len);
998
999 cx.notify();
1000 }
1001
1002 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1003 self.plan
1004 .entries
1005 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1006 cx.notify();
1007 }
1008
1009 #[cfg(any(test, feature = "test-support"))]
1010 pub fn send_raw(
1011 &mut self,
1012 message: &str,
1013 cx: &mut Context<Self>,
1014 ) -> BoxFuture<'static, Result<()>> {
1015 self.send(
1016 vec![acp::ContentBlock::Text(acp::TextContent {
1017 text: message.to_string(),
1018 annotations: None,
1019 })],
1020 cx,
1021 )
1022 }
1023
1024 pub fn send(
1025 &mut self,
1026 message: Vec<acp::ContentBlock>,
1027 cx: &mut Context<Self>,
1028 ) -> BoxFuture<'static, Result<()>> {
1029 let block = ContentBlock::new_combined(
1030 message.clone(),
1031 self.project.read(cx).languages().clone(),
1032 cx,
1033 );
1034 self.push_entry(
1035 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1036 cx,
1037 );
1038 self.clear_completed_plan_entries(cx);
1039
1040 let (tx, rx) = oneshot::channel();
1041 let cancel_task = self.cancel(cx);
1042
1043 self.send_task = Some(cx.spawn(async move |this, cx| {
1044 async {
1045 cancel_task.await;
1046
1047 let result = this
1048 .update(cx, |this, cx| {
1049 this.connection.prompt(
1050 acp::PromptRequest {
1051 prompt: message,
1052 session_id: this.session_id.clone(),
1053 },
1054 cx,
1055 )
1056 })?
1057 .await;
1058
1059 tx.send(result).log_err();
1060
1061 anyhow::Ok(())
1062 }
1063 .await
1064 .log_err();
1065 }));
1066
1067 cx.spawn(async move |this, cx| match rx.await {
1068 Ok(Err(e)) => {
1069 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1070 .log_err();
1071 Err(e)?
1072 }
1073 result => {
1074 let cancelled = matches!(
1075 result,
1076 Ok(Ok(acp::PromptResponse {
1077 stop_reason: acp::StopReason::Cancelled
1078 }))
1079 );
1080
1081 // We only take the task if the current prompt wasn't cancelled.
1082 //
1083 // This prompt may have been cancelled because another one was sent
1084 // while it was still generating. In these cases, dropping `send_task`
1085 // would cause the next generation to be cancelled.
1086 if !cancelled {
1087 this.update(cx, |this, _cx| this.send_task.take()).ok();
1088 }
1089
1090 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1091 .log_err();
1092 Ok(())
1093 }
1094 })
1095 .boxed()
1096 }
1097
1098 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1099 let Some(send_task) = self.send_task.take() else {
1100 return Task::ready(());
1101 };
1102
1103 for entry in self.entries.iter_mut() {
1104 if let AgentThreadEntry::ToolCall(call) = entry {
1105 let cancel = matches!(
1106 call.status,
1107 ToolCallStatus::WaitingForConfirmation { .. }
1108 | ToolCallStatus::Allowed {
1109 status: acp::ToolCallStatus::InProgress
1110 }
1111 );
1112
1113 if cancel {
1114 call.status = ToolCallStatus::Canceled;
1115 }
1116 }
1117 }
1118
1119 self.connection.cancel(&self.session_id, cx);
1120
1121 // Wait for the send task to complete
1122 cx.foreground_executor().spawn(send_task)
1123 }
1124
1125 pub fn read_text_file(
1126 &self,
1127 path: PathBuf,
1128 line: Option<u32>,
1129 limit: Option<u32>,
1130 reuse_shared_snapshot: bool,
1131 cx: &mut Context<Self>,
1132 ) -> Task<Result<String>> {
1133 let project = self.project.clone();
1134 let action_log = self.action_log.clone();
1135 cx.spawn(async move |this, cx| {
1136 let load = project.update(cx, |project, cx| {
1137 let path = project
1138 .project_path_for_absolute_path(&path, cx)
1139 .context("invalid path")?;
1140 anyhow::Ok(project.open_buffer(path, cx))
1141 });
1142 let buffer = load??.await?;
1143
1144 let snapshot = if reuse_shared_snapshot {
1145 this.read_with(cx, |this, _| {
1146 this.shared_buffers.get(&buffer.clone()).cloned()
1147 })
1148 .log_err()
1149 .flatten()
1150 } else {
1151 None
1152 };
1153
1154 let snapshot = if let Some(snapshot) = snapshot {
1155 snapshot
1156 } else {
1157 action_log.update(cx, |action_log, cx| {
1158 action_log.buffer_read(buffer.clone(), cx);
1159 })?;
1160 project.update(cx, |project, cx| {
1161 let position = buffer
1162 .read(cx)
1163 .snapshot()
1164 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1165 project.set_agent_location(
1166 Some(AgentLocation {
1167 buffer: buffer.downgrade(),
1168 position,
1169 }),
1170 cx,
1171 );
1172 })?;
1173
1174 buffer.update(cx, |buffer, _| buffer.snapshot())?
1175 };
1176
1177 this.update(cx, |this, _| {
1178 let text = snapshot.text();
1179 this.shared_buffers.insert(buffer.clone(), snapshot);
1180 if line.is_none() && limit.is_none() {
1181 return Ok(text);
1182 }
1183 let limit = limit.unwrap_or(u32::MAX) as usize;
1184 let Some(line) = line else {
1185 return Ok(text.lines().take(limit).collect::<String>());
1186 };
1187
1188 let count = text.lines().count();
1189 if count < line as usize {
1190 anyhow::bail!("There are only {} lines", count);
1191 }
1192 Ok(text
1193 .lines()
1194 .skip(line as usize + 1)
1195 .take(limit)
1196 .collect::<String>())
1197 })?
1198 })
1199 }
1200
1201 pub fn write_text_file(
1202 &self,
1203 path: PathBuf,
1204 content: String,
1205 cx: &mut Context<Self>,
1206 ) -> Task<Result<()>> {
1207 let project = self.project.clone();
1208 let action_log = self.action_log.clone();
1209 cx.spawn(async move |this, cx| {
1210 let load = project.update(cx, |project, cx| {
1211 let path = project
1212 .project_path_for_absolute_path(&path, cx)
1213 .context("invalid path")?;
1214 anyhow::Ok(project.open_buffer(path, cx))
1215 });
1216 let buffer = load??.await?;
1217 let snapshot = this.update(cx, |this, cx| {
1218 this.shared_buffers
1219 .get(&buffer)
1220 .cloned()
1221 .unwrap_or_else(|| buffer.read(cx).snapshot())
1222 })?;
1223 let edits = cx
1224 .background_executor()
1225 .spawn(async move {
1226 let old_text = snapshot.text();
1227 text_diff(old_text.as_str(), &content)
1228 .into_iter()
1229 .map(|(range, replacement)| {
1230 (
1231 snapshot.anchor_after(range.start)
1232 ..snapshot.anchor_before(range.end),
1233 replacement,
1234 )
1235 })
1236 .collect::<Vec<_>>()
1237 })
1238 .await;
1239 cx.update(|cx| {
1240 project.update(cx, |project, cx| {
1241 project.set_agent_location(
1242 Some(AgentLocation {
1243 buffer: buffer.downgrade(),
1244 position: edits
1245 .last()
1246 .map(|(range, _)| range.end)
1247 .unwrap_or(Anchor::MIN),
1248 }),
1249 cx,
1250 );
1251 });
1252
1253 action_log.update(cx, |action_log, cx| {
1254 action_log.buffer_read(buffer.clone(), cx);
1255 });
1256 buffer.update(cx, |buffer, cx| {
1257 buffer.edit(edits, None, cx);
1258 });
1259 action_log.update(cx, |action_log, cx| {
1260 action_log.buffer_edited(buffer.clone(), cx);
1261 });
1262 })?;
1263 project
1264 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1265 .await
1266 })
1267 }
1268
1269 pub fn to_markdown(&self, cx: &App) -> String {
1270 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1271 }
1272
1273 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1274 cx.emit(AcpThreadEvent::ServerExited(status));
1275 }
1276}
1277
1278fn markdown_for_raw_output(
1279 raw_output: &serde_json::Value,
1280 language_registry: &Arc<LanguageRegistry>,
1281 cx: &mut App,
1282) -> Option<Entity<Markdown>> {
1283 match raw_output {
1284 serde_json::Value::Null => None,
1285 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1286 Markdown::new(
1287 value.to_string().into(),
1288 Some(language_registry.clone()),
1289 None,
1290 cx,
1291 )
1292 })),
1293 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1294 Markdown::new(
1295 value.to_string().into(),
1296 Some(language_registry.clone()),
1297 None,
1298 cx,
1299 )
1300 })),
1301 serde_json::Value::String(value) => Some(cx.new(|cx| {
1302 Markdown::new(
1303 value.clone().into(),
1304 Some(language_registry.clone()),
1305 None,
1306 cx,
1307 )
1308 })),
1309 value => Some(cx.new(|cx| {
1310 Markdown::new(
1311 format!("```json\n{}\n```", value).into(),
1312 Some(language_registry.clone()),
1313 None,
1314 cx,
1315 )
1316 })),
1317 }
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322 use super::*;
1323 use anyhow::anyhow;
1324 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1325 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1326 use indoc::indoc;
1327 use project::FakeFs;
1328 use rand::Rng as _;
1329 use serde_json::json;
1330 use settings::SettingsStore;
1331 use smol::stream::StreamExt as _;
1332 use std::{cell::RefCell, rc::Rc, time::Duration};
1333
1334 use util::path;
1335
1336 fn init_test(cx: &mut TestAppContext) {
1337 env_logger::try_init().ok();
1338 cx.update(|cx| {
1339 let settings_store = SettingsStore::test(cx);
1340 cx.set_global(settings_store);
1341 Project::init_settings(cx);
1342 language::init(cx);
1343 });
1344 }
1345
1346 #[gpui::test]
1347 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1348 init_test(cx);
1349
1350 let fs = FakeFs::new(cx.executor());
1351 let project = Project::test(fs, [], cx).await;
1352 let connection = Rc::new(FakeAgentConnection::new());
1353 let thread = cx
1354 .spawn(async move |mut cx| {
1355 connection
1356 .new_thread(project, Path::new(path!("/test")), &mut cx)
1357 .await
1358 })
1359 .await
1360 .unwrap();
1361
1362 // Test creating a new user message
1363 thread.update(cx, |thread, cx| {
1364 thread.push_user_content_block(
1365 acp::ContentBlock::Text(acp::TextContent {
1366 annotations: None,
1367 text: "Hello, ".to_string(),
1368 }),
1369 cx,
1370 );
1371 });
1372
1373 thread.update(cx, |thread, cx| {
1374 assert_eq!(thread.entries.len(), 1);
1375 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1376 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1377 } else {
1378 panic!("Expected UserMessage");
1379 }
1380 });
1381
1382 // Test appending to existing user message
1383 thread.update(cx, |thread, cx| {
1384 thread.push_user_content_block(
1385 acp::ContentBlock::Text(acp::TextContent {
1386 annotations: None,
1387 text: "world!".to_string(),
1388 }),
1389 cx,
1390 );
1391 });
1392
1393 thread.update(cx, |thread, cx| {
1394 assert_eq!(thread.entries.len(), 1);
1395 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1396 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1397 } else {
1398 panic!("Expected UserMessage");
1399 }
1400 });
1401
1402 // Test creating new user message after assistant message
1403 thread.update(cx, |thread, cx| {
1404 thread.push_assistant_content_block(
1405 acp::ContentBlock::Text(acp::TextContent {
1406 annotations: None,
1407 text: "Assistant response".to_string(),
1408 }),
1409 false,
1410 cx,
1411 );
1412 });
1413
1414 thread.update(cx, |thread, cx| {
1415 thread.push_user_content_block(
1416 acp::ContentBlock::Text(acp::TextContent {
1417 annotations: None,
1418 text: "New user message".to_string(),
1419 }),
1420 cx,
1421 );
1422 });
1423
1424 thread.update(cx, |thread, cx| {
1425 assert_eq!(thread.entries.len(), 3);
1426 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1427 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1428 } else {
1429 panic!("Expected UserMessage at index 2");
1430 }
1431 });
1432 }
1433
1434 #[gpui::test]
1435 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1436 init_test(cx);
1437
1438 let fs = FakeFs::new(cx.executor());
1439 let project = Project::test(fs, [], cx).await;
1440 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1441 |_, thread, mut cx| {
1442 async move {
1443 thread.update(&mut cx, |thread, cx| {
1444 thread
1445 .handle_session_update(
1446 acp::SessionUpdate::AgentThoughtChunk {
1447 content: "Thinking ".into(),
1448 },
1449 cx,
1450 )
1451 .unwrap();
1452 thread
1453 .handle_session_update(
1454 acp::SessionUpdate::AgentThoughtChunk {
1455 content: "hard!".into(),
1456 },
1457 cx,
1458 )
1459 .unwrap();
1460 })?;
1461 Ok(acp::PromptResponse {
1462 stop_reason: acp::StopReason::EndTurn,
1463 })
1464 }
1465 .boxed_local()
1466 },
1467 ));
1468
1469 let thread = cx
1470 .spawn(async move |mut cx| {
1471 connection
1472 .new_thread(project, Path::new(path!("/test")), &mut cx)
1473 .await
1474 })
1475 .await
1476 .unwrap();
1477
1478 thread
1479 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1480 .await
1481 .unwrap();
1482
1483 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1484 assert_eq!(
1485 output,
1486 indoc! {r#"
1487 ## User
1488
1489 Hello from Zed!
1490
1491 ## Assistant
1492
1493 <thinking>
1494 Thinking hard!
1495 </thinking>
1496
1497 "#}
1498 );
1499 }
1500
1501 #[gpui::test]
1502 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1503 init_test(cx);
1504
1505 let fs = FakeFs::new(cx.executor());
1506 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1507 .await;
1508 let project = Project::test(fs.clone(), [], cx).await;
1509 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1510 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1511 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1512 move |_, thread, mut cx| {
1513 let read_file_tx = read_file_tx.clone();
1514 async move {
1515 let content = thread
1516 .update(&mut cx, |thread, cx| {
1517 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1518 })
1519 .unwrap()
1520 .await
1521 .unwrap();
1522 assert_eq!(content, "one\ntwo\nthree\n");
1523 read_file_tx.take().unwrap().send(()).unwrap();
1524 thread
1525 .update(&mut cx, |thread, cx| {
1526 thread.write_text_file(
1527 path!("/tmp/foo").into(),
1528 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1529 cx,
1530 )
1531 })
1532 .unwrap()
1533 .await
1534 .unwrap();
1535 Ok(acp::PromptResponse {
1536 stop_reason: acp::StopReason::EndTurn,
1537 })
1538 }
1539 .boxed_local()
1540 },
1541 ));
1542
1543 let (worktree, pathbuf) = project
1544 .update(cx, |project, cx| {
1545 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1546 })
1547 .await
1548 .unwrap();
1549 let buffer = project
1550 .update(cx, |project, cx| {
1551 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1552 })
1553 .await
1554 .unwrap();
1555
1556 let thread = cx
1557 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1558 .await
1559 .unwrap();
1560
1561 let request = thread.update(cx, |thread, cx| {
1562 thread.send_raw("Extend the count in /tmp/foo", cx)
1563 });
1564 read_file_rx.await.ok();
1565 buffer.update(cx, |buffer, cx| {
1566 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1567 });
1568 cx.run_until_parked();
1569 assert_eq!(
1570 buffer.read_with(cx, |buffer, _| buffer.text()),
1571 "zero\none\ntwo\nthree\nfour\nfive\n"
1572 );
1573 assert_eq!(
1574 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1575 "zero\none\ntwo\nthree\nfour\nfive\n"
1576 );
1577 request.await.unwrap();
1578 }
1579
1580 #[gpui::test]
1581 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1582 init_test(cx);
1583
1584 let fs = FakeFs::new(cx.executor());
1585 let project = Project::test(fs, [], cx).await;
1586 let id = acp::ToolCallId("test".into());
1587
1588 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1589 let id = id.clone();
1590 move |_, thread, mut cx| {
1591 let id = id.clone();
1592 async move {
1593 thread
1594 .update(&mut cx, |thread, cx| {
1595 thread.handle_session_update(
1596 acp::SessionUpdate::ToolCall(acp::ToolCall {
1597 id: id.clone(),
1598 title: "Label".into(),
1599 kind: acp::ToolKind::Fetch,
1600 status: acp::ToolCallStatus::InProgress,
1601 content: vec![],
1602 locations: vec![],
1603 raw_input: None,
1604 raw_output: None,
1605 }),
1606 cx,
1607 )
1608 })
1609 .unwrap()
1610 .unwrap();
1611 Ok(acp::PromptResponse {
1612 stop_reason: acp::StopReason::EndTurn,
1613 })
1614 }
1615 .boxed_local()
1616 }
1617 }));
1618
1619 let thread = cx
1620 .spawn(async move |mut cx| {
1621 connection
1622 .new_thread(project, Path::new(path!("/test")), &mut cx)
1623 .await
1624 })
1625 .await
1626 .unwrap();
1627
1628 let request = thread.update(cx, |thread, cx| {
1629 thread.send_raw("Fetch https://example.com", cx)
1630 });
1631
1632 run_until_first_tool_call(&thread, cx).await;
1633
1634 thread.read_with(cx, |thread, _| {
1635 assert!(matches!(
1636 thread.entries[1],
1637 AgentThreadEntry::ToolCall(ToolCall {
1638 status: ToolCallStatus::Allowed {
1639 status: acp::ToolCallStatus::InProgress,
1640 ..
1641 },
1642 ..
1643 })
1644 ));
1645 });
1646
1647 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1648
1649 thread.read_with(cx, |thread, _| {
1650 assert!(matches!(
1651 &thread.entries[1],
1652 AgentThreadEntry::ToolCall(ToolCall {
1653 status: ToolCallStatus::Canceled,
1654 ..
1655 })
1656 ));
1657 });
1658
1659 thread
1660 .update(cx, |thread, cx| {
1661 thread.handle_session_update(
1662 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1663 id,
1664 fields: acp::ToolCallUpdateFields {
1665 status: Some(acp::ToolCallStatus::Completed),
1666 ..Default::default()
1667 },
1668 }),
1669 cx,
1670 )
1671 })
1672 .unwrap();
1673
1674 request.await.unwrap();
1675
1676 thread.read_with(cx, |thread, _| {
1677 assert!(matches!(
1678 thread.entries[1],
1679 AgentThreadEntry::ToolCall(ToolCall {
1680 status: ToolCallStatus::Allowed {
1681 status: acp::ToolCallStatus::Completed,
1682 ..
1683 },
1684 ..
1685 })
1686 ));
1687 });
1688 }
1689
1690 #[gpui::test]
1691 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1692 init_test(cx);
1693 let fs = FakeFs::new(cx.background_executor.clone());
1694 fs.insert_tree(path!("/test"), json!({})).await;
1695 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1696
1697 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1698 move |_, thread, mut cx| {
1699 async move {
1700 thread
1701 .update(&mut cx, |thread, cx| {
1702 thread.handle_session_update(
1703 acp::SessionUpdate::ToolCall(acp::ToolCall {
1704 id: acp::ToolCallId("test".into()),
1705 title: "Label".into(),
1706 kind: acp::ToolKind::Edit,
1707 status: acp::ToolCallStatus::Completed,
1708 content: vec![acp::ToolCallContent::Diff {
1709 diff: acp::Diff {
1710 path: "/test/test.txt".into(),
1711 old_text: None,
1712 new_text: "foo".into(),
1713 },
1714 }],
1715 locations: vec![],
1716 raw_input: None,
1717 raw_output: None,
1718 }),
1719 cx,
1720 )
1721 })
1722 .unwrap()
1723 .unwrap();
1724 Ok(acp::PromptResponse {
1725 stop_reason: acp::StopReason::EndTurn,
1726 })
1727 }
1728 .boxed_local()
1729 }
1730 }));
1731
1732 let thread = connection
1733 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1734 .await
1735 .unwrap();
1736 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1737 .await
1738 .unwrap();
1739
1740 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1741 }
1742
1743 async fn run_until_first_tool_call(
1744 thread: &Entity<AcpThread>,
1745 cx: &mut TestAppContext,
1746 ) -> usize {
1747 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1748
1749 let subscription = cx.update(|cx| {
1750 cx.subscribe(thread, move |thread, _, cx| {
1751 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1752 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1753 return tx.try_send(ix).unwrap();
1754 }
1755 }
1756 })
1757 });
1758
1759 select! {
1760 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1761 panic!("Timeout waiting for tool call")
1762 }
1763 ix = rx.next().fuse() => {
1764 drop(subscription);
1765 ix.unwrap()
1766 }
1767 }
1768 }
1769
1770 #[derive(Clone, Default)]
1771 struct FakeAgentConnection {
1772 auth_methods: Vec<acp::AuthMethod>,
1773 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1774 on_user_message: Option<
1775 Rc<
1776 dyn Fn(
1777 acp::PromptRequest,
1778 WeakEntity<AcpThread>,
1779 AsyncApp,
1780 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1781 + 'static,
1782 >,
1783 >,
1784 }
1785
1786 impl FakeAgentConnection {
1787 fn new() -> Self {
1788 Self {
1789 auth_methods: Vec::new(),
1790 on_user_message: None,
1791 sessions: Arc::default(),
1792 }
1793 }
1794
1795 #[expect(unused)]
1796 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1797 self.auth_methods = auth_methods;
1798 self
1799 }
1800
1801 fn on_user_message(
1802 mut self,
1803 handler: impl Fn(
1804 acp::PromptRequest,
1805 WeakEntity<AcpThread>,
1806 AsyncApp,
1807 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1808 + 'static,
1809 ) -> Self {
1810 self.on_user_message.replace(Rc::new(handler));
1811 self
1812 }
1813 }
1814
1815 impl AgentConnection for FakeAgentConnection {
1816 fn auth_methods(&self) -> &[acp::AuthMethod] {
1817 &self.auth_methods
1818 }
1819
1820 fn new_thread(
1821 self: Rc<Self>,
1822 project: Entity<Project>,
1823 _cwd: &Path,
1824 cx: &mut gpui::AsyncApp,
1825 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1826 let session_id = acp::SessionId(
1827 rand::thread_rng()
1828 .sample_iter(&rand::distributions::Alphanumeric)
1829 .take(7)
1830 .map(char::from)
1831 .collect::<String>()
1832 .into(),
1833 );
1834 let thread = cx
1835 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1836 .unwrap();
1837 self.sessions.lock().insert(session_id, thread.downgrade());
1838 Task::ready(Ok(thread))
1839 }
1840
1841 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1842 if self.auth_methods().iter().any(|m| m.id == method) {
1843 Task::ready(Ok(()))
1844 } else {
1845 Task::ready(Err(anyhow!("Invalid Auth Method")))
1846 }
1847 }
1848
1849 fn prompt(
1850 &self,
1851 params: acp::PromptRequest,
1852 cx: &mut App,
1853 ) -> Task<gpui::Result<acp::PromptResponse>> {
1854 let sessions = self.sessions.lock();
1855 let thread = sessions.get(¶ms.session_id).unwrap();
1856 if let Some(handler) = &self.on_user_message {
1857 let handler = handler.clone();
1858 let thread = thread.clone();
1859 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1860 } else {
1861 Task::ready(Ok(acp::PromptResponse {
1862 stop_reason: acp::StopReason::EndTurn,
1863 }))
1864 }
1865 }
1866
1867 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1868 let sessions = self.sessions.lock();
1869 let thread = sessions.get(&session_id).unwrap().clone();
1870
1871 cx.spawn(async move |cx| {
1872 thread
1873 .update(cx, |thread, cx| thread.cancel(cx))
1874 .unwrap()
1875 .await
1876 })
1877 .detach();
1878 }
1879 }
1880}