1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol::{self as acp};
13use anyhow::{Context as _, Result};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::Formatter;
24use std::process::ExitStatus;
25use std::rc::Rc;
26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
27use ui::App;
28use util::ResultExt;
29
30#[derive(Debug)]
31pub struct UserMessage {
32 pub content: ContentBlock,
33}
34
35impl UserMessage {
36 pub fn from_acp(
37 message: impl IntoIterator<Item = acp::ContentBlock>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut App,
40 ) -> Self {
41 let mut content = ContentBlock::Empty;
42 for chunk in message {
43 content.append(chunk, &language_registry, cx)
44 }
45 Self { content: content }
46 }
47
48 fn to_markdown(&self, cx: &App) -> String {
49 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
126 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
127 Some(locations)
128 } else {
129 None
130 }
131 }
132}
133
134#[derive(Debug)]
135pub struct ToolCall {
136 pub id: acp::ToolCallId,
137 pub label: Entity<Markdown>,
138 pub kind: acp::ToolKind,
139 pub content: Vec<ToolCallContent>,
140 pub status: ToolCallStatus,
141 pub locations: Vec<acp::ToolCallLocation>,
142 pub raw_input: Option<serde_json::Value>,
143 pub raw_output: Option<serde_json::Value>,
144}
145
146impl ToolCall {
147 fn from_acp(
148 tool_call: acp::ToolCall,
149 status: ToolCallStatus,
150 language_registry: Arc<LanguageRegistry>,
151 cx: &mut App,
152 ) -> Self {
153 Self {
154 id: tool_call.id,
155 label: cx.new(|cx| {
156 Markdown::new(
157 tool_call.title.into(),
158 Some(language_registry.clone()),
159 None,
160 cx,
161 )
162 }),
163 kind: tool_call.kind,
164 content: tool_call
165 .content
166 .into_iter()
167 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
168 .collect(),
169 locations: tool_call.locations,
170 status,
171 raw_input: tool_call.raw_input,
172 raw_output: tool_call.raw_output,
173 }
174 }
175
176 fn update_fields(
177 &mut self,
178 fields: acp::ToolCallUpdateFields,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) {
182 let acp::ToolCallUpdateFields {
183 kind,
184 status,
185 title,
186 content,
187 locations,
188 raw_input,
189 raw_output,
190 } = fields;
191
192 if let Some(kind) = kind {
193 self.kind = kind;
194 }
195
196 if let Some(status) = status {
197 self.status = ToolCallStatus::Allowed { status };
198 }
199
200 if let Some(title) = title {
201 self.label.update(cx, |label, cx| {
202 label.replace(title, cx);
203 });
204 }
205
206 if let Some(content) = content {
207 self.content = content
208 .into_iter()
209 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
210 .collect();
211 }
212
213 if let Some(locations) = locations {
214 self.locations = locations;
215 }
216
217 if let Some(raw_input) = raw_input {
218 self.raw_input = Some(raw_input);
219 }
220
221 if let Some(raw_output) = raw_output {
222 if self.content.is_empty() {
223 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
224 {
225 self.content
226 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
227 markdown,
228 }));
229 }
230 }
231 self.raw_output = Some(raw_output);
232 }
233 }
234
235 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
236 self.content.iter().filter_map(|content| match content {
237 ToolCallContent::Diff(diff) => Some(diff),
238 ToolCallContent::ContentBlock(_) => None,
239 ToolCallContent::Terminal(_) => None,
240 })
241 }
242
243 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::Terminal(terminal) => Some(terminal),
246 ToolCallContent::ContentBlock(_) => None,
247 ToolCallContent::Diff(_) => None,
248 })
249 }
250
251 fn to_markdown(&self, cx: &App) -> String {
252 let mut markdown = format!(
253 "**Tool Call: {}**\nStatus: {}\n\n",
254 self.label.read(cx).source(),
255 self.status
256 );
257 for content in &self.content {
258 markdown.push_str(content.to_markdown(cx).as_str());
259 markdown.push_str("\n\n");
260 }
261 markdown
262 }
263}
264
265#[derive(Debug)]
266pub enum ToolCallStatus {
267 WaitingForConfirmation {
268 options: Vec<acp::PermissionOption>,
269 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
270 },
271 Allowed {
272 status: acp::ToolCallStatus,
273 },
274 Rejected,
275 Canceled,
276}
277
278impl Display for ToolCallStatus {
279 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280 write!(
281 f,
282 "{}",
283 match self {
284 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
285 ToolCallStatus::Allowed { status } => match status {
286 acp::ToolCallStatus::Pending => "Pending",
287 acp::ToolCallStatus::InProgress => "In Progress",
288 acp::ToolCallStatus::Completed => "Completed",
289 acp::ToolCallStatus::Failed => "Failed",
290 },
291 ToolCallStatus::Rejected => "Rejected",
292 ToolCallStatus::Canceled => "Canceled",
293 }
294 )
295 }
296}
297
298#[derive(Debug, PartialEq, Clone)]
299pub enum ContentBlock {
300 Empty,
301 Markdown { markdown: Entity<Markdown> },
302 ResourceLink { resource_link: acp::ResourceLink },
303}
304
305impl ContentBlock {
306 pub fn new(
307 block: acp::ContentBlock,
308 language_registry: &Arc<LanguageRegistry>,
309 cx: &mut App,
310 ) -> Self {
311 let mut this = Self::Empty;
312 this.append(block, language_registry, cx);
313 this
314 }
315
316 pub fn new_combined(
317 blocks: impl IntoIterator<Item = acp::ContentBlock>,
318 language_registry: Arc<LanguageRegistry>,
319 cx: &mut App,
320 ) -> Self {
321 let mut this = Self::Empty;
322 for block in blocks {
323 this.append(block, &language_registry, cx);
324 }
325 this
326 }
327
328 pub fn append(
329 &mut self,
330 block: acp::ContentBlock,
331 language_registry: &Arc<LanguageRegistry>,
332 cx: &mut App,
333 ) {
334 if matches!(self, ContentBlock::Empty) {
335 if let acp::ContentBlock::ResourceLink(resource_link) = block {
336 *self = ContentBlock::ResourceLink { resource_link };
337 return;
338 }
339 }
340
341 let new_content = self.extract_content_from_block(block);
342
343 match self {
344 ContentBlock::Empty => {
345 *self = Self::create_markdown_block(new_content, language_registry, cx);
346 }
347 ContentBlock::Markdown { markdown } => {
348 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
349 }
350 ContentBlock::ResourceLink { resource_link } => {
351 let existing_content = Self::resource_link_to_content(&resource_link.uri);
352 let combined = format!("{}\n{}", existing_content, new_content);
353
354 *self = Self::create_markdown_block(combined, language_registry, cx);
355 }
356 }
357 }
358
359 fn resource_link_to_content(uri: &str) -> String {
360 if let Some(uri) = MentionUri::parse(&uri).log_err() {
361 uri.to_link()
362 } else {
363 uri.to_string().clone()
364 }
365 }
366
367 fn create_markdown_block(
368 content: String,
369 language_registry: &Arc<LanguageRegistry>,
370 cx: &mut App,
371 ) -> ContentBlock {
372 ContentBlock::Markdown {
373 markdown: cx
374 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
375 }
376 }
377
378 fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
379 match block {
380 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
381 acp::ContentBlock::ResourceLink(resource_link) => {
382 Self::resource_link_to_content(&resource_link.uri)
383 }
384 acp::ContentBlock::Resource(acp::EmbeddedResource {
385 resource:
386 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
387 uri,
388 ..
389 }),
390 ..
391 }) => Self::resource_link_to_content(&uri),
392 acp::ContentBlock::Image(_)
393 | acp::ContentBlock::Audio(_)
394 | acp::ContentBlock::Resource(_) => String::new(),
395 }
396 }
397
398 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
399 match self {
400 ContentBlock::Empty => "",
401 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
402 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
403 }
404 }
405
406 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
407 match self {
408 ContentBlock::Empty => None,
409 ContentBlock::Markdown { markdown } => Some(markdown),
410 ContentBlock::ResourceLink { .. } => None,
411 }
412 }
413
414 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
415 match self {
416 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
417 _ => None,
418 }
419 }
420}
421
422#[derive(Debug)]
423pub enum ToolCallContent {
424 ContentBlock(ContentBlock),
425 Diff(Entity<Diff>),
426 Terminal(Entity<Terminal>),
427}
428
429impl ToolCallContent {
430 pub fn from_acp(
431 content: acp::ToolCallContent,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 match content {
436 acp::ToolCallContent::Content { content } => {
437 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
438 }
439 acp::ToolCallContent::Diff { diff } => {
440 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
441 }
442 }
443 }
444
445 pub fn to_markdown(&self, cx: &App) -> String {
446 match self {
447 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
448 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
449 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
450 }
451 }
452}
453
454#[derive(Debug, PartialEq)]
455pub enum ToolCallUpdate {
456 UpdateFields(acp::ToolCallUpdate),
457 UpdateDiff(ToolCallUpdateDiff),
458 UpdateTerminal(ToolCallUpdateTerminal),
459}
460
461impl ToolCallUpdate {
462 fn id(&self) -> &acp::ToolCallId {
463 match self {
464 Self::UpdateFields(update) => &update.id,
465 Self::UpdateDiff(diff) => &diff.id,
466 Self::UpdateTerminal(terminal) => &terminal.id,
467 }
468 }
469}
470
471impl From<acp::ToolCallUpdate> for ToolCallUpdate {
472 fn from(update: acp::ToolCallUpdate) -> Self {
473 Self::UpdateFields(update)
474 }
475}
476
477impl From<ToolCallUpdateDiff> for ToolCallUpdate {
478 fn from(diff: ToolCallUpdateDiff) -> Self {
479 Self::UpdateDiff(diff)
480 }
481}
482
483#[derive(Debug, PartialEq)]
484pub struct ToolCallUpdateDiff {
485 pub id: acp::ToolCallId,
486 pub diff: Entity<Diff>,
487}
488
489impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
490 fn from(terminal: ToolCallUpdateTerminal) -> Self {
491 Self::UpdateTerminal(terminal)
492 }
493}
494
495#[derive(Debug, PartialEq)]
496pub struct ToolCallUpdateTerminal {
497 pub id: acp::ToolCallId,
498 pub terminal: Entity<Terminal>,
499}
500
501#[derive(Debug, Default)]
502pub struct Plan {
503 pub entries: Vec<PlanEntry>,
504}
505
506#[derive(Debug)]
507pub struct PlanStats<'a> {
508 pub in_progress_entry: Option<&'a PlanEntry>,
509 pub pending: u32,
510 pub completed: u32,
511}
512
513impl Plan {
514 pub fn is_empty(&self) -> bool {
515 self.entries.is_empty()
516 }
517
518 pub fn stats(&self) -> PlanStats<'_> {
519 let mut stats = PlanStats {
520 in_progress_entry: None,
521 pending: 0,
522 completed: 0,
523 };
524
525 for entry in &self.entries {
526 match &entry.status {
527 acp::PlanEntryStatus::Pending => {
528 stats.pending += 1;
529 }
530 acp::PlanEntryStatus::InProgress => {
531 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
532 }
533 acp::PlanEntryStatus::Completed => {
534 stats.completed += 1;
535 }
536 }
537 }
538
539 stats
540 }
541}
542
543#[derive(Debug)]
544pub struct PlanEntry {
545 pub content: Entity<Markdown>,
546 pub priority: acp::PlanEntryPriority,
547 pub status: acp::PlanEntryStatus,
548}
549
550impl PlanEntry {
551 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
552 Self {
553 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
554 priority: entry.priority,
555 status: entry.status,
556 }
557 }
558}
559
560pub struct AcpThread {
561 title: SharedString,
562 entries: Vec<AgentThreadEntry>,
563 plan: Plan,
564 project: Entity<Project>,
565 action_log: Entity<ActionLog>,
566 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
567 send_task: Option<Task<()>>,
568 connection: Rc<dyn AgentConnection>,
569 session_id: acp::SessionId,
570}
571
572pub enum AcpThreadEvent {
573 NewEntry,
574 EntryUpdated(usize),
575 ToolAuthorizationRequired,
576 Stopped,
577 Error,
578 ServerExited(ExitStatus),
579}
580
581impl EventEmitter<AcpThreadEvent> for AcpThread {}
582
583#[derive(PartialEq, Eq)]
584pub enum ThreadStatus {
585 Idle,
586 WaitingForToolConfirmation,
587 Generating,
588}
589
590#[derive(Debug, Clone)]
591pub enum LoadError {
592 Unsupported {
593 error_message: SharedString,
594 upgrade_message: SharedString,
595 upgrade_command: String,
596 },
597 Exited(i32),
598 Other(SharedString),
599}
600
601impl Display for LoadError {
602 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
603 match self {
604 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
605 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
606 LoadError::Other(msg) => write!(f, "{}", msg),
607 }
608 }
609}
610
611impl Error for LoadError {}
612
613impl AcpThread {
614 pub fn new(
615 title: impl Into<SharedString>,
616 connection: Rc<dyn AgentConnection>,
617 project: Entity<Project>,
618 session_id: acp::SessionId,
619 cx: &mut Context<Self>,
620 ) -> Self {
621 let action_log = cx.new(|_| ActionLog::new(project.clone()));
622
623 Self {
624 action_log,
625 shared_buffers: Default::default(),
626 entries: Default::default(),
627 plan: Default::default(),
628 title: title.into(),
629 project,
630 send_task: None,
631 connection,
632 session_id,
633 }
634 }
635
636 pub fn action_log(&self) -> &Entity<ActionLog> {
637 &self.action_log
638 }
639
640 pub fn project(&self) -> &Entity<Project> {
641 &self.project
642 }
643
644 pub fn title(&self) -> SharedString {
645 self.title.clone()
646 }
647
648 pub fn entries(&self) -> &[AgentThreadEntry] {
649 &self.entries
650 }
651
652 pub fn session_id(&self) -> &acp::SessionId {
653 &self.session_id
654 }
655
656 pub fn status(&self) -> ThreadStatus {
657 if self.send_task.is_some() {
658 if self.waiting_for_tool_confirmation() {
659 ThreadStatus::WaitingForToolConfirmation
660 } else {
661 ThreadStatus::Generating
662 }
663 } else {
664 ThreadStatus::Idle
665 }
666 }
667
668 pub fn has_pending_edit_tool_calls(&self) -> bool {
669 for entry in self.entries.iter().rev() {
670 match entry {
671 AgentThreadEntry::UserMessage(_) => return false,
672 AgentThreadEntry::ToolCall(
673 call @ ToolCall {
674 status:
675 ToolCallStatus::Allowed {
676 status:
677 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
678 },
679 ..
680 },
681 ) if call.diffs().next().is_some() => {
682 return true;
683 }
684 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
685 }
686 }
687
688 false
689 }
690
691 pub fn used_tools_since_last_user_message(&self) -> bool {
692 for entry in self.entries.iter().rev() {
693 match entry {
694 AgentThreadEntry::UserMessage(..) => return false,
695 AgentThreadEntry::AssistantMessage(..) => continue,
696 AgentThreadEntry::ToolCall(..) => return true,
697 }
698 }
699
700 false
701 }
702
703 pub fn handle_session_update(
704 &mut self,
705 update: acp::SessionUpdate,
706 cx: &mut Context<Self>,
707 ) -> Result<()> {
708 match update {
709 acp::SessionUpdate::UserMessageChunk { content } => {
710 self.push_user_content_block(content, cx);
711 }
712 acp::SessionUpdate::AgentMessageChunk { content } => {
713 self.push_assistant_content_block(content, false, cx);
714 }
715 acp::SessionUpdate::AgentThoughtChunk { content } => {
716 self.push_assistant_content_block(content, true, cx);
717 }
718 acp::SessionUpdate::ToolCall(tool_call) => {
719 self.upsert_tool_call(tool_call, cx);
720 }
721 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
722 self.update_tool_call(tool_call_update, cx)?;
723 }
724 acp::SessionUpdate::Plan(plan) => {
725 self.update_plan(plan, cx);
726 }
727 }
728 Ok(())
729 }
730
731 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
732 let language_registry = self.project.read(cx).languages().clone();
733 let entries_len = self.entries.len();
734
735 if let Some(last_entry) = self.entries.last_mut()
736 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
737 {
738 content.append(chunk, &language_registry, cx);
739 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
740 } else {
741 let content = ContentBlock::new(chunk, &language_registry, cx);
742 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
743 }
744 }
745
746 pub fn push_assistant_content_block(
747 &mut self,
748 chunk: acp::ContentBlock,
749 is_thought: bool,
750 cx: &mut Context<Self>,
751 ) {
752 let language_registry = self.project.read(cx).languages().clone();
753 let entries_len = self.entries.len();
754 if let Some(last_entry) = self.entries.last_mut()
755 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
756 {
757 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
758 match (chunks.last_mut(), is_thought) {
759 (Some(AssistantMessageChunk::Message { block }), false)
760 | (Some(AssistantMessageChunk::Thought { block }), true) => {
761 block.append(chunk, &language_registry, cx)
762 }
763 _ => {
764 let block = ContentBlock::new(chunk, &language_registry, cx);
765 if is_thought {
766 chunks.push(AssistantMessageChunk::Thought { block })
767 } else {
768 chunks.push(AssistantMessageChunk::Message { block })
769 }
770 }
771 }
772 } else {
773 let block = ContentBlock::new(chunk, &language_registry, cx);
774 let chunk = if is_thought {
775 AssistantMessageChunk::Thought { block }
776 } else {
777 AssistantMessageChunk::Message { block }
778 };
779
780 self.push_entry(
781 AgentThreadEntry::AssistantMessage(AssistantMessage {
782 chunks: vec![chunk],
783 }),
784 cx,
785 );
786 }
787 }
788
789 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
790 self.entries.push(entry);
791 cx.emit(AcpThreadEvent::NewEntry);
792 }
793
794 pub fn update_tool_call(
795 &mut self,
796 update: impl Into<ToolCallUpdate>,
797 cx: &mut Context<Self>,
798 ) -> Result<()> {
799 let update = update.into();
800 let languages = self.project.read(cx).languages().clone();
801
802 let (ix, current_call) = self
803 .tool_call_mut(update.id())
804 .context("Tool call not found")?;
805 match update {
806 ToolCallUpdate::UpdateFields(update) => {
807 current_call.update_fields(update.fields, languages, cx);
808 }
809 ToolCallUpdate::UpdateDiff(update) => {
810 current_call.content.clear();
811 current_call
812 .content
813 .push(ToolCallContent::Diff(update.diff));
814 }
815 ToolCallUpdate::UpdateTerminal(update) => {
816 current_call.content.clear();
817 current_call
818 .content
819 .push(ToolCallContent::Terminal(update.terminal));
820 }
821 }
822
823 cx.emit(AcpThreadEvent::EntryUpdated(ix));
824
825 Ok(())
826 }
827
828 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
829 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
830 let status = ToolCallStatus::Allowed {
831 status: tool_call.status,
832 };
833 self.upsert_tool_call_inner(tool_call, status, cx)
834 }
835
836 pub fn upsert_tool_call_inner(
837 &mut self,
838 tool_call: acp::ToolCall,
839 status: ToolCallStatus,
840 cx: &mut Context<Self>,
841 ) {
842 let language_registry = self.project.read(cx).languages().clone();
843 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
844
845 let location = call.locations.last().cloned();
846
847 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
848 *current_call = call;
849
850 cx.emit(AcpThreadEvent::EntryUpdated(ix));
851 } else {
852 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
853 }
854
855 if let Some(location) = location {
856 self.set_project_location(location, cx)
857 }
858 }
859
860 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
861 // The tool call we are looking for is typically the last one, or very close to the end.
862 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
863 self.entries
864 .iter_mut()
865 .enumerate()
866 .rev()
867 .find_map(|(index, tool_call)| {
868 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
869 && &tool_call.id == id
870 {
871 Some((index, tool_call))
872 } else {
873 None
874 }
875 })
876 }
877
878 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
879 self.project.update(cx, |project, cx| {
880 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
881 return;
882 };
883 let buffer = project.open_buffer(path, cx);
884 cx.spawn(async move |project, cx| {
885 let buffer = buffer.await?;
886
887 project.update(cx, |project, cx| {
888 let position = if let Some(line) = location.line {
889 let snapshot = buffer.read(cx).snapshot();
890 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
891 snapshot.anchor_before(point)
892 } else {
893 Anchor::MIN
894 };
895
896 project.set_agent_location(
897 Some(AgentLocation {
898 buffer: buffer.downgrade(),
899 position,
900 }),
901 cx,
902 );
903 })
904 })
905 .detach_and_log_err(cx);
906 });
907 }
908
909 pub fn request_tool_call_authorization(
910 &mut self,
911 tool_call: acp::ToolCall,
912 options: Vec<acp::PermissionOption>,
913 cx: &mut Context<Self>,
914 ) -> oneshot::Receiver<acp::PermissionOptionId> {
915 let (tx, rx) = oneshot::channel();
916
917 let status = ToolCallStatus::WaitingForConfirmation {
918 options,
919 respond_tx: tx,
920 };
921
922 self.upsert_tool_call_inner(tool_call, status, cx);
923 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
924 rx
925 }
926
927 pub fn authorize_tool_call(
928 &mut self,
929 id: acp::ToolCallId,
930 option_id: acp::PermissionOptionId,
931 option_kind: acp::PermissionOptionKind,
932 cx: &mut Context<Self>,
933 ) {
934 let Some((ix, call)) = self.tool_call_mut(&id) else {
935 return;
936 };
937
938 let new_status = match option_kind {
939 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
940 ToolCallStatus::Rejected
941 }
942 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
943 ToolCallStatus::Allowed {
944 status: acp::ToolCallStatus::InProgress,
945 }
946 }
947 };
948
949 let curr_status = mem::replace(&mut call.status, new_status);
950
951 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
952 respond_tx.send(option_id).log_err();
953 } else if cfg!(debug_assertions) {
954 panic!("tried to authorize an already authorized tool call");
955 }
956
957 cx.emit(AcpThreadEvent::EntryUpdated(ix));
958 }
959
960 /// Returns true if the last turn is awaiting tool authorization
961 pub fn waiting_for_tool_confirmation(&self) -> bool {
962 for entry in self.entries.iter().rev() {
963 match &entry {
964 AgentThreadEntry::ToolCall(call) => match call.status {
965 ToolCallStatus::WaitingForConfirmation { .. } => return true,
966 ToolCallStatus::Allowed { .. }
967 | ToolCallStatus::Rejected
968 | ToolCallStatus::Canceled => continue,
969 },
970 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
971 // Reached the beginning of the turn
972 return false;
973 }
974 }
975 }
976 false
977 }
978
979 pub fn plan(&self) -> &Plan {
980 &self.plan
981 }
982
983 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
984 let new_entries_len = request.entries.len();
985 let mut new_entries = request.entries.into_iter();
986
987 // Reuse existing markdown to prevent flickering
988 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
989 let PlanEntry {
990 content,
991 priority,
992 status,
993 } = old;
994 content.update(cx, |old, cx| {
995 old.replace(new.content, cx);
996 });
997 *priority = new.priority;
998 *status = new.status;
999 }
1000 for new in new_entries {
1001 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1002 }
1003 self.plan.entries.truncate(new_entries_len);
1004
1005 cx.notify();
1006 }
1007
1008 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1009 self.plan
1010 .entries
1011 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1012 cx.notify();
1013 }
1014
1015 #[cfg(any(test, feature = "test-support"))]
1016 pub fn send_raw(
1017 &mut self,
1018 message: &str,
1019 cx: &mut Context<Self>,
1020 ) -> BoxFuture<'static, Result<()>> {
1021 self.send(
1022 vec![acp::ContentBlock::Text(acp::TextContent {
1023 text: message.to_string(),
1024 annotations: None,
1025 })],
1026 cx,
1027 )
1028 }
1029
1030 pub fn send(
1031 &mut self,
1032 message: Vec<acp::ContentBlock>,
1033 cx: &mut Context<Self>,
1034 ) -> BoxFuture<'static, Result<()>> {
1035 let block = ContentBlock::new_combined(
1036 message.clone(),
1037 self.project.read(cx).languages().clone(),
1038 cx,
1039 );
1040 self.push_entry(
1041 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1042 cx,
1043 );
1044 self.clear_completed_plan_entries(cx);
1045
1046 let (tx, rx) = oneshot::channel();
1047 let cancel_task = self.cancel(cx);
1048
1049 self.send_task = Some(cx.spawn(async move |this, cx| {
1050 async {
1051 cancel_task.await;
1052
1053 let result = this
1054 .update(cx, |this, cx| {
1055 this.connection.prompt(
1056 acp::PromptRequest {
1057 prompt: message,
1058 session_id: this.session_id.clone(),
1059 },
1060 cx,
1061 )
1062 })?
1063 .await;
1064
1065 tx.send(result).log_err();
1066
1067 anyhow::Ok(())
1068 }
1069 .await
1070 .log_err();
1071 }));
1072
1073 cx.spawn(async move |this, cx| match rx.await {
1074 Ok(Err(e)) => {
1075 this.update(cx, |this, cx| {
1076 this.send_task.take();
1077 cx.emit(AcpThreadEvent::Error)
1078 })
1079 .log_err();
1080 Err(e)?
1081 }
1082 result => {
1083 let cancelled = matches!(
1084 result,
1085 Ok(Ok(acp::PromptResponse {
1086 stop_reason: acp::StopReason::Cancelled
1087 }))
1088 );
1089
1090 // We only take the task if the current prompt wasn't cancelled.
1091 //
1092 // This prompt may have been cancelled because another one was sent
1093 // while it was still generating. In these cases, dropping `send_task`
1094 // would cause the next generation to be cancelled.
1095 if !cancelled {
1096 this.update(cx, |this, _cx| this.send_task.take()).ok();
1097 }
1098
1099 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1100 .log_err();
1101 Ok(())
1102 }
1103 })
1104 .boxed()
1105 }
1106
1107 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1108 let Some(send_task) = self.send_task.take() else {
1109 return Task::ready(());
1110 };
1111
1112 for entry in self.entries.iter_mut() {
1113 if let AgentThreadEntry::ToolCall(call) = entry {
1114 let cancel = matches!(
1115 call.status,
1116 ToolCallStatus::WaitingForConfirmation { .. }
1117 | ToolCallStatus::Allowed {
1118 status: acp::ToolCallStatus::InProgress
1119 }
1120 );
1121
1122 if cancel {
1123 call.status = ToolCallStatus::Canceled;
1124 }
1125 }
1126 }
1127
1128 self.connection.cancel(&self.session_id, cx);
1129
1130 // Wait for the send task to complete
1131 cx.foreground_executor().spawn(send_task)
1132 }
1133
1134 pub fn read_text_file(
1135 &self,
1136 path: PathBuf,
1137 line: Option<u32>,
1138 limit: Option<u32>,
1139 reuse_shared_snapshot: bool,
1140 cx: &mut Context<Self>,
1141 ) -> Task<Result<String>> {
1142 let project = self.project.clone();
1143 let action_log = self.action_log.clone();
1144 cx.spawn(async move |this, cx| {
1145 let load = project.update(cx, |project, cx| {
1146 let path = project
1147 .project_path_for_absolute_path(&path, cx)
1148 .context("invalid path")?;
1149 anyhow::Ok(project.open_buffer(path, cx))
1150 });
1151 let buffer = load??.await?;
1152
1153 let snapshot = if reuse_shared_snapshot {
1154 this.read_with(cx, |this, _| {
1155 this.shared_buffers.get(&buffer.clone()).cloned()
1156 })
1157 .log_err()
1158 .flatten()
1159 } else {
1160 None
1161 };
1162
1163 let snapshot = if let Some(snapshot) = snapshot {
1164 snapshot
1165 } else {
1166 action_log.update(cx, |action_log, cx| {
1167 action_log.buffer_read(buffer.clone(), cx);
1168 })?;
1169 project.update(cx, |project, cx| {
1170 let position = buffer
1171 .read(cx)
1172 .snapshot()
1173 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1174 project.set_agent_location(
1175 Some(AgentLocation {
1176 buffer: buffer.downgrade(),
1177 position,
1178 }),
1179 cx,
1180 );
1181 })?;
1182
1183 buffer.update(cx, |buffer, _| buffer.snapshot())?
1184 };
1185
1186 this.update(cx, |this, _| {
1187 let text = snapshot.text();
1188 this.shared_buffers.insert(buffer.clone(), snapshot);
1189 if line.is_none() && limit.is_none() {
1190 return Ok(text);
1191 }
1192 let limit = limit.unwrap_or(u32::MAX) as usize;
1193 let Some(line) = line else {
1194 return Ok(text.lines().take(limit).collect::<String>());
1195 };
1196
1197 let count = text.lines().count();
1198 if count < line as usize {
1199 anyhow::bail!("There are only {} lines", count);
1200 }
1201 Ok(text
1202 .lines()
1203 .skip(line as usize + 1)
1204 .take(limit)
1205 .collect::<String>())
1206 })?
1207 })
1208 }
1209
1210 pub fn write_text_file(
1211 &self,
1212 path: PathBuf,
1213 content: String,
1214 cx: &mut Context<Self>,
1215 ) -> Task<Result<()>> {
1216 let project = self.project.clone();
1217 let action_log = self.action_log.clone();
1218 cx.spawn(async move |this, cx| {
1219 let load = project.update(cx, |project, cx| {
1220 let path = project
1221 .project_path_for_absolute_path(&path, cx)
1222 .context("invalid path")?;
1223 anyhow::Ok(project.open_buffer(path, cx))
1224 });
1225 let buffer = load??.await?;
1226 let snapshot = this.update(cx, |this, cx| {
1227 this.shared_buffers
1228 .get(&buffer)
1229 .cloned()
1230 .unwrap_or_else(|| buffer.read(cx).snapshot())
1231 })?;
1232 let edits = cx
1233 .background_executor()
1234 .spawn(async move {
1235 let old_text = snapshot.text();
1236 text_diff(old_text.as_str(), &content)
1237 .into_iter()
1238 .map(|(range, replacement)| {
1239 (
1240 snapshot.anchor_after(range.start)
1241 ..snapshot.anchor_before(range.end),
1242 replacement,
1243 )
1244 })
1245 .collect::<Vec<_>>()
1246 })
1247 .await;
1248 cx.update(|cx| {
1249 project.update(cx, |project, cx| {
1250 project.set_agent_location(
1251 Some(AgentLocation {
1252 buffer: buffer.downgrade(),
1253 position: edits
1254 .last()
1255 .map(|(range, _)| range.end)
1256 .unwrap_or(Anchor::MIN),
1257 }),
1258 cx,
1259 );
1260 });
1261
1262 action_log.update(cx, |action_log, cx| {
1263 action_log.buffer_read(buffer.clone(), cx);
1264 });
1265 buffer.update(cx, |buffer, cx| {
1266 buffer.edit(edits, None, cx);
1267 });
1268 action_log.update(cx, |action_log, cx| {
1269 action_log.buffer_edited(buffer.clone(), cx);
1270 });
1271 })?;
1272 project
1273 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1274 .await
1275 })
1276 }
1277
1278 pub fn to_markdown(&self, cx: &App) -> String {
1279 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1280 }
1281
1282 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1283 cx.emit(AcpThreadEvent::ServerExited(status));
1284 }
1285}
1286
1287fn markdown_for_raw_output(
1288 raw_output: &serde_json::Value,
1289 language_registry: &Arc<LanguageRegistry>,
1290 cx: &mut App,
1291) -> Option<Entity<Markdown>> {
1292 match raw_output {
1293 serde_json::Value::Null => None,
1294 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1295 Markdown::new(
1296 value.to_string().into(),
1297 Some(language_registry.clone()),
1298 None,
1299 cx,
1300 )
1301 })),
1302 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1303 Markdown::new(
1304 value.to_string().into(),
1305 Some(language_registry.clone()),
1306 None,
1307 cx,
1308 )
1309 })),
1310 serde_json::Value::String(value) => Some(cx.new(|cx| {
1311 Markdown::new(
1312 value.clone().into(),
1313 Some(language_registry.clone()),
1314 None,
1315 cx,
1316 )
1317 })),
1318 value => Some(cx.new(|cx| {
1319 Markdown::new(
1320 format!("```json\n{}\n```", value).into(),
1321 Some(language_registry.clone()),
1322 None,
1323 cx,
1324 )
1325 })),
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use super::*;
1332 use anyhow::anyhow;
1333 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1334 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1335 use indoc::indoc;
1336 use project::FakeFs;
1337 use rand::Rng as _;
1338 use serde_json::json;
1339 use settings::SettingsStore;
1340 use smol::stream::StreamExt as _;
1341 use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1342
1343 use util::path;
1344
1345 fn init_test(cx: &mut TestAppContext) {
1346 env_logger::try_init().ok();
1347 cx.update(|cx| {
1348 let settings_store = SettingsStore::test(cx);
1349 cx.set_global(settings_store);
1350 Project::init_settings(cx);
1351 language::init(cx);
1352 });
1353 }
1354
1355 #[gpui::test]
1356 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1357 init_test(cx);
1358
1359 let fs = FakeFs::new(cx.executor());
1360 let project = Project::test(fs, [], cx).await;
1361 let connection = Rc::new(FakeAgentConnection::new());
1362 let thread = cx
1363 .spawn(async move |mut cx| {
1364 connection
1365 .new_thread(project, Path::new(path!("/test")), &mut cx)
1366 .await
1367 })
1368 .await
1369 .unwrap();
1370
1371 // Test creating a new user message
1372 thread.update(cx, |thread, cx| {
1373 thread.push_user_content_block(
1374 acp::ContentBlock::Text(acp::TextContent {
1375 annotations: None,
1376 text: "Hello, ".to_string(),
1377 }),
1378 cx,
1379 );
1380 });
1381
1382 thread.update(cx, |thread, cx| {
1383 assert_eq!(thread.entries.len(), 1);
1384 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1385 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1386 } else {
1387 panic!("Expected UserMessage");
1388 }
1389 });
1390
1391 // Test appending to existing user message
1392 thread.update(cx, |thread, cx| {
1393 thread.push_user_content_block(
1394 acp::ContentBlock::Text(acp::TextContent {
1395 annotations: None,
1396 text: "world!".to_string(),
1397 }),
1398 cx,
1399 );
1400 });
1401
1402 thread.update(cx, |thread, cx| {
1403 assert_eq!(thread.entries.len(), 1);
1404 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1405 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1406 } else {
1407 panic!("Expected UserMessage");
1408 }
1409 });
1410
1411 // Test creating new user message after assistant message
1412 thread.update(cx, |thread, cx| {
1413 thread.push_assistant_content_block(
1414 acp::ContentBlock::Text(acp::TextContent {
1415 annotations: None,
1416 text: "Assistant response".to_string(),
1417 }),
1418 false,
1419 cx,
1420 );
1421 });
1422
1423 thread.update(cx, |thread, cx| {
1424 thread.push_user_content_block(
1425 acp::ContentBlock::Text(acp::TextContent {
1426 annotations: None,
1427 text: "New user message".to_string(),
1428 }),
1429 cx,
1430 );
1431 });
1432
1433 thread.update(cx, |thread, cx| {
1434 assert_eq!(thread.entries.len(), 3);
1435 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1436 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1437 } else {
1438 panic!("Expected UserMessage at index 2");
1439 }
1440 });
1441 }
1442
1443 #[gpui::test]
1444 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1445 init_test(cx);
1446
1447 let fs = FakeFs::new(cx.executor());
1448 let project = Project::test(fs, [], cx).await;
1449 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1450 |_, thread, mut cx| {
1451 async move {
1452 thread.update(&mut cx, |thread, cx| {
1453 thread
1454 .handle_session_update(
1455 acp::SessionUpdate::AgentThoughtChunk {
1456 content: "Thinking ".into(),
1457 },
1458 cx,
1459 )
1460 .unwrap();
1461 thread
1462 .handle_session_update(
1463 acp::SessionUpdate::AgentThoughtChunk {
1464 content: "hard!".into(),
1465 },
1466 cx,
1467 )
1468 .unwrap();
1469 })?;
1470 Ok(acp::PromptResponse {
1471 stop_reason: acp::StopReason::EndTurn,
1472 })
1473 }
1474 .boxed_local()
1475 },
1476 ));
1477
1478 let thread = cx
1479 .spawn(async move |mut cx| {
1480 connection
1481 .new_thread(project, Path::new(path!("/test")), &mut cx)
1482 .await
1483 })
1484 .await
1485 .unwrap();
1486
1487 thread
1488 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1489 .await
1490 .unwrap();
1491
1492 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1493 assert_eq!(
1494 output,
1495 indoc! {r#"
1496 ## User
1497
1498 Hello from Zed!
1499
1500 ## Assistant
1501
1502 <thinking>
1503 Thinking hard!
1504 </thinking>
1505
1506 "#}
1507 );
1508 }
1509
1510 #[gpui::test]
1511 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1512 init_test(cx);
1513
1514 let fs = FakeFs::new(cx.executor());
1515 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1516 .await;
1517 let project = Project::test(fs.clone(), [], cx).await;
1518 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1519 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1520 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1521 move |_, thread, mut cx| {
1522 let read_file_tx = read_file_tx.clone();
1523 async move {
1524 let content = thread
1525 .update(&mut cx, |thread, cx| {
1526 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1527 })
1528 .unwrap()
1529 .await
1530 .unwrap();
1531 assert_eq!(content, "one\ntwo\nthree\n");
1532 read_file_tx.take().unwrap().send(()).unwrap();
1533 thread
1534 .update(&mut cx, |thread, cx| {
1535 thread.write_text_file(
1536 path!("/tmp/foo").into(),
1537 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1538 cx,
1539 )
1540 })
1541 .unwrap()
1542 .await
1543 .unwrap();
1544 Ok(acp::PromptResponse {
1545 stop_reason: acp::StopReason::EndTurn,
1546 })
1547 }
1548 .boxed_local()
1549 },
1550 ));
1551
1552 let (worktree, pathbuf) = project
1553 .update(cx, |project, cx| {
1554 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1555 })
1556 .await
1557 .unwrap();
1558 let buffer = project
1559 .update(cx, |project, cx| {
1560 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1561 })
1562 .await
1563 .unwrap();
1564
1565 let thread = cx
1566 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1567 .await
1568 .unwrap();
1569
1570 let request = thread.update(cx, |thread, cx| {
1571 thread.send_raw("Extend the count in /tmp/foo", cx)
1572 });
1573 read_file_rx.await.ok();
1574 buffer.update(cx, |buffer, cx| {
1575 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1576 });
1577 cx.run_until_parked();
1578 assert_eq!(
1579 buffer.read_with(cx, |buffer, _| buffer.text()),
1580 "zero\none\ntwo\nthree\nfour\nfive\n"
1581 );
1582 assert_eq!(
1583 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1584 "zero\none\ntwo\nthree\nfour\nfive\n"
1585 );
1586 request.await.unwrap();
1587 }
1588
1589 #[gpui::test]
1590 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1591 init_test(cx);
1592
1593 let fs = FakeFs::new(cx.executor());
1594 let project = Project::test(fs, [], cx).await;
1595 let id = acp::ToolCallId("test".into());
1596
1597 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1598 let id = id.clone();
1599 move |_, thread, mut cx| {
1600 let id = id.clone();
1601 async move {
1602 thread
1603 .update(&mut cx, |thread, cx| {
1604 thread.handle_session_update(
1605 acp::SessionUpdate::ToolCall(acp::ToolCall {
1606 id: id.clone(),
1607 title: "Label".into(),
1608 kind: acp::ToolKind::Fetch,
1609 status: acp::ToolCallStatus::InProgress,
1610 content: vec![],
1611 locations: vec![],
1612 raw_input: None,
1613 raw_output: None,
1614 }),
1615 cx,
1616 )
1617 })
1618 .unwrap()
1619 .unwrap();
1620 Ok(acp::PromptResponse {
1621 stop_reason: acp::StopReason::EndTurn,
1622 })
1623 }
1624 .boxed_local()
1625 }
1626 }));
1627
1628 let thread = cx
1629 .spawn(async move |mut cx| {
1630 connection
1631 .new_thread(project, Path::new(path!("/test")), &mut cx)
1632 .await
1633 })
1634 .await
1635 .unwrap();
1636
1637 let request = thread.update(cx, |thread, cx| {
1638 thread.send_raw("Fetch https://example.com", cx)
1639 });
1640
1641 run_until_first_tool_call(&thread, cx).await;
1642
1643 thread.read_with(cx, |thread, _| {
1644 assert!(matches!(
1645 thread.entries[1],
1646 AgentThreadEntry::ToolCall(ToolCall {
1647 status: ToolCallStatus::Allowed {
1648 status: acp::ToolCallStatus::InProgress,
1649 ..
1650 },
1651 ..
1652 })
1653 ));
1654 });
1655
1656 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1657
1658 thread.read_with(cx, |thread, _| {
1659 assert!(matches!(
1660 &thread.entries[1],
1661 AgentThreadEntry::ToolCall(ToolCall {
1662 status: ToolCallStatus::Canceled,
1663 ..
1664 })
1665 ));
1666 });
1667
1668 thread
1669 .update(cx, |thread, cx| {
1670 thread.handle_session_update(
1671 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1672 id,
1673 fields: acp::ToolCallUpdateFields {
1674 status: Some(acp::ToolCallStatus::Completed),
1675 ..Default::default()
1676 },
1677 }),
1678 cx,
1679 )
1680 })
1681 .unwrap();
1682
1683 request.await.unwrap();
1684
1685 thread.read_with(cx, |thread, _| {
1686 assert!(matches!(
1687 thread.entries[1],
1688 AgentThreadEntry::ToolCall(ToolCall {
1689 status: ToolCallStatus::Allowed {
1690 status: acp::ToolCallStatus::Completed,
1691 ..
1692 },
1693 ..
1694 })
1695 ));
1696 });
1697 }
1698
1699 #[gpui::test]
1700 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1701 init_test(cx);
1702 let fs = FakeFs::new(cx.background_executor.clone());
1703 fs.insert_tree(path!("/test"), json!({})).await;
1704 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1705
1706 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1707 move |_, thread, mut cx| {
1708 async move {
1709 thread
1710 .update(&mut cx, |thread, cx| {
1711 thread.handle_session_update(
1712 acp::SessionUpdate::ToolCall(acp::ToolCall {
1713 id: acp::ToolCallId("test".into()),
1714 title: "Label".into(),
1715 kind: acp::ToolKind::Edit,
1716 status: acp::ToolCallStatus::Completed,
1717 content: vec![acp::ToolCallContent::Diff {
1718 diff: acp::Diff {
1719 path: "/test/test.txt".into(),
1720 old_text: None,
1721 new_text: "foo".into(),
1722 },
1723 }],
1724 locations: vec![],
1725 raw_input: None,
1726 raw_output: None,
1727 }),
1728 cx,
1729 )
1730 })
1731 .unwrap()
1732 .unwrap();
1733 Ok(acp::PromptResponse {
1734 stop_reason: acp::StopReason::EndTurn,
1735 })
1736 }
1737 .boxed_local()
1738 }
1739 }));
1740
1741 let thread = connection
1742 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1743 .await
1744 .unwrap();
1745 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1746 .await
1747 .unwrap();
1748
1749 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1750 }
1751
1752 async fn run_until_first_tool_call(
1753 thread: &Entity<AcpThread>,
1754 cx: &mut TestAppContext,
1755 ) -> usize {
1756 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1757
1758 let subscription = cx.update(|cx| {
1759 cx.subscribe(thread, move |thread, _, cx| {
1760 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1761 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1762 return tx.try_send(ix).unwrap();
1763 }
1764 }
1765 })
1766 });
1767
1768 select! {
1769 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1770 panic!("Timeout waiting for tool call")
1771 }
1772 ix = rx.next().fuse() => {
1773 drop(subscription);
1774 ix.unwrap()
1775 }
1776 }
1777 }
1778
1779 #[derive(Clone, Default)]
1780 struct FakeAgentConnection {
1781 auth_methods: Vec<acp::AuthMethod>,
1782 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1783 on_user_message: Option<
1784 Rc<
1785 dyn Fn(
1786 acp::PromptRequest,
1787 WeakEntity<AcpThread>,
1788 AsyncApp,
1789 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1790 + 'static,
1791 >,
1792 >,
1793 }
1794
1795 impl FakeAgentConnection {
1796 fn new() -> Self {
1797 Self {
1798 auth_methods: Vec::new(),
1799 on_user_message: None,
1800 sessions: Arc::default(),
1801 }
1802 }
1803
1804 #[expect(unused)]
1805 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1806 self.auth_methods = auth_methods;
1807 self
1808 }
1809
1810 fn on_user_message(
1811 mut self,
1812 handler: impl Fn(
1813 acp::PromptRequest,
1814 WeakEntity<AcpThread>,
1815 AsyncApp,
1816 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1817 + 'static,
1818 ) -> Self {
1819 self.on_user_message.replace(Rc::new(handler));
1820 self
1821 }
1822 }
1823
1824 impl AgentConnection for FakeAgentConnection {
1825 fn auth_methods(&self) -> &[acp::AuthMethod] {
1826 &self.auth_methods
1827 }
1828
1829 fn new_thread(
1830 self: Rc<Self>,
1831 project: Entity<Project>,
1832 _cwd: &Path,
1833 cx: &mut gpui::AsyncApp,
1834 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1835 let session_id = acp::SessionId(
1836 rand::thread_rng()
1837 .sample_iter(&rand::distributions::Alphanumeric)
1838 .take(7)
1839 .map(char::from)
1840 .collect::<String>()
1841 .into(),
1842 );
1843 let thread = cx
1844 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1845 .unwrap();
1846 self.sessions.lock().insert(session_id, thread.downgrade());
1847 Task::ready(Ok(thread))
1848 }
1849
1850 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1851 if self.auth_methods().iter().any(|m| m.id == method) {
1852 Task::ready(Ok(()))
1853 } else {
1854 Task::ready(Err(anyhow!("Invalid Auth Method")))
1855 }
1856 }
1857
1858 fn prompt(
1859 &self,
1860 params: acp::PromptRequest,
1861 cx: &mut App,
1862 ) -> Task<gpui::Result<acp::PromptResponse>> {
1863 let sessions = self.sessions.lock();
1864 let thread = sessions.get(¶ms.session_id).unwrap();
1865 if let Some(handler) = &self.on_user_message {
1866 let handler = handler.clone();
1867 let thread = thread.clone();
1868 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1869 } else {
1870 Task::ready(Ok(acp::PromptResponse {
1871 stop_reason: acp::StopReason::EndTurn,
1872 }))
1873 }
1874 }
1875
1876 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1877 let sessions = self.sessions.lock();
1878 let thread = sessions.get(&session_id).unwrap().clone();
1879
1880 cx.spawn(async move |cx| {
1881 thread
1882 .update(cx, |thread, cx| thread.cancel(cx))
1883 .unwrap()
1884 .await
1885 })
1886 .detach();
1887 }
1888 }
1889}