1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol::{self as acp};
13use anyhow::{Context as _, Result};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::Formatter;
24use std::process::ExitStatus;
25use std::rc::Rc;
26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
27use ui::App;
28use util::ResultExt;
29
30#[derive(Debug)]
31pub struct UserMessage {
32 pub content: ContentBlock,
33}
34
35impl UserMessage {
36 pub fn from_acp(
37 message: impl IntoIterator<Item = acp::ContentBlock>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut App,
40 ) -> Self {
41 let mut content = ContentBlock::Empty;
42 for chunk in message {
43 content.append(chunk, &language_registry, cx)
44 }
45 Self { content: content }
46 }
47
48 fn to_markdown(&self, cx: &App) -> String {
49 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
126 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
127 Some(locations)
128 } else {
129 None
130 }
131 }
132}
133
134#[derive(Debug)]
135pub struct ToolCall {
136 pub id: acp::ToolCallId,
137 pub label: Entity<Markdown>,
138 pub kind: acp::ToolKind,
139 pub content: Vec<ToolCallContent>,
140 pub status: ToolCallStatus,
141 pub locations: Vec<acp::ToolCallLocation>,
142 pub raw_input: Option<serde_json::Value>,
143 pub raw_output: Option<serde_json::Value>,
144}
145
146impl ToolCall {
147 fn from_acp(
148 tool_call: acp::ToolCall,
149 status: ToolCallStatus,
150 language_registry: Arc<LanguageRegistry>,
151 cx: &mut App,
152 ) -> Self {
153 Self {
154 id: tool_call.id,
155 label: cx.new(|cx| {
156 Markdown::new(
157 tool_call.title.into(),
158 Some(language_registry.clone()),
159 None,
160 cx,
161 )
162 }),
163 kind: tool_call.kind,
164 content: tool_call
165 .content
166 .into_iter()
167 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
168 .collect(),
169 locations: tool_call.locations,
170 status,
171 raw_input: tool_call.raw_input,
172 raw_output: tool_call.raw_output,
173 }
174 }
175
176 fn update_fields(
177 &mut self,
178 fields: acp::ToolCallUpdateFields,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) {
182 let acp::ToolCallUpdateFields {
183 kind,
184 status,
185 title,
186 content,
187 locations,
188 raw_input,
189 raw_output,
190 } = fields;
191
192 if let Some(kind) = kind {
193 self.kind = kind;
194 }
195
196 if let Some(status) = status {
197 self.status = ToolCallStatus::Allowed { status };
198 }
199
200 if let Some(title) = title {
201 self.label.update(cx, |label, cx| {
202 label.replace(title, cx);
203 });
204 }
205
206 if let Some(content) = content {
207 self.content = content
208 .into_iter()
209 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
210 .collect();
211 }
212
213 if let Some(locations) = locations {
214 self.locations = locations;
215 }
216
217 if let Some(raw_input) = raw_input {
218 self.raw_input = Some(raw_input);
219 }
220
221 if let Some(raw_output) = raw_output {
222 if self.content.is_empty() {
223 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
224 {
225 self.content
226 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
227 markdown,
228 }));
229 }
230 }
231 self.raw_output = Some(raw_output);
232 }
233 }
234
235 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
236 self.content.iter().filter_map(|content| match content {
237 ToolCallContent::Diff(diff) => Some(diff),
238 ToolCallContent::ContentBlock(_) => None,
239 ToolCallContent::Terminal(_) => None,
240 })
241 }
242
243 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::Terminal(terminal) => Some(terminal),
246 ToolCallContent::ContentBlock(_) => None,
247 ToolCallContent::Diff(_) => None,
248 })
249 }
250
251 fn to_markdown(&self, cx: &App) -> String {
252 let mut markdown = format!(
253 "**Tool Call: {}**\nStatus: {}\n\n",
254 self.label.read(cx).source(),
255 self.status
256 );
257 for content in &self.content {
258 markdown.push_str(content.to_markdown(cx).as_str());
259 markdown.push_str("\n\n");
260 }
261 markdown
262 }
263}
264
265#[derive(Debug)]
266pub enum ToolCallStatus {
267 WaitingForConfirmation {
268 options: Vec<acp::PermissionOption>,
269 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
270 },
271 Allowed {
272 status: acp::ToolCallStatus,
273 },
274 Rejected,
275 Canceled,
276}
277
278impl Display for ToolCallStatus {
279 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280 write!(
281 f,
282 "{}",
283 match self {
284 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
285 ToolCallStatus::Allowed { status } => match status {
286 acp::ToolCallStatus::Pending => "Pending",
287 acp::ToolCallStatus::InProgress => "In Progress",
288 acp::ToolCallStatus::Completed => "Completed",
289 acp::ToolCallStatus::Failed => "Failed",
290 },
291 ToolCallStatus::Rejected => "Rejected",
292 ToolCallStatus::Canceled => "Canceled",
293 }
294 )
295 }
296}
297
298#[derive(Debug, PartialEq, Clone)]
299pub enum ContentBlock {
300 Empty,
301 Markdown { markdown: Entity<Markdown> },
302 ResourceLink { resource_link: acp::ResourceLink },
303}
304
305impl ContentBlock {
306 pub fn new(
307 block: acp::ContentBlock,
308 language_registry: &Arc<LanguageRegistry>,
309 cx: &mut App,
310 ) -> Self {
311 let mut this = Self::Empty;
312 this.append(block, language_registry, cx);
313 this
314 }
315
316 pub fn new_combined(
317 blocks: impl IntoIterator<Item = acp::ContentBlock>,
318 language_registry: Arc<LanguageRegistry>,
319 cx: &mut App,
320 ) -> Self {
321 let mut this = Self::Empty;
322 for block in blocks {
323 this.append(block, &language_registry, cx);
324 }
325 this
326 }
327
328 pub fn append(
329 &mut self,
330 block: acp::ContentBlock,
331 language_registry: &Arc<LanguageRegistry>,
332 cx: &mut App,
333 ) {
334 if matches!(self, ContentBlock::Empty) {
335 if let acp::ContentBlock::ResourceLink(resource_link) = block {
336 *self = ContentBlock::ResourceLink { resource_link };
337 return;
338 }
339 }
340
341 let new_content = self.extract_content_from_block(block);
342
343 match self {
344 ContentBlock::Empty => {
345 *self = Self::create_markdown_block(new_content, language_registry, cx);
346 }
347 ContentBlock::Markdown { markdown } => {
348 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
349 }
350 ContentBlock::ResourceLink { resource_link } => {
351 let existing_content = Self::resource_link_to_content(&resource_link.uri);
352 let combined = format!("{}\n{}", existing_content, new_content);
353
354 *self = Self::create_markdown_block(combined, language_registry, cx);
355 }
356 }
357 }
358
359 fn resource_link_to_content(uri: &str) -> String {
360 if let Some(uri) = MentionUri::parse(&uri).log_err() {
361 uri.to_link()
362 } else {
363 uri.to_string().clone()
364 }
365 }
366
367 fn create_markdown_block(
368 content: String,
369 language_registry: &Arc<LanguageRegistry>,
370 cx: &mut App,
371 ) -> ContentBlock {
372 ContentBlock::Markdown {
373 markdown: cx
374 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
375 }
376 }
377
378 fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
379 match block {
380 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
381 acp::ContentBlock::ResourceLink(resource_link) => {
382 Self::resource_link_to_content(&resource_link.uri)
383 }
384 acp::ContentBlock::Resource(acp::EmbeddedResource {
385 resource:
386 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
387 uri,
388 ..
389 }),
390 ..
391 }) => Self::resource_link_to_content(&uri),
392 acp::ContentBlock::Image(_)
393 | acp::ContentBlock::Audio(_)
394 | acp::ContentBlock::Resource(_) => String::new(),
395 }
396 }
397
398 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
399 match self {
400 ContentBlock::Empty => "",
401 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
402 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
403 }
404 }
405
406 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
407 match self {
408 ContentBlock::Empty => None,
409 ContentBlock::Markdown { markdown } => Some(markdown),
410 ContentBlock::ResourceLink { .. } => None,
411 }
412 }
413
414 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
415 match self {
416 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
417 _ => None,
418 }
419 }
420}
421
422#[derive(Debug)]
423pub enum ToolCallContent {
424 ContentBlock(ContentBlock),
425 Diff(Entity<Diff>),
426 Terminal(Entity<Terminal>),
427}
428
429impl ToolCallContent {
430 pub fn from_acp(
431 content: acp::ToolCallContent,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 match content {
436 acp::ToolCallContent::Content { content } => {
437 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
438 }
439 acp::ToolCallContent::Diff { diff } => {
440 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
441 }
442 }
443 }
444
445 pub fn to_markdown(&self, cx: &App) -> String {
446 match self {
447 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
448 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
449 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
450 }
451 }
452}
453
454#[derive(Debug, PartialEq)]
455pub enum ToolCallUpdate {
456 UpdateFields(acp::ToolCallUpdate),
457 UpdateDiff(ToolCallUpdateDiff),
458 UpdateTerminal(ToolCallUpdateTerminal),
459}
460
461impl ToolCallUpdate {
462 fn id(&self) -> &acp::ToolCallId {
463 match self {
464 Self::UpdateFields(update) => &update.id,
465 Self::UpdateDiff(diff) => &diff.id,
466 Self::UpdateTerminal(terminal) => &terminal.id,
467 }
468 }
469}
470
471impl From<acp::ToolCallUpdate> for ToolCallUpdate {
472 fn from(update: acp::ToolCallUpdate) -> Self {
473 Self::UpdateFields(update)
474 }
475}
476
477impl From<ToolCallUpdateDiff> for ToolCallUpdate {
478 fn from(diff: ToolCallUpdateDiff) -> Self {
479 Self::UpdateDiff(diff)
480 }
481}
482
483#[derive(Debug, PartialEq)]
484pub struct ToolCallUpdateDiff {
485 pub id: acp::ToolCallId,
486 pub diff: Entity<Diff>,
487}
488
489impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
490 fn from(terminal: ToolCallUpdateTerminal) -> Self {
491 Self::UpdateTerminal(terminal)
492 }
493}
494
495#[derive(Debug, PartialEq)]
496pub struct ToolCallUpdateTerminal {
497 pub id: acp::ToolCallId,
498 pub terminal: Entity<Terminal>,
499}
500
501#[derive(Debug, Default)]
502pub struct Plan {
503 pub entries: Vec<PlanEntry>,
504}
505
506#[derive(Debug)]
507pub struct PlanStats<'a> {
508 pub in_progress_entry: Option<&'a PlanEntry>,
509 pub pending: u32,
510 pub completed: u32,
511}
512
513impl Plan {
514 pub fn is_empty(&self) -> bool {
515 self.entries.is_empty()
516 }
517
518 pub fn stats(&self) -> PlanStats<'_> {
519 let mut stats = PlanStats {
520 in_progress_entry: None,
521 pending: 0,
522 completed: 0,
523 };
524
525 for entry in &self.entries {
526 match &entry.status {
527 acp::PlanEntryStatus::Pending => {
528 stats.pending += 1;
529 }
530 acp::PlanEntryStatus::InProgress => {
531 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
532 }
533 acp::PlanEntryStatus::Completed => {
534 stats.completed += 1;
535 }
536 }
537 }
538
539 stats
540 }
541}
542
543#[derive(Debug)]
544pub struct PlanEntry {
545 pub content: Entity<Markdown>,
546 pub priority: acp::PlanEntryPriority,
547 pub status: acp::PlanEntryStatus,
548}
549
550impl PlanEntry {
551 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
552 Self {
553 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
554 priority: entry.priority,
555 status: entry.status,
556 }
557 }
558}
559
560pub struct AcpThread {
561 title: SharedString,
562 entries: Vec<AgentThreadEntry>,
563 plan: Plan,
564 project: Entity<Project>,
565 action_log: Entity<ActionLog>,
566 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
567 send_task: Option<Task<()>>,
568 connection: Rc<dyn AgentConnection>,
569 session_id: acp::SessionId,
570}
571
572pub enum AcpThreadEvent {
573 NewEntry,
574 EntryUpdated(usize),
575 ToolAuthorizationRequired,
576 Stopped,
577 Error,
578 ServerExited(ExitStatus),
579}
580
581impl EventEmitter<AcpThreadEvent> for AcpThread {}
582
583#[derive(PartialEq, Eq)]
584pub enum ThreadStatus {
585 Idle,
586 WaitingForToolConfirmation,
587 Generating,
588}
589
590#[derive(Debug, Clone)]
591pub enum LoadError {
592 Unsupported {
593 error_message: SharedString,
594 upgrade_message: SharedString,
595 upgrade_command: String,
596 },
597 Exited(i32),
598 Other(SharedString),
599}
600
601impl Display for LoadError {
602 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
603 match self {
604 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
605 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
606 LoadError::Other(msg) => write!(f, "{}", msg),
607 }
608 }
609}
610
611impl Error for LoadError {}
612
613impl AcpThread {
614 pub fn new(
615 title: impl Into<SharedString>,
616 connection: Rc<dyn AgentConnection>,
617 project: Entity<Project>,
618 session_id: acp::SessionId,
619 cx: &mut Context<Self>,
620 ) -> Self {
621 let action_log = cx.new(|_| ActionLog::new(project.clone()));
622
623 Self {
624 action_log,
625 shared_buffers: Default::default(),
626 entries: Default::default(),
627 plan: Default::default(),
628 title: title.into(),
629 project,
630 send_task: None,
631 connection,
632 session_id,
633 }
634 }
635
636 pub fn action_log(&self) -> &Entity<ActionLog> {
637 &self.action_log
638 }
639
640 pub fn project(&self) -> &Entity<Project> {
641 &self.project
642 }
643
644 pub fn title(&self) -> SharedString {
645 self.title.clone()
646 }
647
648 pub fn entries(&self) -> &[AgentThreadEntry] {
649 &self.entries
650 }
651
652 pub fn session_id(&self) -> &acp::SessionId {
653 &self.session_id
654 }
655
656 pub fn status(&self) -> ThreadStatus {
657 if self.send_task.is_some() {
658 if self.waiting_for_tool_confirmation() {
659 ThreadStatus::WaitingForToolConfirmation
660 } else {
661 ThreadStatus::Generating
662 }
663 } else {
664 ThreadStatus::Idle
665 }
666 }
667
668 pub fn has_pending_edit_tool_calls(&self) -> bool {
669 for entry in self.entries.iter().rev() {
670 match entry {
671 AgentThreadEntry::UserMessage(_) => return false,
672 AgentThreadEntry::ToolCall(
673 call @ ToolCall {
674 status:
675 ToolCallStatus::Allowed {
676 status:
677 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
678 },
679 ..
680 },
681 ) if call.diffs().next().is_some() => {
682 return true;
683 }
684 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
685 }
686 }
687
688 false
689 }
690
691 pub fn used_tools_since_last_user_message(&self) -> bool {
692 for entry in self.entries.iter().rev() {
693 match entry {
694 AgentThreadEntry::UserMessage(..) => return false,
695 AgentThreadEntry::AssistantMessage(..) => continue,
696 AgentThreadEntry::ToolCall(..) => return true,
697 }
698 }
699
700 false
701 }
702
703 pub fn handle_session_update(
704 &mut self,
705 update: acp::SessionUpdate,
706 cx: &mut Context<Self>,
707 ) -> Result<()> {
708 match update {
709 acp::SessionUpdate::UserMessageChunk { content } => {
710 self.push_user_content_block(content, cx);
711 }
712 acp::SessionUpdate::AgentMessageChunk { content } => {
713 self.push_assistant_content_block(content, false, cx);
714 }
715 acp::SessionUpdate::AgentThoughtChunk { content } => {
716 self.push_assistant_content_block(content, true, cx);
717 }
718 acp::SessionUpdate::ToolCall(tool_call) => {
719 self.upsert_tool_call(tool_call, cx);
720 }
721 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
722 self.update_tool_call(tool_call_update, cx)?;
723 }
724 acp::SessionUpdate::Plan(plan) => {
725 self.update_plan(plan, cx);
726 }
727 }
728 Ok(())
729 }
730
731 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
732 let language_registry = self.project.read(cx).languages().clone();
733 let entries_len = self.entries.len();
734
735 if let Some(last_entry) = self.entries.last_mut()
736 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
737 {
738 content.append(chunk, &language_registry, cx);
739 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
740 } else {
741 let content = ContentBlock::new(chunk, &language_registry, cx);
742 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
743 }
744 }
745
746 pub fn push_assistant_content_block(
747 &mut self,
748 chunk: acp::ContentBlock,
749 is_thought: bool,
750 cx: &mut Context<Self>,
751 ) {
752 let language_registry = self.project.read(cx).languages().clone();
753 let entries_len = self.entries.len();
754 if let Some(last_entry) = self.entries.last_mut()
755 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
756 {
757 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
758 match (chunks.last_mut(), is_thought) {
759 (Some(AssistantMessageChunk::Message { block }), false)
760 | (Some(AssistantMessageChunk::Thought { block }), true) => {
761 block.append(chunk, &language_registry, cx)
762 }
763 _ => {
764 let block = ContentBlock::new(chunk, &language_registry, cx);
765 if is_thought {
766 chunks.push(AssistantMessageChunk::Thought { block })
767 } else {
768 chunks.push(AssistantMessageChunk::Message { block })
769 }
770 }
771 }
772 } else {
773 let block = ContentBlock::new(chunk, &language_registry, cx);
774 let chunk = if is_thought {
775 AssistantMessageChunk::Thought { block }
776 } else {
777 AssistantMessageChunk::Message { block }
778 };
779
780 self.push_entry(
781 AgentThreadEntry::AssistantMessage(AssistantMessage {
782 chunks: vec![chunk],
783 }),
784 cx,
785 );
786 }
787 }
788
789 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
790 self.entries.push(entry);
791 cx.emit(AcpThreadEvent::NewEntry);
792 }
793
794 pub fn update_tool_call(
795 &mut self,
796 update: impl Into<ToolCallUpdate>,
797 cx: &mut Context<Self>,
798 ) -> Result<()> {
799 let update = update.into();
800 let languages = self.project.read(cx).languages().clone();
801
802 let (ix, current_call) = self
803 .tool_call_mut(update.id())
804 .context("Tool call not found")?;
805 match update {
806 ToolCallUpdate::UpdateFields(update) => {
807 current_call.update_fields(update.fields, languages, cx);
808 }
809 ToolCallUpdate::UpdateDiff(update) => {
810 current_call.content.clear();
811 current_call
812 .content
813 .push(ToolCallContent::Diff(update.diff));
814 }
815 ToolCallUpdate::UpdateTerminal(update) => {
816 current_call.content.clear();
817 current_call
818 .content
819 .push(ToolCallContent::Terminal(update.terminal));
820 }
821 }
822
823 cx.emit(AcpThreadEvent::EntryUpdated(ix));
824
825 Ok(())
826 }
827
828 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
829 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
830 let status = ToolCallStatus::Allowed {
831 status: tool_call.status,
832 };
833 self.upsert_tool_call_inner(tool_call, status, cx)
834 }
835
836 pub fn upsert_tool_call_inner(
837 &mut self,
838 tool_call: acp::ToolCall,
839 status: ToolCallStatus,
840 cx: &mut Context<Self>,
841 ) {
842 let language_registry = self.project.read(cx).languages().clone();
843 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
844
845 let location = call.locations.last().cloned();
846
847 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
848 *current_call = call;
849
850 cx.emit(AcpThreadEvent::EntryUpdated(ix));
851 } else {
852 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
853 }
854
855 if let Some(location) = location {
856 self.set_project_location(location, cx)
857 }
858 }
859
860 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
861 // The tool call we are looking for is typically the last one, or very close to the end.
862 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
863 self.entries
864 .iter_mut()
865 .enumerate()
866 .rev()
867 .find_map(|(index, tool_call)| {
868 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
869 && &tool_call.id == id
870 {
871 Some((index, tool_call))
872 } else {
873 None
874 }
875 })
876 }
877
878 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
879 self.project.update(cx, |project, cx| {
880 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
881 return;
882 };
883 let buffer = project.open_buffer(path, cx);
884 cx.spawn(async move |project, cx| {
885 let buffer = buffer.await?;
886
887 project.update(cx, |project, cx| {
888 let position = if let Some(line) = location.line {
889 let snapshot = buffer.read(cx).snapshot();
890 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
891 snapshot.anchor_before(point)
892 } else {
893 Anchor::MIN
894 };
895
896 project.set_agent_location(
897 Some(AgentLocation {
898 buffer: buffer.downgrade(),
899 position,
900 }),
901 cx,
902 );
903 })
904 })
905 .detach_and_log_err(cx);
906 });
907 }
908
909 pub fn request_tool_call_authorization(
910 &mut self,
911 tool_call: acp::ToolCall,
912 options: Vec<acp::PermissionOption>,
913 cx: &mut Context<Self>,
914 ) -> oneshot::Receiver<acp::PermissionOptionId> {
915 let (tx, rx) = oneshot::channel();
916
917 let status = ToolCallStatus::WaitingForConfirmation {
918 options,
919 respond_tx: tx,
920 };
921
922 self.upsert_tool_call_inner(tool_call, status, cx);
923 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
924 rx
925 }
926
927 pub fn authorize_tool_call(
928 &mut self,
929 id: acp::ToolCallId,
930 option_id: acp::PermissionOptionId,
931 option_kind: acp::PermissionOptionKind,
932 cx: &mut Context<Self>,
933 ) {
934 let Some((ix, call)) = self.tool_call_mut(&id) else {
935 return;
936 };
937
938 let new_status = match option_kind {
939 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
940 ToolCallStatus::Rejected
941 }
942 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
943 ToolCallStatus::Allowed {
944 status: acp::ToolCallStatus::InProgress,
945 }
946 }
947 };
948
949 let curr_status = mem::replace(&mut call.status, new_status);
950
951 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
952 respond_tx.send(option_id).log_err();
953 } else if cfg!(debug_assertions) {
954 panic!("tried to authorize an already authorized tool call");
955 }
956
957 cx.emit(AcpThreadEvent::EntryUpdated(ix));
958 }
959
960 /// Returns true if the last turn is awaiting tool authorization
961 pub fn waiting_for_tool_confirmation(&self) -> bool {
962 for entry in self.entries.iter().rev() {
963 match &entry {
964 AgentThreadEntry::ToolCall(call) => match call.status {
965 ToolCallStatus::WaitingForConfirmation { .. } => return true,
966 ToolCallStatus::Allowed { .. }
967 | ToolCallStatus::Rejected
968 | ToolCallStatus::Canceled => continue,
969 },
970 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
971 // Reached the beginning of the turn
972 return false;
973 }
974 }
975 }
976 false
977 }
978
979 pub fn plan(&self) -> &Plan {
980 &self.plan
981 }
982
983 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
984 let new_entries_len = request.entries.len();
985 let mut new_entries = request.entries.into_iter();
986
987 // Reuse existing markdown to prevent flickering
988 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
989 let PlanEntry {
990 content,
991 priority,
992 status,
993 } = old;
994 content.update(cx, |old, cx| {
995 old.replace(new.content, cx);
996 });
997 *priority = new.priority;
998 *status = new.status;
999 }
1000 for new in new_entries {
1001 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1002 }
1003 self.plan.entries.truncate(new_entries_len);
1004
1005 cx.notify();
1006 }
1007
1008 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1009 self.plan
1010 .entries
1011 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1012 cx.notify();
1013 }
1014
1015 #[cfg(any(test, feature = "test-support"))]
1016 pub fn send_raw(
1017 &mut self,
1018 message: &str,
1019 cx: &mut Context<Self>,
1020 ) -> BoxFuture<'static, Result<()>> {
1021 self.send(
1022 vec![acp::ContentBlock::Text(acp::TextContent {
1023 text: message.to_string(),
1024 annotations: None,
1025 })],
1026 cx,
1027 )
1028 }
1029
1030 pub fn send(
1031 &mut self,
1032 message: Vec<acp::ContentBlock>,
1033 cx: &mut Context<Self>,
1034 ) -> BoxFuture<'static, Result<()>> {
1035 let block = ContentBlock::new_combined(
1036 message.clone(),
1037 self.project.read(cx).languages().clone(),
1038 cx,
1039 );
1040 self.push_entry(
1041 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1042 cx,
1043 );
1044 self.clear_completed_plan_entries(cx);
1045
1046 let (tx, rx) = oneshot::channel();
1047 let cancel_task = self.cancel(cx);
1048
1049 self.send_task = Some(cx.spawn(async move |this, cx| {
1050 async {
1051 cancel_task.await;
1052
1053 let result = this
1054 .update(cx, |this, cx| {
1055 this.connection.prompt(
1056 acp::PromptRequest {
1057 prompt: message,
1058 session_id: this.session_id.clone(),
1059 },
1060 cx,
1061 )
1062 })?
1063 .await;
1064
1065 tx.send(result).log_err();
1066
1067 anyhow::Ok(())
1068 }
1069 .await
1070 .log_err();
1071 }));
1072
1073 cx.spawn(async move |this, cx| match rx.await {
1074 Ok(Err(e)) => {
1075 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1076 .log_err();
1077 Err(e)?
1078 }
1079 result => {
1080 let cancelled = matches!(
1081 result,
1082 Ok(Ok(acp::PromptResponse {
1083 stop_reason: acp::StopReason::Cancelled
1084 }))
1085 );
1086
1087 // We only take the task if the current prompt wasn't cancelled.
1088 //
1089 // This prompt may have been cancelled because another one was sent
1090 // while it was still generating. In these cases, dropping `send_task`
1091 // would cause the next generation to be cancelled.
1092 if !cancelled {
1093 this.update(cx, |this, _cx| this.send_task.take()).ok();
1094 }
1095
1096 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1097 .log_err();
1098 Ok(())
1099 }
1100 })
1101 .boxed()
1102 }
1103
1104 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1105 let Some(send_task) = self.send_task.take() else {
1106 return Task::ready(());
1107 };
1108
1109 for entry in self.entries.iter_mut() {
1110 if let AgentThreadEntry::ToolCall(call) = entry {
1111 let cancel = matches!(
1112 call.status,
1113 ToolCallStatus::WaitingForConfirmation { .. }
1114 | ToolCallStatus::Allowed {
1115 status: acp::ToolCallStatus::InProgress
1116 }
1117 );
1118
1119 if cancel {
1120 call.status = ToolCallStatus::Canceled;
1121 }
1122 }
1123 }
1124
1125 self.connection.cancel(&self.session_id, cx);
1126
1127 // Wait for the send task to complete
1128 cx.foreground_executor().spawn(send_task)
1129 }
1130
1131 pub fn read_text_file(
1132 &self,
1133 path: PathBuf,
1134 line: Option<u32>,
1135 limit: Option<u32>,
1136 reuse_shared_snapshot: bool,
1137 cx: &mut Context<Self>,
1138 ) -> Task<Result<String>> {
1139 let project = self.project.clone();
1140 let action_log = self.action_log.clone();
1141 cx.spawn(async move |this, cx| {
1142 let load = project.update(cx, |project, cx| {
1143 let path = project
1144 .project_path_for_absolute_path(&path, cx)
1145 .context("invalid path")?;
1146 anyhow::Ok(project.open_buffer(path, cx))
1147 });
1148 let buffer = load??.await?;
1149
1150 let snapshot = if reuse_shared_snapshot {
1151 this.read_with(cx, |this, _| {
1152 this.shared_buffers.get(&buffer.clone()).cloned()
1153 })
1154 .log_err()
1155 .flatten()
1156 } else {
1157 None
1158 };
1159
1160 let snapshot = if let Some(snapshot) = snapshot {
1161 snapshot
1162 } else {
1163 action_log.update(cx, |action_log, cx| {
1164 action_log.buffer_read(buffer.clone(), cx);
1165 })?;
1166 project.update(cx, |project, cx| {
1167 let position = buffer
1168 .read(cx)
1169 .snapshot()
1170 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1171 project.set_agent_location(
1172 Some(AgentLocation {
1173 buffer: buffer.downgrade(),
1174 position,
1175 }),
1176 cx,
1177 );
1178 })?;
1179
1180 buffer.update(cx, |buffer, _| buffer.snapshot())?
1181 };
1182
1183 this.update(cx, |this, _| {
1184 let text = snapshot.text();
1185 this.shared_buffers.insert(buffer.clone(), snapshot);
1186 if line.is_none() && limit.is_none() {
1187 return Ok(text);
1188 }
1189 let limit = limit.unwrap_or(u32::MAX) as usize;
1190 let Some(line) = line else {
1191 return Ok(text.lines().take(limit).collect::<String>());
1192 };
1193
1194 let count = text.lines().count();
1195 if count < line as usize {
1196 anyhow::bail!("There are only {} lines", count);
1197 }
1198 Ok(text
1199 .lines()
1200 .skip(line as usize + 1)
1201 .take(limit)
1202 .collect::<String>())
1203 })?
1204 })
1205 }
1206
1207 pub fn write_text_file(
1208 &self,
1209 path: PathBuf,
1210 content: String,
1211 cx: &mut Context<Self>,
1212 ) -> Task<Result<()>> {
1213 let project = self.project.clone();
1214 let action_log = self.action_log.clone();
1215 cx.spawn(async move |this, cx| {
1216 let load = project.update(cx, |project, cx| {
1217 let path = project
1218 .project_path_for_absolute_path(&path, cx)
1219 .context("invalid path")?;
1220 anyhow::Ok(project.open_buffer(path, cx))
1221 });
1222 let buffer = load??.await?;
1223 let snapshot = this.update(cx, |this, cx| {
1224 this.shared_buffers
1225 .get(&buffer)
1226 .cloned()
1227 .unwrap_or_else(|| buffer.read(cx).snapshot())
1228 })?;
1229 let edits = cx
1230 .background_executor()
1231 .spawn(async move {
1232 let old_text = snapshot.text();
1233 text_diff(old_text.as_str(), &content)
1234 .into_iter()
1235 .map(|(range, replacement)| {
1236 (
1237 snapshot.anchor_after(range.start)
1238 ..snapshot.anchor_before(range.end),
1239 replacement,
1240 )
1241 })
1242 .collect::<Vec<_>>()
1243 })
1244 .await;
1245 cx.update(|cx| {
1246 project.update(cx, |project, cx| {
1247 project.set_agent_location(
1248 Some(AgentLocation {
1249 buffer: buffer.downgrade(),
1250 position: edits
1251 .last()
1252 .map(|(range, _)| range.end)
1253 .unwrap_or(Anchor::MIN),
1254 }),
1255 cx,
1256 );
1257 });
1258
1259 action_log.update(cx, |action_log, cx| {
1260 action_log.buffer_read(buffer.clone(), cx);
1261 });
1262 buffer.update(cx, |buffer, cx| {
1263 buffer.edit(edits, None, cx);
1264 });
1265 action_log.update(cx, |action_log, cx| {
1266 action_log.buffer_edited(buffer.clone(), cx);
1267 });
1268 })?;
1269 project
1270 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1271 .await
1272 })
1273 }
1274
1275 pub fn to_markdown(&self, cx: &App) -> String {
1276 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1277 }
1278
1279 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1280 cx.emit(AcpThreadEvent::ServerExited(status));
1281 }
1282}
1283
1284fn markdown_for_raw_output(
1285 raw_output: &serde_json::Value,
1286 language_registry: &Arc<LanguageRegistry>,
1287 cx: &mut App,
1288) -> Option<Entity<Markdown>> {
1289 match raw_output {
1290 serde_json::Value::Null => None,
1291 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1292 Markdown::new(
1293 value.to_string().into(),
1294 Some(language_registry.clone()),
1295 None,
1296 cx,
1297 )
1298 })),
1299 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1300 Markdown::new(
1301 value.to_string().into(),
1302 Some(language_registry.clone()),
1303 None,
1304 cx,
1305 )
1306 })),
1307 serde_json::Value::String(value) => Some(cx.new(|cx| {
1308 Markdown::new(
1309 value.clone().into(),
1310 Some(language_registry.clone()),
1311 None,
1312 cx,
1313 )
1314 })),
1315 value => Some(cx.new(|cx| {
1316 Markdown::new(
1317 format!("```json\n{}\n```", value).into(),
1318 Some(language_registry.clone()),
1319 None,
1320 cx,
1321 )
1322 })),
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use super::*;
1329 use anyhow::anyhow;
1330 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1331 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1332 use indoc::indoc;
1333 use project::FakeFs;
1334 use rand::Rng as _;
1335 use serde_json::json;
1336 use settings::SettingsStore;
1337 use smol::stream::StreamExt as _;
1338 use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1339
1340 use util::path;
1341
1342 fn init_test(cx: &mut TestAppContext) {
1343 env_logger::try_init().ok();
1344 cx.update(|cx| {
1345 let settings_store = SettingsStore::test(cx);
1346 cx.set_global(settings_store);
1347 Project::init_settings(cx);
1348 language::init(cx);
1349 });
1350 }
1351
1352 #[gpui::test]
1353 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1354 init_test(cx);
1355
1356 let fs = FakeFs::new(cx.executor());
1357 let project = Project::test(fs, [], cx).await;
1358 let connection = Rc::new(FakeAgentConnection::new());
1359 let thread = cx
1360 .spawn(async move |mut cx| {
1361 connection
1362 .new_thread(project, Path::new(path!("/test")), &mut cx)
1363 .await
1364 })
1365 .await
1366 .unwrap();
1367
1368 // Test creating a new user message
1369 thread.update(cx, |thread, cx| {
1370 thread.push_user_content_block(
1371 acp::ContentBlock::Text(acp::TextContent {
1372 annotations: None,
1373 text: "Hello, ".to_string(),
1374 }),
1375 cx,
1376 );
1377 });
1378
1379 thread.update(cx, |thread, cx| {
1380 assert_eq!(thread.entries.len(), 1);
1381 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1382 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1383 } else {
1384 panic!("Expected UserMessage");
1385 }
1386 });
1387
1388 // Test appending to existing user message
1389 thread.update(cx, |thread, cx| {
1390 thread.push_user_content_block(
1391 acp::ContentBlock::Text(acp::TextContent {
1392 annotations: None,
1393 text: "world!".to_string(),
1394 }),
1395 cx,
1396 );
1397 });
1398
1399 thread.update(cx, |thread, cx| {
1400 assert_eq!(thread.entries.len(), 1);
1401 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1402 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1403 } else {
1404 panic!("Expected UserMessage");
1405 }
1406 });
1407
1408 // Test creating new user message after assistant message
1409 thread.update(cx, |thread, cx| {
1410 thread.push_assistant_content_block(
1411 acp::ContentBlock::Text(acp::TextContent {
1412 annotations: None,
1413 text: "Assistant response".to_string(),
1414 }),
1415 false,
1416 cx,
1417 );
1418 });
1419
1420 thread.update(cx, |thread, cx| {
1421 thread.push_user_content_block(
1422 acp::ContentBlock::Text(acp::TextContent {
1423 annotations: None,
1424 text: "New user message".to_string(),
1425 }),
1426 cx,
1427 );
1428 });
1429
1430 thread.update(cx, |thread, cx| {
1431 assert_eq!(thread.entries.len(), 3);
1432 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1433 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1434 } else {
1435 panic!("Expected UserMessage at index 2");
1436 }
1437 });
1438 }
1439
1440 #[gpui::test]
1441 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1442 init_test(cx);
1443
1444 let fs = FakeFs::new(cx.executor());
1445 let project = Project::test(fs, [], cx).await;
1446 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1447 |_, thread, mut cx| {
1448 async move {
1449 thread.update(&mut cx, |thread, cx| {
1450 thread
1451 .handle_session_update(
1452 acp::SessionUpdate::AgentThoughtChunk {
1453 content: "Thinking ".into(),
1454 },
1455 cx,
1456 )
1457 .unwrap();
1458 thread
1459 .handle_session_update(
1460 acp::SessionUpdate::AgentThoughtChunk {
1461 content: "hard!".into(),
1462 },
1463 cx,
1464 )
1465 .unwrap();
1466 })?;
1467 Ok(acp::PromptResponse {
1468 stop_reason: acp::StopReason::EndTurn,
1469 })
1470 }
1471 .boxed_local()
1472 },
1473 ));
1474
1475 let thread = cx
1476 .spawn(async move |mut cx| {
1477 connection
1478 .new_thread(project, Path::new(path!("/test")), &mut cx)
1479 .await
1480 })
1481 .await
1482 .unwrap();
1483
1484 thread
1485 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1486 .await
1487 .unwrap();
1488
1489 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1490 assert_eq!(
1491 output,
1492 indoc! {r#"
1493 ## User
1494
1495 Hello from Zed!
1496
1497 ## Assistant
1498
1499 <thinking>
1500 Thinking hard!
1501 </thinking>
1502
1503 "#}
1504 );
1505 }
1506
1507 #[gpui::test]
1508 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1509 init_test(cx);
1510
1511 let fs = FakeFs::new(cx.executor());
1512 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1513 .await;
1514 let project = Project::test(fs.clone(), [], cx).await;
1515 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1516 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1517 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1518 move |_, thread, mut cx| {
1519 let read_file_tx = read_file_tx.clone();
1520 async move {
1521 let content = thread
1522 .update(&mut cx, |thread, cx| {
1523 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1524 })
1525 .unwrap()
1526 .await
1527 .unwrap();
1528 assert_eq!(content, "one\ntwo\nthree\n");
1529 read_file_tx.take().unwrap().send(()).unwrap();
1530 thread
1531 .update(&mut cx, |thread, cx| {
1532 thread.write_text_file(
1533 path!("/tmp/foo").into(),
1534 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1535 cx,
1536 )
1537 })
1538 .unwrap()
1539 .await
1540 .unwrap();
1541 Ok(acp::PromptResponse {
1542 stop_reason: acp::StopReason::EndTurn,
1543 })
1544 }
1545 .boxed_local()
1546 },
1547 ));
1548
1549 let (worktree, pathbuf) = project
1550 .update(cx, |project, cx| {
1551 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1552 })
1553 .await
1554 .unwrap();
1555 let buffer = project
1556 .update(cx, |project, cx| {
1557 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1558 })
1559 .await
1560 .unwrap();
1561
1562 let thread = cx
1563 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1564 .await
1565 .unwrap();
1566
1567 let request = thread.update(cx, |thread, cx| {
1568 thread.send_raw("Extend the count in /tmp/foo", cx)
1569 });
1570 read_file_rx.await.ok();
1571 buffer.update(cx, |buffer, cx| {
1572 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1573 });
1574 cx.run_until_parked();
1575 assert_eq!(
1576 buffer.read_with(cx, |buffer, _| buffer.text()),
1577 "zero\none\ntwo\nthree\nfour\nfive\n"
1578 );
1579 assert_eq!(
1580 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1581 "zero\none\ntwo\nthree\nfour\nfive\n"
1582 );
1583 request.await.unwrap();
1584 }
1585
1586 #[gpui::test]
1587 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1588 init_test(cx);
1589
1590 let fs = FakeFs::new(cx.executor());
1591 let project = Project::test(fs, [], cx).await;
1592 let id = acp::ToolCallId("test".into());
1593
1594 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1595 let id = id.clone();
1596 move |_, thread, mut cx| {
1597 let id = id.clone();
1598 async move {
1599 thread
1600 .update(&mut cx, |thread, cx| {
1601 thread.handle_session_update(
1602 acp::SessionUpdate::ToolCall(acp::ToolCall {
1603 id: id.clone(),
1604 title: "Label".into(),
1605 kind: acp::ToolKind::Fetch,
1606 status: acp::ToolCallStatus::InProgress,
1607 content: vec![],
1608 locations: vec![],
1609 raw_input: None,
1610 raw_output: None,
1611 }),
1612 cx,
1613 )
1614 })
1615 .unwrap()
1616 .unwrap();
1617 Ok(acp::PromptResponse {
1618 stop_reason: acp::StopReason::EndTurn,
1619 })
1620 }
1621 .boxed_local()
1622 }
1623 }));
1624
1625 let thread = cx
1626 .spawn(async move |mut cx| {
1627 connection
1628 .new_thread(project, Path::new(path!("/test")), &mut cx)
1629 .await
1630 })
1631 .await
1632 .unwrap();
1633
1634 let request = thread.update(cx, |thread, cx| {
1635 thread.send_raw("Fetch https://example.com", cx)
1636 });
1637
1638 run_until_first_tool_call(&thread, cx).await;
1639
1640 thread.read_with(cx, |thread, _| {
1641 assert!(matches!(
1642 thread.entries[1],
1643 AgentThreadEntry::ToolCall(ToolCall {
1644 status: ToolCallStatus::Allowed {
1645 status: acp::ToolCallStatus::InProgress,
1646 ..
1647 },
1648 ..
1649 })
1650 ));
1651 });
1652
1653 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1654
1655 thread.read_with(cx, |thread, _| {
1656 assert!(matches!(
1657 &thread.entries[1],
1658 AgentThreadEntry::ToolCall(ToolCall {
1659 status: ToolCallStatus::Canceled,
1660 ..
1661 })
1662 ));
1663 });
1664
1665 thread
1666 .update(cx, |thread, cx| {
1667 thread.handle_session_update(
1668 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1669 id,
1670 fields: acp::ToolCallUpdateFields {
1671 status: Some(acp::ToolCallStatus::Completed),
1672 ..Default::default()
1673 },
1674 }),
1675 cx,
1676 )
1677 })
1678 .unwrap();
1679
1680 request.await.unwrap();
1681
1682 thread.read_with(cx, |thread, _| {
1683 assert!(matches!(
1684 thread.entries[1],
1685 AgentThreadEntry::ToolCall(ToolCall {
1686 status: ToolCallStatus::Allowed {
1687 status: acp::ToolCallStatus::Completed,
1688 ..
1689 },
1690 ..
1691 })
1692 ));
1693 });
1694 }
1695
1696 #[gpui::test]
1697 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1698 init_test(cx);
1699 let fs = FakeFs::new(cx.background_executor.clone());
1700 fs.insert_tree(path!("/test"), json!({})).await;
1701 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1702
1703 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1704 move |_, thread, mut cx| {
1705 async move {
1706 thread
1707 .update(&mut cx, |thread, cx| {
1708 thread.handle_session_update(
1709 acp::SessionUpdate::ToolCall(acp::ToolCall {
1710 id: acp::ToolCallId("test".into()),
1711 title: "Label".into(),
1712 kind: acp::ToolKind::Edit,
1713 status: acp::ToolCallStatus::Completed,
1714 content: vec![acp::ToolCallContent::Diff {
1715 diff: acp::Diff {
1716 path: "/test/test.txt".into(),
1717 old_text: None,
1718 new_text: "foo".into(),
1719 },
1720 }],
1721 locations: vec![],
1722 raw_input: None,
1723 raw_output: None,
1724 }),
1725 cx,
1726 )
1727 })
1728 .unwrap()
1729 .unwrap();
1730 Ok(acp::PromptResponse {
1731 stop_reason: acp::StopReason::EndTurn,
1732 })
1733 }
1734 .boxed_local()
1735 }
1736 }));
1737
1738 let thread = connection
1739 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1740 .await
1741 .unwrap();
1742 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1743 .await
1744 .unwrap();
1745
1746 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1747 }
1748
1749 async fn run_until_first_tool_call(
1750 thread: &Entity<AcpThread>,
1751 cx: &mut TestAppContext,
1752 ) -> usize {
1753 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1754
1755 let subscription = cx.update(|cx| {
1756 cx.subscribe(thread, move |thread, _, cx| {
1757 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1758 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1759 return tx.try_send(ix).unwrap();
1760 }
1761 }
1762 })
1763 });
1764
1765 select! {
1766 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1767 panic!("Timeout waiting for tool call")
1768 }
1769 ix = rx.next().fuse() => {
1770 drop(subscription);
1771 ix.unwrap()
1772 }
1773 }
1774 }
1775
1776 #[derive(Clone, Default)]
1777 struct FakeAgentConnection {
1778 auth_methods: Vec<acp::AuthMethod>,
1779 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1780 on_user_message: Option<
1781 Rc<
1782 dyn Fn(
1783 acp::PromptRequest,
1784 WeakEntity<AcpThread>,
1785 AsyncApp,
1786 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1787 + 'static,
1788 >,
1789 >,
1790 }
1791
1792 impl FakeAgentConnection {
1793 fn new() -> Self {
1794 Self {
1795 auth_methods: Vec::new(),
1796 on_user_message: None,
1797 sessions: Arc::default(),
1798 }
1799 }
1800
1801 #[expect(unused)]
1802 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1803 self.auth_methods = auth_methods;
1804 self
1805 }
1806
1807 fn on_user_message(
1808 mut self,
1809 handler: impl Fn(
1810 acp::PromptRequest,
1811 WeakEntity<AcpThread>,
1812 AsyncApp,
1813 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1814 + 'static,
1815 ) -> Self {
1816 self.on_user_message.replace(Rc::new(handler));
1817 self
1818 }
1819 }
1820
1821 impl AgentConnection for FakeAgentConnection {
1822 fn auth_methods(&self) -> &[acp::AuthMethod] {
1823 &self.auth_methods
1824 }
1825
1826 fn new_thread(
1827 self: Rc<Self>,
1828 project: Entity<Project>,
1829 _cwd: &Path,
1830 cx: &mut gpui::AsyncApp,
1831 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1832 let session_id = acp::SessionId(
1833 rand::thread_rng()
1834 .sample_iter(&rand::distributions::Alphanumeric)
1835 .take(7)
1836 .map(char::from)
1837 .collect::<String>()
1838 .into(),
1839 );
1840 let thread = cx
1841 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1842 .unwrap();
1843 self.sessions.lock().insert(session_id, thread.downgrade());
1844 Task::ready(Ok(thread))
1845 }
1846
1847 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1848 if self.auth_methods().iter().any(|m| m.id == method) {
1849 Task::ready(Ok(()))
1850 } else {
1851 Task::ready(Err(anyhow!("Invalid Auth Method")))
1852 }
1853 }
1854
1855 fn prompt(
1856 &self,
1857 params: acp::PromptRequest,
1858 cx: &mut App,
1859 ) -> Task<gpui::Result<acp::PromptResponse>> {
1860 let sessions = self.sessions.lock();
1861 let thread = sessions.get(¶ms.session_id).unwrap();
1862 if let Some(handler) = &self.on_user_message {
1863 let handler = handler.clone();
1864 let thread = thread.clone();
1865 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1866 } else {
1867 Task::ready(Ok(acp::PromptResponse {
1868 stop_reason: acp::StopReason::EndTurn,
1869 }))
1870 }
1871 }
1872
1873 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1874 let sessions = self.sessions.lock();
1875 let thread = sessions.get(&session_id).unwrap().clone();
1876
1877 cx.spawn(async move |cx| {
1878 thread
1879 .update(cx, |thread, cx| thread.cancel(cx))
1880 .unwrap()
1881 .await
1882 })
1883 .detach();
1884 }
1885 }
1886}