1mod connection;
2mod diff;
3
4pub use connection::*;
5pub use diff::*;
6
7use agent_client_protocol as acp;
8use anyhow::{Context as _, Result};
9use assistant_tool::ActionLog;
10use editor::Bias;
11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
12use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
13use itertools::Itertools;
14use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
15use markdown::Markdown;
16use project::{AgentLocation, Project};
17use std::collections::HashMap;
18use std::error::Error;
19use std::fmt::Formatter;
20use std::process::ExitStatus;
21use std::rc::Rc;
22use std::{
23 fmt::Display,
24 mem,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub content: ContentBlock,
34}
35
36impl UserMessage {
37 pub fn from_acp(
38 message: impl IntoIterator<Item = acp::ContentBlock>,
39 language_registry: Arc<LanguageRegistry>,
40 cx: &mut App,
41 ) -> Self {
42 let mut content = ContentBlock::Empty;
43 for chunk in message {
44 content.append(chunk, &language_registry, cx)
45 }
46 Self { content: content }
47 }
48
49 fn to_markdown(&self, cx: &App) -> String {
50 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
51 }
52}
53
54#[derive(Debug)]
55pub struct MentionPath<'a>(&'a Path);
56
57impl<'a> MentionPath<'a> {
58 const PREFIX: &'static str = "@file:";
59
60 pub fn new(path: &'a Path) -> Self {
61 MentionPath(path)
62 }
63
64 pub fn try_parse(url: &'a str) -> Option<Self> {
65 let path = url.strip_prefix(Self::PREFIX)?;
66 Some(MentionPath(Path::new(path)))
67 }
68
69 pub fn path(&self) -> &Path {
70 self.0
71 }
72}
73
74impl Display for MentionPath<'_> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(
77 f,
78 "[@{}]({}{})",
79 self.0.file_name().unwrap_or_default().display(),
80 Self::PREFIX,
81 self.0.display()
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub struct AssistantMessage {
88 pub chunks: Vec<AssistantMessageChunk>,
89}
90
91impl AssistantMessage {
92 pub fn to_markdown(&self, cx: &App) -> String {
93 format!(
94 "## Assistant\n\n{}\n\n",
95 self.chunks
96 .iter()
97 .map(|chunk| chunk.to_markdown(cx))
98 .join("\n\n")
99 )
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub enum AssistantMessageChunk {
105 Message { block: ContentBlock },
106 Thought { block: ContentBlock },
107}
108
109impl AssistantMessageChunk {
110 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 Self::Message {
112 block: ContentBlock::new(chunk.into(), language_registry, cx),
113 }
114 }
115
116 fn to_markdown(&self, cx: &App) -> String {
117 match self {
118 Self::Message { block } => block.to_markdown(cx).to_string(),
119 Self::Thought { block } => {
120 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
121 }
122 }
123 }
124}
125
126#[derive(Debug)]
127pub enum AgentThreadEntry {
128 UserMessage(UserMessage),
129 AssistantMessage(AssistantMessage),
130 ToolCall(ToolCall),
131}
132
133impl AgentThreadEntry {
134 fn to_markdown(&self, cx: &App) -> String {
135 match self {
136 Self::UserMessage(message) => message.to_markdown(cx),
137 Self::AssistantMessage(message) => message.to_markdown(cx),
138 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
151 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
152 Some(locations)
153 } else {
154 None
155 }
156 }
157}
158
159#[derive(Debug)]
160pub struct ToolCall {
161 pub id: acp::ToolCallId,
162 pub label: Entity<Markdown>,
163 pub kind: acp::ToolKind,
164 pub content: Vec<ToolCallContent>,
165 pub status: ToolCallStatus,
166 pub locations: Vec<acp::ToolCallLocation>,
167 pub raw_input: Option<serde_json::Value>,
168 pub raw_output: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 raw_output: tool_call.raw_output,
198 }
199 }
200
201 fn update_fields(
202 &mut self,
203 fields: acp::ToolCallUpdateFields,
204 language_registry: Arc<LanguageRegistry>,
205 cx: &mut App,
206 ) {
207 let acp::ToolCallUpdateFields {
208 kind,
209 status,
210 title,
211 content,
212 locations,
213 raw_input,
214 raw_output,
215 } = fields;
216
217 if let Some(kind) = kind {
218 self.kind = kind;
219 }
220
221 if let Some(status) = status {
222 self.status = ToolCallStatus::Allowed { status };
223 }
224
225 if let Some(title) = title {
226 self.label.update(cx, |label, cx| {
227 label.replace(title, cx);
228 });
229 }
230
231 if let Some(content) = content {
232 self.content = content
233 .into_iter()
234 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
235 .collect();
236 }
237
238 if let Some(locations) = locations {
239 self.locations = locations;
240 }
241
242 if let Some(raw_input) = raw_input {
243 self.raw_input = Some(raw_input);
244 }
245
246 if let Some(raw_output) = raw_output {
247 self.raw_output = Some(raw_output);
248 }
249 }
250
251 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
252 self.content.iter().filter_map(|content| match content {
253 ToolCallContent::ContentBlock { .. } => None,
254 ToolCallContent::Diff { diff } => Some(diff),
255 })
256 }
257
258 fn to_markdown(&self, cx: &App) -> String {
259 let mut markdown = format!(
260 "**Tool Call: {}**\nStatus: {}\n\n",
261 self.label.read(cx).source(),
262 self.status
263 );
264 for content in &self.content {
265 markdown.push_str(content.to_markdown(cx).as_str());
266 markdown.push_str("\n\n");
267 }
268 markdown
269 }
270}
271
272#[derive(Debug)]
273pub enum ToolCallStatus {
274 WaitingForConfirmation {
275 options: Vec<acp::PermissionOption>,
276 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
277 },
278 Allowed {
279 status: acp::ToolCallStatus,
280 },
281 Rejected,
282 Canceled,
283}
284
285impl Display for ToolCallStatus {
286 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287 write!(
288 f,
289 "{}",
290 match self {
291 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
292 ToolCallStatus::Allowed { status } => match status {
293 acp::ToolCallStatus::Pending => "Pending",
294 acp::ToolCallStatus::InProgress => "In Progress",
295 acp::ToolCallStatus::Completed => "Completed",
296 acp::ToolCallStatus::Failed => "Failed",
297 },
298 ToolCallStatus::Rejected => "Rejected",
299 ToolCallStatus::Canceled => "Canceled",
300 }
301 )
302 }
303}
304
305#[derive(Debug, PartialEq, Clone)]
306pub enum ContentBlock {
307 Empty,
308 Markdown { markdown: Entity<Markdown> },
309}
310
311impl ContentBlock {
312 pub fn new(
313 block: acp::ContentBlock,
314 language_registry: &Arc<LanguageRegistry>,
315 cx: &mut App,
316 ) -> Self {
317 let mut this = Self::Empty;
318 this.append(block, language_registry, cx);
319 this
320 }
321
322 pub fn new_combined(
323 blocks: impl IntoIterator<Item = acp::ContentBlock>,
324 language_registry: Arc<LanguageRegistry>,
325 cx: &mut App,
326 ) -> Self {
327 let mut this = Self::Empty;
328 for block in blocks {
329 this.append(block, &language_registry, cx);
330 }
331 this
332 }
333
334 pub fn append(
335 &mut self,
336 block: acp::ContentBlock,
337 language_registry: &Arc<LanguageRegistry>,
338 cx: &mut App,
339 ) {
340 let new_content = match block {
341 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
342 acp::ContentBlock::ResourceLink(resource_link) => {
343 if let Some(path) = resource_link.uri.strip_prefix("file://") {
344 format!("{}", MentionPath(path.as_ref()))
345 } else {
346 resource_link.uri.clone()
347 }
348 }
349 acp::ContentBlock::Image(_)
350 | acp::ContentBlock::Audio(_)
351 | acp::ContentBlock::Resource(_) => String::new(),
352 };
353
354 match self {
355 ContentBlock::Empty => {
356 *self = ContentBlock::Markdown {
357 markdown: cx.new(|cx| {
358 Markdown::new(
359 new_content.into(),
360 Some(language_registry.clone()),
361 None,
362 cx,
363 )
364 }),
365 };
366 }
367 ContentBlock::Markdown { markdown } => {
368 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
369 }
370 }
371 }
372
373 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
374 match self {
375 ContentBlock::Empty => "",
376 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
377 }
378 }
379
380 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
381 match self {
382 ContentBlock::Empty => None,
383 ContentBlock::Markdown { markdown } => Some(markdown),
384 }
385 }
386}
387
388#[derive(Debug)]
389pub enum ToolCallContent {
390 ContentBlock { content: ContentBlock },
391 Diff { diff: Entity<Diff> },
392}
393
394impl ToolCallContent {
395 pub fn from_acp(
396 content: acp::ToolCallContent,
397 language_registry: Arc<LanguageRegistry>,
398 cx: &mut App,
399 ) -> Self {
400 match content {
401 acp::ToolCallContent::Content { content } => Self::ContentBlock {
402 content: ContentBlock::new(content, &language_registry, cx),
403 },
404 acp::ToolCallContent::Diff { diff } => Self::Diff {
405 diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
406 },
407 }
408 }
409
410 pub fn to_markdown(&self, cx: &App) -> String {
411 match self {
412 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
413 Self::Diff { diff } => diff.read(cx).to_markdown(cx),
414 }
415 }
416}
417
418#[derive(Debug, PartialEq)]
419pub enum ToolCallUpdate {
420 UpdateFields(acp::ToolCallUpdate),
421 UpdateDiff(ToolCallUpdateDiff),
422}
423
424impl ToolCallUpdate {
425 fn id(&self) -> &acp::ToolCallId {
426 match self {
427 Self::UpdateFields(update) => &update.id,
428 Self::UpdateDiff(diff) => &diff.id,
429 }
430 }
431}
432
433impl From<acp::ToolCallUpdate> for ToolCallUpdate {
434 fn from(update: acp::ToolCallUpdate) -> Self {
435 Self::UpdateFields(update)
436 }
437}
438
439impl From<ToolCallUpdateDiff> for ToolCallUpdate {
440 fn from(diff: ToolCallUpdateDiff) -> Self {
441 Self::UpdateDiff(diff)
442 }
443}
444
445#[derive(Debug, PartialEq)]
446pub struct ToolCallUpdateDiff {
447 pub id: acp::ToolCallId,
448 pub diff: Entity<Diff>,
449}
450
451#[derive(Debug, Default)]
452pub struct Plan {
453 pub entries: Vec<PlanEntry>,
454}
455
456#[derive(Debug)]
457pub struct PlanStats<'a> {
458 pub in_progress_entry: Option<&'a PlanEntry>,
459 pub pending: u32,
460 pub completed: u32,
461}
462
463impl Plan {
464 pub fn is_empty(&self) -> bool {
465 self.entries.is_empty()
466 }
467
468 pub fn stats(&self) -> PlanStats<'_> {
469 let mut stats = PlanStats {
470 in_progress_entry: None,
471 pending: 0,
472 completed: 0,
473 };
474
475 for entry in &self.entries {
476 match &entry.status {
477 acp::PlanEntryStatus::Pending => {
478 stats.pending += 1;
479 }
480 acp::PlanEntryStatus::InProgress => {
481 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
482 }
483 acp::PlanEntryStatus::Completed => {
484 stats.completed += 1;
485 }
486 }
487 }
488
489 stats
490 }
491}
492
493#[derive(Debug)]
494pub struct PlanEntry {
495 pub content: Entity<Markdown>,
496 pub priority: acp::PlanEntryPriority,
497 pub status: acp::PlanEntryStatus,
498}
499
500impl PlanEntry {
501 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
502 Self {
503 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
504 priority: entry.priority,
505 status: entry.status,
506 }
507 }
508}
509
510pub struct AcpThread {
511 title: SharedString,
512 entries: Vec<AgentThreadEntry>,
513 plan: Plan,
514 project: Entity<Project>,
515 action_log: Entity<ActionLog>,
516 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
517 send_task: Option<Task<()>>,
518 connection: Rc<dyn AgentConnection>,
519 session_id: acp::SessionId,
520}
521
522pub enum AcpThreadEvent {
523 NewEntry,
524 EntryUpdated(usize),
525 ToolAuthorizationRequired,
526 Stopped,
527 Error,
528 ServerExited(ExitStatus),
529}
530
531impl EventEmitter<AcpThreadEvent> for AcpThread {}
532
533#[derive(PartialEq, Eq)]
534pub enum ThreadStatus {
535 Idle,
536 WaitingForToolConfirmation,
537 Generating,
538}
539
540#[derive(Debug, Clone)]
541pub enum LoadError {
542 Unsupported {
543 error_message: SharedString,
544 upgrade_message: SharedString,
545 upgrade_command: String,
546 },
547 Exited(i32),
548 Other(SharedString),
549}
550
551impl Display for LoadError {
552 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
553 match self {
554 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
555 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
556 LoadError::Other(msg) => write!(f, "{}", msg),
557 }
558 }
559}
560
561impl Error for LoadError {}
562
563impl AcpThread {
564 pub fn new(
565 title: impl Into<SharedString>,
566 connection: Rc<dyn AgentConnection>,
567 project: Entity<Project>,
568 session_id: acp::SessionId,
569 cx: &mut Context<Self>,
570 ) -> Self {
571 let action_log = cx.new(|_| ActionLog::new(project.clone()));
572
573 Self {
574 action_log,
575 shared_buffers: Default::default(),
576 entries: Default::default(),
577 plan: Default::default(),
578 title: title.into(),
579 project,
580 send_task: None,
581 connection,
582 session_id,
583 }
584 }
585
586 pub fn action_log(&self) -> &Entity<ActionLog> {
587 &self.action_log
588 }
589
590 pub fn project(&self) -> &Entity<Project> {
591 &self.project
592 }
593
594 pub fn title(&self) -> SharedString {
595 self.title.clone()
596 }
597
598 pub fn entries(&self) -> &[AgentThreadEntry] {
599 &self.entries
600 }
601
602 pub fn session_id(&self) -> &acp::SessionId {
603 &self.session_id
604 }
605
606 pub fn status(&self) -> ThreadStatus {
607 if self.send_task.is_some() {
608 if self.waiting_for_tool_confirmation() {
609 ThreadStatus::WaitingForToolConfirmation
610 } else {
611 ThreadStatus::Generating
612 }
613 } else {
614 ThreadStatus::Idle
615 }
616 }
617
618 pub fn has_pending_edit_tool_calls(&self) -> bool {
619 for entry in self.entries.iter().rev() {
620 match entry {
621 AgentThreadEntry::UserMessage(_) => return false,
622 AgentThreadEntry::ToolCall(
623 call @ ToolCall {
624 status:
625 ToolCallStatus::Allowed {
626 status:
627 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
628 },
629 ..
630 },
631 ) if call.diffs().next().is_some() => {
632 return true;
633 }
634 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
635 }
636 }
637
638 false
639 }
640
641 pub fn used_tools_since_last_user_message(&self) -> bool {
642 for entry in self.entries.iter().rev() {
643 match entry {
644 AgentThreadEntry::UserMessage(..) => return false,
645 AgentThreadEntry::AssistantMessage(..) => continue,
646 AgentThreadEntry::ToolCall(..) => return true,
647 }
648 }
649
650 false
651 }
652
653 pub fn handle_session_update(
654 &mut self,
655 update: acp::SessionUpdate,
656 cx: &mut Context<Self>,
657 ) -> Result<()> {
658 match update {
659 acp::SessionUpdate::UserMessageChunk { content } => {
660 self.push_user_content_block(content, cx);
661 }
662 acp::SessionUpdate::AgentMessageChunk { content } => {
663 self.push_assistant_content_block(content, false, cx);
664 }
665 acp::SessionUpdate::AgentThoughtChunk { content } => {
666 self.push_assistant_content_block(content, true, cx);
667 }
668 acp::SessionUpdate::ToolCall(tool_call) => {
669 self.upsert_tool_call(tool_call, cx);
670 }
671 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
672 self.update_tool_call(tool_call_update, cx)?;
673 }
674 acp::SessionUpdate::Plan(plan) => {
675 self.update_plan(plan, cx);
676 }
677 }
678 Ok(())
679 }
680
681 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
682 let language_registry = self.project.read(cx).languages().clone();
683 let entries_len = self.entries.len();
684
685 if let Some(last_entry) = self.entries.last_mut()
686 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
687 {
688 content.append(chunk, &language_registry, cx);
689 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
690 } else {
691 let content = ContentBlock::new(chunk, &language_registry, cx);
692 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
693 }
694 }
695
696 pub fn push_assistant_content_block(
697 &mut self,
698 chunk: acp::ContentBlock,
699 is_thought: bool,
700 cx: &mut Context<Self>,
701 ) {
702 let language_registry = self.project.read(cx).languages().clone();
703 let entries_len = self.entries.len();
704 if let Some(last_entry) = self.entries.last_mut()
705 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
706 {
707 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
708 match (chunks.last_mut(), is_thought) {
709 (Some(AssistantMessageChunk::Message { block }), false)
710 | (Some(AssistantMessageChunk::Thought { block }), true) => {
711 block.append(chunk, &language_registry, cx)
712 }
713 _ => {
714 let block = ContentBlock::new(chunk, &language_registry, cx);
715 if is_thought {
716 chunks.push(AssistantMessageChunk::Thought { block })
717 } else {
718 chunks.push(AssistantMessageChunk::Message { block })
719 }
720 }
721 }
722 } else {
723 let block = ContentBlock::new(chunk, &language_registry, cx);
724 let chunk = if is_thought {
725 AssistantMessageChunk::Thought { block }
726 } else {
727 AssistantMessageChunk::Message { block }
728 };
729
730 self.push_entry(
731 AgentThreadEntry::AssistantMessage(AssistantMessage {
732 chunks: vec![chunk],
733 }),
734 cx,
735 );
736 }
737 }
738
739 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
740 self.entries.push(entry);
741 cx.emit(AcpThreadEvent::NewEntry);
742 }
743
744 pub fn update_tool_call(
745 &mut self,
746 update: impl Into<ToolCallUpdate>,
747 cx: &mut Context<Self>,
748 ) -> Result<()> {
749 let update = update.into();
750 let languages = self.project.read(cx).languages().clone();
751
752 let (ix, current_call) = self
753 .tool_call_mut(update.id())
754 .context("Tool call not found")?;
755 match update {
756 ToolCallUpdate::UpdateFields(update) => {
757 current_call.update_fields(update.fields, languages, cx);
758 }
759 ToolCallUpdate::UpdateDiff(update) => {
760 current_call.content.clear();
761 current_call
762 .content
763 .push(ToolCallContent::Diff { diff: update.diff });
764 }
765 }
766
767 cx.emit(AcpThreadEvent::EntryUpdated(ix));
768
769 Ok(())
770 }
771
772 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
773 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
774 let status = ToolCallStatus::Allowed {
775 status: tool_call.status,
776 };
777 self.upsert_tool_call_inner(tool_call, status, cx)
778 }
779
780 pub fn upsert_tool_call_inner(
781 &mut self,
782 tool_call: acp::ToolCall,
783 status: ToolCallStatus,
784 cx: &mut Context<Self>,
785 ) {
786 let language_registry = self.project.read(cx).languages().clone();
787 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
788
789 let location = call.locations.last().cloned();
790
791 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
792 *current_call = call;
793
794 cx.emit(AcpThreadEvent::EntryUpdated(ix));
795 } else {
796 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
797 }
798
799 if let Some(location) = location {
800 self.set_project_location(location, cx)
801 }
802 }
803
804 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
805 // The tool call we are looking for is typically the last one, or very close to the end.
806 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
807 self.entries
808 .iter_mut()
809 .enumerate()
810 .rev()
811 .find_map(|(index, tool_call)| {
812 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
813 && &tool_call.id == id
814 {
815 Some((index, tool_call))
816 } else {
817 None
818 }
819 })
820 }
821
822 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
823 self.project.update(cx, |project, cx| {
824 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
825 return;
826 };
827 let buffer = project.open_buffer(path, cx);
828 cx.spawn(async move |project, cx| {
829 let buffer = buffer.await?;
830
831 project.update(cx, |project, cx| {
832 let position = if let Some(line) = location.line {
833 let snapshot = buffer.read(cx).snapshot();
834 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
835 snapshot.anchor_before(point)
836 } else {
837 Anchor::MIN
838 };
839
840 project.set_agent_location(
841 Some(AgentLocation {
842 buffer: buffer.downgrade(),
843 position,
844 }),
845 cx,
846 );
847 })
848 })
849 .detach_and_log_err(cx);
850 });
851 }
852
853 pub fn request_tool_call_authorization(
854 &mut self,
855 tool_call: acp::ToolCall,
856 options: Vec<acp::PermissionOption>,
857 cx: &mut Context<Self>,
858 ) -> oneshot::Receiver<acp::PermissionOptionId> {
859 let (tx, rx) = oneshot::channel();
860
861 let status = ToolCallStatus::WaitingForConfirmation {
862 options,
863 respond_tx: tx,
864 };
865
866 self.upsert_tool_call_inner(tool_call, status, cx);
867 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
868 rx
869 }
870
871 pub fn authorize_tool_call(
872 &mut self,
873 id: acp::ToolCallId,
874 option_id: acp::PermissionOptionId,
875 option_kind: acp::PermissionOptionKind,
876 cx: &mut Context<Self>,
877 ) {
878 let Some((ix, call)) = self.tool_call_mut(&id) else {
879 return;
880 };
881
882 let new_status = match option_kind {
883 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
884 ToolCallStatus::Rejected
885 }
886 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
887 ToolCallStatus::Allowed {
888 status: acp::ToolCallStatus::InProgress,
889 }
890 }
891 };
892
893 let curr_status = mem::replace(&mut call.status, new_status);
894
895 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
896 respond_tx.send(option_id).log_err();
897 } else if cfg!(debug_assertions) {
898 panic!("tried to authorize an already authorized tool call");
899 }
900
901 cx.emit(AcpThreadEvent::EntryUpdated(ix));
902 }
903
904 /// Returns true if the last turn is awaiting tool authorization
905 pub fn waiting_for_tool_confirmation(&self) -> bool {
906 for entry in self.entries.iter().rev() {
907 match &entry {
908 AgentThreadEntry::ToolCall(call) => match call.status {
909 ToolCallStatus::WaitingForConfirmation { .. } => return true,
910 ToolCallStatus::Allowed { .. }
911 | ToolCallStatus::Rejected
912 | ToolCallStatus::Canceled => continue,
913 },
914 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
915 // Reached the beginning of the turn
916 return false;
917 }
918 }
919 }
920 false
921 }
922
923 pub fn plan(&self) -> &Plan {
924 &self.plan
925 }
926
927 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
928 let new_entries_len = request.entries.len();
929 let mut new_entries = request.entries.into_iter();
930
931 // Reuse existing markdown to prevent flickering
932 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
933 let PlanEntry {
934 content,
935 priority,
936 status,
937 } = old;
938 content.update(cx, |old, cx| {
939 old.replace(new.content, cx);
940 });
941 *priority = new.priority;
942 *status = new.status;
943 }
944 for new in new_entries {
945 self.plan.entries.push(PlanEntry::from_acp(new, cx))
946 }
947 self.plan.entries.truncate(new_entries_len);
948
949 cx.notify();
950 }
951
952 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
953 self.plan
954 .entries
955 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
956 cx.notify();
957 }
958
959 #[cfg(any(test, feature = "test-support"))]
960 pub fn send_raw(
961 &mut self,
962 message: &str,
963 cx: &mut Context<Self>,
964 ) -> BoxFuture<'static, Result<()>> {
965 self.send(
966 vec![acp::ContentBlock::Text(acp::TextContent {
967 text: message.to_string(),
968 annotations: None,
969 })],
970 cx,
971 )
972 }
973
974 pub fn send(
975 &mut self,
976 message: Vec<acp::ContentBlock>,
977 cx: &mut Context<Self>,
978 ) -> BoxFuture<'static, Result<()>> {
979 let block = ContentBlock::new_combined(
980 message.clone(),
981 self.project.read(cx).languages().clone(),
982 cx,
983 );
984 self.push_entry(
985 AgentThreadEntry::UserMessage(UserMessage { content: block }),
986 cx,
987 );
988 self.clear_completed_plan_entries(cx);
989
990 let (tx, rx) = oneshot::channel();
991 let cancel_task = self.cancel(cx);
992
993 self.send_task = Some(cx.spawn(async move |this, cx| {
994 async {
995 cancel_task.await;
996
997 let result = this
998 .update(cx, |this, cx| {
999 this.connection.prompt(
1000 acp::PromptRequest {
1001 prompt: message,
1002 session_id: this.session_id.clone(),
1003 },
1004 cx,
1005 )
1006 })?
1007 .await;
1008
1009 tx.send(result).log_err();
1010
1011 anyhow::Ok(())
1012 }
1013 .await
1014 .log_err();
1015 }));
1016
1017 cx.spawn(async move |this, cx| match rx.await {
1018 Ok(Err(e)) => {
1019 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1020 .log_err();
1021 Err(e)?
1022 }
1023 result => {
1024 let cancelled = matches!(
1025 result,
1026 Ok(Ok(acp::PromptResponse {
1027 stop_reason: acp::StopReason::Cancelled
1028 }))
1029 );
1030
1031 // We only take the task if the current prompt wasn't cancelled.
1032 //
1033 // This prompt may have been cancelled because another one was sent
1034 // while it was still generating. In these cases, dropping `send_task`
1035 // would cause the next generation to be cancelled.
1036 if !cancelled {
1037 this.update(cx, |this, _cx| this.send_task.take()).ok();
1038 }
1039
1040 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1041 .log_err();
1042 Ok(())
1043 }
1044 })
1045 .boxed()
1046 }
1047
1048 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1049 let Some(send_task) = self.send_task.take() else {
1050 return Task::ready(());
1051 };
1052
1053 for entry in self.entries.iter_mut() {
1054 if let AgentThreadEntry::ToolCall(call) = entry {
1055 let cancel = matches!(
1056 call.status,
1057 ToolCallStatus::WaitingForConfirmation { .. }
1058 | ToolCallStatus::Allowed {
1059 status: acp::ToolCallStatus::InProgress
1060 }
1061 );
1062
1063 if cancel {
1064 call.status = ToolCallStatus::Canceled;
1065 }
1066 }
1067 }
1068
1069 self.connection.cancel(&self.session_id, cx);
1070
1071 // Wait for the send task to complete
1072 cx.foreground_executor().spawn(send_task)
1073 }
1074
1075 pub fn read_text_file(
1076 &self,
1077 path: PathBuf,
1078 line: Option<u32>,
1079 limit: Option<u32>,
1080 reuse_shared_snapshot: bool,
1081 cx: &mut Context<Self>,
1082 ) -> Task<Result<String>> {
1083 let project = self.project.clone();
1084 let action_log = self.action_log.clone();
1085 cx.spawn(async move |this, cx| {
1086 let load = project.update(cx, |project, cx| {
1087 let path = project
1088 .project_path_for_absolute_path(&path, cx)
1089 .context("invalid path")?;
1090 anyhow::Ok(project.open_buffer(path, cx))
1091 });
1092 let buffer = load??.await?;
1093
1094 let snapshot = if reuse_shared_snapshot {
1095 this.read_with(cx, |this, _| {
1096 this.shared_buffers.get(&buffer.clone()).cloned()
1097 })
1098 .log_err()
1099 .flatten()
1100 } else {
1101 None
1102 };
1103
1104 let snapshot = if let Some(snapshot) = snapshot {
1105 snapshot
1106 } else {
1107 action_log.update(cx, |action_log, cx| {
1108 action_log.buffer_read(buffer.clone(), cx);
1109 })?;
1110 project.update(cx, |project, cx| {
1111 let position = buffer
1112 .read(cx)
1113 .snapshot()
1114 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1115 project.set_agent_location(
1116 Some(AgentLocation {
1117 buffer: buffer.downgrade(),
1118 position,
1119 }),
1120 cx,
1121 );
1122 })?;
1123
1124 buffer.update(cx, |buffer, _| buffer.snapshot())?
1125 };
1126
1127 this.update(cx, |this, _| {
1128 let text = snapshot.text();
1129 this.shared_buffers.insert(buffer.clone(), snapshot);
1130 if line.is_none() && limit.is_none() {
1131 return Ok(text);
1132 }
1133 let limit = limit.unwrap_or(u32::MAX) as usize;
1134 let Some(line) = line else {
1135 return Ok(text.lines().take(limit).collect::<String>());
1136 };
1137
1138 let count = text.lines().count();
1139 if count < line as usize {
1140 anyhow::bail!("There are only {} lines", count);
1141 }
1142 Ok(text
1143 .lines()
1144 .skip(line as usize + 1)
1145 .take(limit)
1146 .collect::<String>())
1147 })?
1148 })
1149 }
1150
1151 pub fn write_text_file(
1152 &self,
1153 path: PathBuf,
1154 content: String,
1155 cx: &mut Context<Self>,
1156 ) -> Task<Result<()>> {
1157 let project = self.project.clone();
1158 let action_log = self.action_log.clone();
1159 cx.spawn(async move |this, cx| {
1160 let load = project.update(cx, |project, cx| {
1161 let path = project
1162 .project_path_for_absolute_path(&path, cx)
1163 .context("invalid path")?;
1164 anyhow::Ok(project.open_buffer(path, cx))
1165 });
1166 let buffer = load??.await?;
1167 let snapshot = this.update(cx, |this, cx| {
1168 this.shared_buffers
1169 .get(&buffer)
1170 .cloned()
1171 .unwrap_or_else(|| buffer.read(cx).snapshot())
1172 })?;
1173 let edits = cx
1174 .background_executor()
1175 .spawn(async move {
1176 let old_text = snapshot.text();
1177 text_diff(old_text.as_str(), &content)
1178 .into_iter()
1179 .map(|(range, replacement)| {
1180 (
1181 snapshot.anchor_after(range.start)
1182 ..snapshot.anchor_before(range.end),
1183 replacement,
1184 )
1185 })
1186 .collect::<Vec<_>>()
1187 })
1188 .await;
1189 cx.update(|cx| {
1190 project.update(cx, |project, cx| {
1191 project.set_agent_location(
1192 Some(AgentLocation {
1193 buffer: buffer.downgrade(),
1194 position: edits
1195 .last()
1196 .map(|(range, _)| range.end)
1197 .unwrap_or(Anchor::MIN),
1198 }),
1199 cx,
1200 );
1201 });
1202
1203 action_log.update(cx, |action_log, cx| {
1204 action_log.buffer_read(buffer.clone(), cx);
1205 });
1206 buffer.update(cx, |buffer, cx| {
1207 buffer.edit(edits, None, cx);
1208 });
1209 action_log.update(cx, |action_log, cx| {
1210 action_log.buffer_edited(buffer.clone(), cx);
1211 });
1212 })?;
1213 project
1214 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1215 .await
1216 })
1217 }
1218
1219 pub fn to_markdown(&self, cx: &App) -> String {
1220 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1221 }
1222
1223 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1224 cx.emit(AcpThreadEvent::ServerExited(status));
1225 }
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231 use anyhow::anyhow;
1232 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1233 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1234 use indoc::indoc;
1235 use project::FakeFs;
1236 use rand::Rng as _;
1237 use serde_json::json;
1238 use settings::SettingsStore;
1239 use smol::stream::StreamExt as _;
1240 use std::{cell::RefCell, rc::Rc, time::Duration};
1241
1242 use util::path;
1243
1244 fn init_test(cx: &mut TestAppContext) {
1245 env_logger::try_init().ok();
1246 cx.update(|cx| {
1247 let settings_store = SettingsStore::test(cx);
1248 cx.set_global(settings_store);
1249 Project::init_settings(cx);
1250 language::init(cx);
1251 });
1252 }
1253
1254 #[gpui::test]
1255 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1256 init_test(cx);
1257
1258 let fs = FakeFs::new(cx.executor());
1259 let project = Project::test(fs, [], cx).await;
1260 let connection = Rc::new(FakeAgentConnection::new());
1261 let thread = cx
1262 .spawn(async move |mut cx| {
1263 connection
1264 .new_thread(project, Path::new(path!("/test")), &mut cx)
1265 .await
1266 })
1267 .await
1268 .unwrap();
1269
1270 // Test creating a new user message
1271 thread.update(cx, |thread, cx| {
1272 thread.push_user_content_block(
1273 acp::ContentBlock::Text(acp::TextContent {
1274 annotations: None,
1275 text: "Hello, ".to_string(),
1276 }),
1277 cx,
1278 );
1279 });
1280
1281 thread.update(cx, |thread, cx| {
1282 assert_eq!(thread.entries.len(), 1);
1283 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1284 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1285 } else {
1286 panic!("Expected UserMessage");
1287 }
1288 });
1289
1290 // Test appending to existing user message
1291 thread.update(cx, |thread, cx| {
1292 thread.push_user_content_block(
1293 acp::ContentBlock::Text(acp::TextContent {
1294 annotations: None,
1295 text: "world!".to_string(),
1296 }),
1297 cx,
1298 );
1299 });
1300
1301 thread.update(cx, |thread, cx| {
1302 assert_eq!(thread.entries.len(), 1);
1303 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1304 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1305 } else {
1306 panic!("Expected UserMessage");
1307 }
1308 });
1309
1310 // Test creating new user message after assistant message
1311 thread.update(cx, |thread, cx| {
1312 thread.push_assistant_content_block(
1313 acp::ContentBlock::Text(acp::TextContent {
1314 annotations: None,
1315 text: "Assistant response".to_string(),
1316 }),
1317 false,
1318 cx,
1319 );
1320 });
1321
1322 thread.update(cx, |thread, cx| {
1323 thread.push_user_content_block(
1324 acp::ContentBlock::Text(acp::TextContent {
1325 annotations: None,
1326 text: "New user message".to_string(),
1327 }),
1328 cx,
1329 );
1330 });
1331
1332 thread.update(cx, |thread, cx| {
1333 assert_eq!(thread.entries.len(), 3);
1334 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1335 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1336 } else {
1337 panic!("Expected UserMessage at index 2");
1338 }
1339 });
1340 }
1341
1342 #[gpui::test]
1343 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1344 init_test(cx);
1345
1346 let fs = FakeFs::new(cx.executor());
1347 let project = Project::test(fs, [], cx).await;
1348 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1349 |_, thread, mut cx| {
1350 async move {
1351 thread.update(&mut cx, |thread, cx| {
1352 thread
1353 .handle_session_update(
1354 acp::SessionUpdate::AgentThoughtChunk {
1355 content: "Thinking ".into(),
1356 },
1357 cx,
1358 )
1359 .unwrap();
1360 thread
1361 .handle_session_update(
1362 acp::SessionUpdate::AgentThoughtChunk {
1363 content: "hard!".into(),
1364 },
1365 cx,
1366 )
1367 .unwrap();
1368 })?;
1369 Ok(acp::PromptResponse {
1370 stop_reason: acp::StopReason::EndTurn,
1371 })
1372 }
1373 .boxed_local()
1374 },
1375 ));
1376
1377 let thread = cx
1378 .spawn(async move |mut cx| {
1379 connection
1380 .new_thread(project, Path::new(path!("/test")), &mut cx)
1381 .await
1382 })
1383 .await
1384 .unwrap();
1385
1386 thread
1387 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1388 .await
1389 .unwrap();
1390
1391 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1392 assert_eq!(
1393 output,
1394 indoc! {r#"
1395 ## User
1396
1397 Hello from Zed!
1398
1399 ## Assistant
1400
1401 <thinking>
1402 Thinking hard!
1403 </thinking>
1404
1405 "#}
1406 );
1407 }
1408
1409 #[gpui::test]
1410 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1411 init_test(cx);
1412
1413 let fs = FakeFs::new(cx.executor());
1414 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1415 .await;
1416 let project = Project::test(fs.clone(), [], cx).await;
1417 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1418 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1419 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1420 move |_, thread, mut cx| {
1421 let read_file_tx = read_file_tx.clone();
1422 async move {
1423 let content = thread
1424 .update(&mut cx, |thread, cx| {
1425 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1426 })
1427 .unwrap()
1428 .await
1429 .unwrap();
1430 assert_eq!(content, "one\ntwo\nthree\n");
1431 read_file_tx.take().unwrap().send(()).unwrap();
1432 thread
1433 .update(&mut cx, |thread, cx| {
1434 thread.write_text_file(
1435 path!("/tmp/foo").into(),
1436 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1437 cx,
1438 )
1439 })
1440 .unwrap()
1441 .await
1442 .unwrap();
1443 Ok(acp::PromptResponse {
1444 stop_reason: acp::StopReason::EndTurn,
1445 })
1446 }
1447 .boxed_local()
1448 },
1449 ));
1450
1451 let (worktree, pathbuf) = project
1452 .update(cx, |project, cx| {
1453 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1454 })
1455 .await
1456 .unwrap();
1457 let buffer = project
1458 .update(cx, |project, cx| {
1459 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1460 })
1461 .await
1462 .unwrap();
1463
1464 let thread = cx
1465 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1466 .await
1467 .unwrap();
1468
1469 let request = thread.update(cx, |thread, cx| {
1470 thread.send_raw("Extend the count in /tmp/foo", cx)
1471 });
1472 read_file_rx.await.ok();
1473 buffer.update(cx, |buffer, cx| {
1474 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1475 });
1476 cx.run_until_parked();
1477 assert_eq!(
1478 buffer.read_with(cx, |buffer, _| buffer.text()),
1479 "zero\none\ntwo\nthree\nfour\nfive\n"
1480 );
1481 assert_eq!(
1482 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1483 "zero\none\ntwo\nthree\nfour\nfive\n"
1484 );
1485 request.await.unwrap();
1486 }
1487
1488 #[gpui::test]
1489 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1490 init_test(cx);
1491
1492 let fs = FakeFs::new(cx.executor());
1493 let project = Project::test(fs, [], cx).await;
1494 let id = acp::ToolCallId("test".into());
1495
1496 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1497 let id = id.clone();
1498 move |_, thread, mut cx| {
1499 let id = id.clone();
1500 async move {
1501 thread
1502 .update(&mut cx, |thread, cx| {
1503 thread.handle_session_update(
1504 acp::SessionUpdate::ToolCall(acp::ToolCall {
1505 id: id.clone(),
1506 title: "Label".into(),
1507 kind: acp::ToolKind::Fetch,
1508 status: acp::ToolCallStatus::InProgress,
1509 content: vec![],
1510 locations: vec![],
1511 raw_input: None,
1512 raw_output: None,
1513 }),
1514 cx,
1515 )
1516 })
1517 .unwrap()
1518 .unwrap();
1519 Ok(acp::PromptResponse {
1520 stop_reason: acp::StopReason::EndTurn,
1521 })
1522 }
1523 .boxed_local()
1524 }
1525 }));
1526
1527 let thread = cx
1528 .spawn(async move |mut cx| {
1529 connection
1530 .new_thread(project, Path::new(path!("/test")), &mut cx)
1531 .await
1532 })
1533 .await
1534 .unwrap();
1535
1536 let request = thread.update(cx, |thread, cx| {
1537 thread.send_raw("Fetch https://example.com", cx)
1538 });
1539
1540 run_until_first_tool_call(&thread, cx).await;
1541
1542 thread.read_with(cx, |thread, _| {
1543 assert!(matches!(
1544 thread.entries[1],
1545 AgentThreadEntry::ToolCall(ToolCall {
1546 status: ToolCallStatus::Allowed {
1547 status: acp::ToolCallStatus::InProgress,
1548 ..
1549 },
1550 ..
1551 })
1552 ));
1553 });
1554
1555 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1556
1557 thread.read_with(cx, |thread, _| {
1558 assert!(matches!(
1559 &thread.entries[1],
1560 AgentThreadEntry::ToolCall(ToolCall {
1561 status: ToolCallStatus::Canceled,
1562 ..
1563 })
1564 ));
1565 });
1566
1567 thread
1568 .update(cx, |thread, cx| {
1569 thread.handle_session_update(
1570 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1571 id,
1572 fields: acp::ToolCallUpdateFields {
1573 status: Some(acp::ToolCallStatus::Completed),
1574 ..Default::default()
1575 },
1576 }),
1577 cx,
1578 )
1579 })
1580 .unwrap();
1581
1582 request.await.unwrap();
1583
1584 thread.read_with(cx, |thread, _| {
1585 assert!(matches!(
1586 thread.entries[1],
1587 AgentThreadEntry::ToolCall(ToolCall {
1588 status: ToolCallStatus::Allowed {
1589 status: acp::ToolCallStatus::Completed,
1590 ..
1591 },
1592 ..
1593 })
1594 ));
1595 });
1596 }
1597
1598 #[gpui::test]
1599 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1600 init_test(cx);
1601 let fs = FakeFs::new(cx.background_executor.clone());
1602 fs.insert_tree(path!("/test"), json!({})).await;
1603 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1604
1605 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1606 move |_, thread, mut cx| {
1607 async move {
1608 thread
1609 .update(&mut cx, |thread, cx| {
1610 thread.handle_session_update(
1611 acp::SessionUpdate::ToolCall(acp::ToolCall {
1612 id: acp::ToolCallId("test".into()),
1613 title: "Label".into(),
1614 kind: acp::ToolKind::Edit,
1615 status: acp::ToolCallStatus::Completed,
1616 content: vec![acp::ToolCallContent::Diff {
1617 diff: acp::Diff {
1618 path: "/test/test.txt".into(),
1619 old_text: None,
1620 new_text: "foo".into(),
1621 },
1622 }],
1623 locations: vec![],
1624 raw_input: None,
1625 raw_output: None,
1626 }),
1627 cx,
1628 )
1629 })
1630 .unwrap()
1631 .unwrap();
1632 Ok(acp::PromptResponse {
1633 stop_reason: acp::StopReason::EndTurn,
1634 })
1635 }
1636 .boxed_local()
1637 }
1638 }));
1639
1640 let thread = connection
1641 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1642 .await
1643 .unwrap();
1644 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1645 .await
1646 .unwrap();
1647
1648 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1649 }
1650
1651 async fn run_until_first_tool_call(
1652 thread: &Entity<AcpThread>,
1653 cx: &mut TestAppContext,
1654 ) -> usize {
1655 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1656
1657 let subscription = cx.update(|cx| {
1658 cx.subscribe(thread, move |thread, _, cx| {
1659 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1660 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1661 return tx.try_send(ix).unwrap();
1662 }
1663 }
1664 })
1665 });
1666
1667 select! {
1668 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1669 panic!("Timeout waiting for tool call")
1670 }
1671 ix = rx.next().fuse() => {
1672 drop(subscription);
1673 ix.unwrap()
1674 }
1675 }
1676 }
1677
1678 #[derive(Clone, Default)]
1679 struct FakeAgentConnection {
1680 auth_methods: Vec<acp::AuthMethod>,
1681 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1682 on_user_message: Option<
1683 Rc<
1684 dyn Fn(
1685 acp::PromptRequest,
1686 WeakEntity<AcpThread>,
1687 AsyncApp,
1688 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1689 + 'static,
1690 >,
1691 >,
1692 }
1693
1694 impl FakeAgentConnection {
1695 fn new() -> Self {
1696 Self {
1697 auth_methods: Vec::new(),
1698 on_user_message: None,
1699 sessions: Arc::default(),
1700 }
1701 }
1702
1703 #[expect(unused)]
1704 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1705 self.auth_methods = auth_methods;
1706 self
1707 }
1708
1709 fn on_user_message(
1710 mut self,
1711 handler: impl Fn(
1712 acp::PromptRequest,
1713 WeakEntity<AcpThread>,
1714 AsyncApp,
1715 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1716 + 'static,
1717 ) -> Self {
1718 self.on_user_message.replace(Rc::new(handler));
1719 self
1720 }
1721 }
1722
1723 impl AgentConnection for FakeAgentConnection {
1724 fn auth_methods(&self) -> &[acp::AuthMethod] {
1725 &self.auth_methods
1726 }
1727
1728 fn new_thread(
1729 self: Rc<Self>,
1730 project: Entity<Project>,
1731 _cwd: &Path,
1732 cx: &mut gpui::AsyncApp,
1733 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1734 let session_id = acp::SessionId(
1735 rand::thread_rng()
1736 .sample_iter(&rand::distributions::Alphanumeric)
1737 .take(7)
1738 .map(char::from)
1739 .collect::<String>()
1740 .into(),
1741 );
1742 let thread = cx
1743 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1744 .unwrap();
1745 self.sessions.lock().insert(session_id, thread.downgrade());
1746 Task::ready(Ok(thread))
1747 }
1748
1749 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1750 if self.auth_methods().iter().any(|m| m.id == method) {
1751 Task::ready(Ok(()))
1752 } else {
1753 Task::ready(Err(anyhow!("Invalid Auth Method")))
1754 }
1755 }
1756
1757 fn prompt(
1758 &self,
1759 params: acp::PromptRequest,
1760 cx: &mut App,
1761 ) -> Task<gpui::Result<acp::PromptResponse>> {
1762 let sessions = self.sessions.lock();
1763 let thread = sessions.get(¶ms.session_id).unwrap();
1764 if let Some(handler) = &self.on_user_message {
1765 let handler = handler.clone();
1766 let thread = thread.clone();
1767 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1768 } else {
1769 Task::ready(Ok(acp::PromptResponse {
1770 stop_reason: acp::StopReason::EndTurn,
1771 }))
1772 }
1773 }
1774
1775 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1776 let sessions = self.sessions.lock();
1777 let thread = sessions.get(&session_id).unwrap().clone();
1778
1779 cx.spawn(async move |cx| {
1780 thread
1781 .update(cx, |thread, cx| thread.cancel(cx))
1782 .unwrap()
1783 .await
1784 })
1785 .detach();
1786 }
1787 }
1788}