1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol::{self as acp};
13use anyhow::{Context as _, Result};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::Formatter;
24use std::process::ExitStatus;
25use std::rc::Rc;
26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
27use ui::App;
28use util::ResultExt;
29
30#[derive(Debug)]
31pub struct UserMessage {
32 pub content: ContentBlock,
33}
34
35impl UserMessage {
36 pub fn from_acp(
37 message: impl IntoIterator<Item = acp::ContentBlock>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut App,
40 ) -> Self {
41 let mut content = ContentBlock::Empty;
42 for chunk in message {
43 content.append(chunk, &language_registry, cx)
44 }
45 Self { content: content }
46 }
47
48 fn to_markdown(&self, cx: &App) -> String {
49 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
126 if let AgentThreadEntry::ToolCall(ToolCall {
127 locations,
128 resolved_locations,
129 ..
130 }) = self
131 {
132 Some((
133 locations.get(ix)?.clone(),
134 resolved_locations.get(ix)?.clone()?,
135 ))
136 } else {
137 None
138 }
139 }
140}
141
142#[derive(Debug)]
143pub struct ToolCall {
144 pub id: acp::ToolCallId,
145 pub label: Entity<Markdown>,
146 pub kind: acp::ToolKind,
147 pub content: Vec<ToolCallContent>,
148 pub status: ToolCallStatus,
149 pub locations: Vec<acp::ToolCallLocation>,
150 pub resolved_locations: Vec<Option<AgentLocation>>,
151 pub raw_input: Option<serde_json::Value>,
152 pub raw_output: Option<serde_json::Value>,
153}
154
155impl ToolCall {
156 fn from_acp(
157 tool_call: acp::ToolCall,
158 status: ToolCallStatus,
159 language_registry: Arc<LanguageRegistry>,
160 cx: &mut App,
161 ) -> Self {
162 Self {
163 id: tool_call.id,
164 label: cx.new(|cx| {
165 Markdown::new(
166 tool_call.title.into(),
167 Some(language_registry.clone()),
168 None,
169 cx,
170 )
171 }),
172 kind: tool_call.kind,
173 content: tool_call
174 .content
175 .into_iter()
176 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
177 .collect(),
178 locations: tool_call.locations,
179 resolved_locations: Vec::default(),
180 status,
181 raw_input: tool_call.raw_input,
182 raw_output: tool_call.raw_output,
183 }
184 }
185
186 fn update_fields(
187 &mut self,
188 fields: acp::ToolCallUpdateFields,
189 language_registry: Arc<LanguageRegistry>,
190 cx: &mut App,
191 ) {
192 let acp::ToolCallUpdateFields {
193 kind,
194 status,
195 title,
196 content,
197 locations,
198 raw_input,
199 raw_output,
200 } = fields;
201
202 if let Some(kind) = kind {
203 self.kind = kind;
204 }
205
206 if let Some(status) = status {
207 self.status = ToolCallStatus::Allowed { status };
208 }
209
210 if let Some(title) = title {
211 self.label.update(cx, |label, cx| {
212 label.replace(title, cx);
213 });
214 }
215
216 if let Some(content) = content {
217 self.content = content
218 .into_iter()
219 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
220 .collect();
221 }
222
223 if let Some(locations) = locations {
224 self.locations = locations;
225 }
226
227 if let Some(raw_input) = raw_input {
228 self.raw_input = Some(raw_input);
229 }
230
231 if let Some(raw_output) = raw_output {
232 if self.content.is_empty() {
233 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
234 {
235 self.content
236 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
237 markdown,
238 }));
239 }
240 }
241 self.raw_output = Some(raw_output);
242 }
243 }
244
245 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
246 self.content.iter().filter_map(|content| match content {
247 ToolCallContent::Diff(diff) => Some(diff),
248 ToolCallContent::ContentBlock(_) => None,
249 ToolCallContent::Terminal(_) => None,
250 })
251 }
252
253 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
254 self.content.iter().filter_map(|content| match content {
255 ToolCallContent::Terminal(terminal) => Some(terminal),
256 ToolCallContent::ContentBlock(_) => None,
257 ToolCallContent::Diff(_) => None,
258 })
259 }
260
261 fn to_markdown(&self, cx: &App) -> String {
262 let mut markdown = format!(
263 "**Tool Call: {}**\nStatus: {}\n\n",
264 self.label.read(cx).source(),
265 self.status
266 );
267 for content in &self.content {
268 markdown.push_str(content.to_markdown(cx).as_str());
269 markdown.push_str("\n\n");
270 }
271 markdown
272 }
273
274 async fn resolve_location(
275 location: acp::ToolCallLocation,
276 project: WeakEntity<Project>,
277 cx: &mut AsyncApp,
278 ) -> Option<AgentLocation> {
279 let buffer = project
280 .update(cx, |project, cx| {
281 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
282 Some(project.open_buffer(path, cx))
283 } else {
284 None
285 }
286 })
287 .ok()??;
288 let buffer = buffer.await.log_err()?;
289 let position = buffer
290 .update(cx, |buffer, _| {
291 if let Some(row) = location.line {
292 let snapshot = buffer.snapshot();
293 let column = snapshot.indent_size_for_line(row).len;
294 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
295 snapshot.anchor_before(point)
296 } else {
297 Anchor::MIN
298 }
299 })
300 .ok()?;
301
302 Some(AgentLocation {
303 buffer: buffer.downgrade(),
304 position,
305 })
306 }
307
308 fn resolve_locations(
309 &self,
310 project: Entity<Project>,
311 cx: &mut App,
312 ) -> Task<Vec<Option<AgentLocation>>> {
313 let locations = self.locations.clone();
314 project.update(cx, |_, cx| {
315 cx.spawn(async move |project, cx| {
316 let mut new_locations = Vec::new();
317 for location in locations {
318 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
319 }
320 new_locations
321 })
322 })
323 }
324}
325
326#[derive(Debug)]
327pub enum ToolCallStatus {
328 WaitingForConfirmation {
329 options: Vec<acp::PermissionOption>,
330 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
331 },
332 Allowed {
333 status: acp::ToolCallStatus,
334 },
335 Rejected,
336 Canceled,
337}
338
339impl Display for ToolCallStatus {
340 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
341 write!(
342 f,
343 "{}",
344 match self {
345 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
346 ToolCallStatus::Allowed { status } => match status {
347 acp::ToolCallStatus::Pending => "Pending",
348 acp::ToolCallStatus::InProgress => "In Progress",
349 acp::ToolCallStatus::Completed => "Completed",
350 acp::ToolCallStatus::Failed => "Failed",
351 },
352 ToolCallStatus::Rejected => "Rejected",
353 ToolCallStatus::Canceled => "Canceled",
354 }
355 )
356 }
357}
358
359#[derive(Debug, PartialEq, Clone)]
360pub enum ContentBlock {
361 Empty,
362 Markdown { markdown: Entity<Markdown> },
363 ResourceLink { resource_link: acp::ResourceLink },
364}
365
366impl ContentBlock {
367 pub fn new(
368 block: acp::ContentBlock,
369 language_registry: &Arc<LanguageRegistry>,
370 cx: &mut App,
371 ) -> Self {
372 let mut this = Self::Empty;
373 this.append(block, language_registry, cx);
374 this
375 }
376
377 pub fn new_combined(
378 blocks: impl IntoIterator<Item = acp::ContentBlock>,
379 language_registry: Arc<LanguageRegistry>,
380 cx: &mut App,
381 ) -> Self {
382 let mut this = Self::Empty;
383 for block in blocks {
384 this.append(block, &language_registry, cx);
385 }
386 this
387 }
388
389 pub fn append(
390 &mut self,
391 block: acp::ContentBlock,
392 language_registry: &Arc<LanguageRegistry>,
393 cx: &mut App,
394 ) {
395 if matches!(self, ContentBlock::Empty) {
396 if let acp::ContentBlock::ResourceLink(resource_link) = block {
397 *self = ContentBlock::ResourceLink { resource_link };
398 return;
399 }
400 }
401
402 let new_content = self.extract_content_from_block(block);
403
404 match self {
405 ContentBlock::Empty => {
406 *self = Self::create_markdown_block(new_content, language_registry, cx);
407 }
408 ContentBlock::Markdown { markdown } => {
409 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
410 }
411 ContentBlock::ResourceLink { resource_link } => {
412 let existing_content = Self::resource_link_to_content(&resource_link.uri);
413 let combined = format!("{}\n{}", existing_content, new_content);
414
415 *self = Self::create_markdown_block(combined, language_registry, cx);
416 }
417 }
418 }
419
420 fn resource_link_to_content(uri: &str) -> String {
421 if let Some(uri) = MentionUri::parse(&uri).log_err() {
422 uri.to_link()
423 } else {
424 uri.to_string().clone()
425 }
426 }
427
428 fn create_markdown_block(
429 content: String,
430 language_registry: &Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) -> ContentBlock {
433 ContentBlock::Markdown {
434 markdown: cx
435 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
436 }
437 }
438
439 fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
440 match block {
441 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
442 acp::ContentBlock::ResourceLink(resource_link) => {
443 Self::resource_link_to_content(&resource_link.uri)
444 }
445 acp::ContentBlock::Resource(acp::EmbeddedResource {
446 resource:
447 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
448 uri,
449 ..
450 }),
451 ..
452 }) => Self::resource_link_to_content(&uri),
453 acp::ContentBlock::Image(_)
454 | acp::ContentBlock::Audio(_)
455 | acp::ContentBlock::Resource(_) => String::new(),
456 }
457 }
458
459 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
460 match self {
461 ContentBlock::Empty => "",
462 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
463 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
464 }
465 }
466
467 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
468 match self {
469 ContentBlock::Empty => None,
470 ContentBlock::Markdown { markdown } => Some(markdown),
471 ContentBlock::ResourceLink { .. } => None,
472 }
473 }
474
475 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
476 match self {
477 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
478 _ => None,
479 }
480 }
481}
482
483#[derive(Debug)]
484pub enum ToolCallContent {
485 ContentBlock(ContentBlock),
486 Diff(Entity<Diff>),
487 Terminal(Entity<Terminal>),
488}
489
490impl ToolCallContent {
491 pub fn from_acp(
492 content: acp::ToolCallContent,
493 language_registry: Arc<LanguageRegistry>,
494 cx: &mut App,
495 ) -> Self {
496 match content {
497 acp::ToolCallContent::Content { content } => {
498 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
499 }
500 acp::ToolCallContent::Diff { diff } => {
501 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
502 }
503 }
504 }
505
506 pub fn to_markdown(&self, cx: &App) -> String {
507 match self {
508 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
509 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
510 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
511 }
512 }
513}
514
515#[derive(Debug, PartialEq)]
516pub enum ToolCallUpdate {
517 UpdateFields(acp::ToolCallUpdate),
518 UpdateDiff(ToolCallUpdateDiff),
519 UpdateTerminal(ToolCallUpdateTerminal),
520}
521
522impl ToolCallUpdate {
523 fn id(&self) -> &acp::ToolCallId {
524 match self {
525 Self::UpdateFields(update) => &update.id,
526 Self::UpdateDiff(diff) => &diff.id,
527 Self::UpdateTerminal(terminal) => &terminal.id,
528 }
529 }
530}
531
532impl From<acp::ToolCallUpdate> for ToolCallUpdate {
533 fn from(update: acp::ToolCallUpdate) -> Self {
534 Self::UpdateFields(update)
535 }
536}
537
538impl From<ToolCallUpdateDiff> for ToolCallUpdate {
539 fn from(diff: ToolCallUpdateDiff) -> Self {
540 Self::UpdateDiff(diff)
541 }
542}
543
544#[derive(Debug, PartialEq)]
545pub struct ToolCallUpdateDiff {
546 pub id: acp::ToolCallId,
547 pub diff: Entity<Diff>,
548}
549
550impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
551 fn from(terminal: ToolCallUpdateTerminal) -> Self {
552 Self::UpdateTerminal(terminal)
553 }
554}
555
556#[derive(Debug, PartialEq)]
557pub struct ToolCallUpdateTerminal {
558 pub id: acp::ToolCallId,
559 pub terminal: Entity<Terminal>,
560}
561
562#[derive(Debug, Default)]
563pub struct Plan {
564 pub entries: Vec<PlanEntry>,
565}
566
567#[derive(Debug)]
568pub struct PlanStats<'a> {
569 pub in_progress_entry: Option<&'a PlanEntry>,
570 pub pending: u32,
571 pub completed: u32,
572}
573
574impl Plan {
575 pub fn is_empty(&self) -> bool {
576 self.entries.is_empty()
577 }
578
579 pub fn stats(&self) -> PlanStats<'_> {
580 let mut stats = PlanStats {
581 in_progress_entry: None,
582 pending: 0,
583 completed: 0,
584 };
585
586 for entry in &self.entries {
587 match &entry.status {
588 acp::PlanEntryStatus::Pending => {
589 stats.pending += 1;
590 }
591 acp::PlanEntryStatus::InProgress => {
592 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
593 }
594 acp::PlanEntryStatus::Completed => {
595 stats.completed += 1;
596 }
597 }
598 }
599
600 stats
601 }
602}
603
604#[derive(Debug)]
605pub struct PlanEntry {
606 pub content: Entity<Markdown>,
607 pub priority: acp::PlanEntryPriority,
608 pub status: acp::PlanEntryStatus,
609}
610
611impl PlanEntry {
612 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
613 Self {
614 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
615 priority: entry.priority,
616 status: entry.status,
617 }
618 }
619}
620
621pub struct AcpThread {
622 title: SharedString,
623 entries: Vec<AgentThreadEntry>,
624 plan: Plan,
625 project: Entity<Project>,
626 action_log: Entity<ActionLog>,
627 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
628 send_task: Option<Task<()>>,
629 connection: Rc<dyn AgentConnection>,
630 session_id: acp::SessionId,
631}
632
633pub enum AcpThreadEvent {
634 NewEntry,
635 EntryUpdated(usize),
636 ToolAuthorizationRequired,
637 Stopped,
638 Error,
639 ServerExited(ExitStatus),
640}
641
642impl EventEmitter<AcpThreadEvent> for AcpThread {}
643
644#[derive(PartialEq, Eq)]
645pub enum ThreadStatus {
646 Idle,
647 WaitingForToolConfirmation,
648 Generating,
649}
650
651#[derive(Debug, Clone)]
652pub enum LoadError {
653 Unsupported {
654 error_message: SharedString,
655 upgrade_message: SharedString,
656 upgrade_command: String,
657 },
658 Exited(i32),
659 Other(SharedString),
660}
661
662impl Display for LoadError {
663 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
664 match self {
665 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
666 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
667 LoadError::Other(msg) => write!(f, "{}", msg),
668 }
669 }
670}
671
672impl Error for LoadError {}
673
674impl AcpThread {
675 pub fn new(
676 title: impl Into<SharedString>,
677 connection: Rc<dyn AgentConnection>,
678 project: Entity<Project>,
679 session_id: acp::SessionId,
680 cx: &mut Context<Self>,
681 ) -> Self {
682 let action_log = cx.new(|_| ActionLog::new(project.clone()));
683
684 Self {
685 action_log,
686 shared_buffers: Default::default(),
687 entries: Default::default(),
688 plan: Default::default(),
689 title: title.into(),
690 project,
691 send_task: None,
692 connection,
693 session_id,
694 }
695 }
696
697 pub fn action_log(&self) -> &Entity<ActionLog> {
698 &self.action_log
699 }
700
701 pub fn project(&self) -> &Entity<Project> {
702 &self.project
703 }
704
705 pub fn title(&self) -> SharedString {
706 self.title.clone()
707 }
708
709 pub fn entries(&self) -> &[AgentThreadEntry] {
710 &self.entries
711 }
712
713 pub fn session_id(&self) -> &acp::SessionId {
714 &self.session_id
715 }
716
717 pub fn status(&self) -> ThreadStatus {
718 if self.send_task.is_some() {
719 if self.waiting_for_tool_confirmation() {
720 ThreadStatus::WaitingForToolConfirmation
721 } else {
722 ThreadStatus::Generating
723 }
724 } else {
725 ThreadStatus::Idle
726 }
727 }
728
729 pub fn has_pending_edit_tool_calls(&self) -> bool {
730 for entry in self.entries.iter().rev() {
731 match entry {
732 AgentThreadEntry::UserMessage(_) => return false,
733 AgentThreadEntry::ToolCall(
734 call @ ToolCall {
735 status:
736 ToolCallStatus::Allowed {
737 status:
738 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
739 },
740 ..
741 },
742 ) if call.diffs().next().is_some() => {
743 return true;
744 }
745 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
746 }
747 }
748
749 false
750 }
751
752 pub fn used_tools_since_last_user_message(&self) -> bool {
753 for entry in self.entries.iter().rev() {
754 match entry {
755 AgentThreadEntry::UserMessage(..) => return false,
756 AgentThreadEntry::AssistantMessage(..) => continue,
757 AgentThreadEntry::ToolCall(..) => return true,
758 }
759 }
760
761 false
762 }
763
764 pub fn handle_session_update(
765 &mut self,
766 update: acp::SessionUpdate,
767 cx: &mut Context<Self>,
768 ) -> Result<()> {
769 match update {
770 acp::SessionUpdate::UserMessageChunk { content } => {
771 self.push_user_content_block(content, cx);
772 }
773 acp::SessionUpdate::AgentMessageChunk { content } => {
774 self.push_assistant_content_block(content, false, cx);
775 }
776 acp::SessionUpdate::AgentThoughtChunk { content } => {
777 self.push_assistant_content_block(content, true, cx);
778 }
779 acp::SessionUpdate::ToolCall(tool_call) => {
780 self.upsert_tool_call(tool_call, cx);
781 }
782 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
783 self.update_tool_call(tool_call_update, cx)?;
784 }
785 acp::SessionUpdate::Plan(plan) => {
786 self.update_plan(plan, cx);
787 }
788 }
789 Ok(())
790 }
791
792 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
793 let language_registry = self.project.read(cx).languages().clone();
794 let entries_len = self.entries.len();
795
796 if let Some(last_entry) = self.entries.last_mut()
797 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
798 {
799 content.append(chunk, &language_registry, cx);
800 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
801 } else {
802 let content = ContentBlock::new(chunk, &language_registry, cx);
803 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
804 }
805 }
806
807 pub fn push_assistant_content_block(
808 &mut self,
809 chunk: acp::ContentBlock,
810 is_thought: bool,
811 cx: &mut Context<Self>,
812 ) {
813 let language_registry = self.project.read(cx).languages().clone();
814 let entries_len = self.entries.len();
815 if let Some(last_entry) = self.entries.last_mut()
816 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
817 {
818 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
819 match (chunks.last_mut(), is_thought) {
820 (Some(AssistantMessageChunk::Message { block }), false)
821 | (Some(AssistantMessageChunk::Thought { block }), true) => {
822 block.append(chunk, &language_registry, cx)
823 }
824 _ => {
825 let block = ContentBlock::new(chunk, &language_registry, cx);
826 if is_thought {
827 chunks.push(AssistantMessageChunk::Thought { block })
828 } else {
829 chunks.push(AssistantMessageChunk::Message { block })
830 }
831 }
832 }
833 } else {
834 let block = ContentBlock::new(chunk, &language_registry, cx);
835 let chunk = if is_thought {
836 AssistantMessageChunk::Thought { block }
837 } else {
838 AssistantMessageChunk::Message { block }
839 };
840
841 self.push_entry(
842 AgentThreadEntry::AssistantMessage(AssistantMessage {
843 chunks: vec![chunk],
844 }),
845 cx,
846 );
847 }
848 }
849
850 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
851 self.entries.push(entry);
852 cx.emit(AcpThreadEvent::NewEntry);
853 }
854
855 pub fn update_tool_call(
856 &mut self,
857 update: impl Into<ToolCallUpdate>,
858 cx: &mut Context<Self>,
859 ) -> Result<()> {
860 let update = update.into();
861 let languages = self.project.read(cx).languages().clone();
862
863 let (ix, current_call) = self
864 .tool_call_mut(update.id())
865 .context("Tool call not found")?;
866 match update {
867 ToolCallUpdate::UpdateFields(update) => {
868 let location_updated = update.fields.locations.is_some();
869 current_call.update_fields(update.fields, languages, cx);
870 if location_updated {
871 self.resolve_locations(update.id.clone(), cx);
872 }
873 }
874 ToolCallUpdate::UpdateDiff(update) => {
875 current_call.content.clear();
876 current_call
877 .content
878 .push(ToolCallContent::Diff(update.diff));
879 }
880 ToolCallUpdate::UpdateTerminal(update) => {
881 current_call.content.clear();
882 current_call
883 .content
884 .push(ToolCallContent::Terminal(update.terminal));
885 }
886 }
887
888 cx.emit(AcpThreadEvent::EntryUpdated(ix));
889
890 Ok(())
891 }
892
893 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
894 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
895 let status = ToolCallStatus::Allowed {
896 status: tool_call.status,
897 };
898 self.upsert_tool_call_inner(tool_call, status, cx)
899 }
900
901 pub fn upsert_tool_call_inner(
902 &mut self,
903 tool_call: acp::ToolCall,
904 status: ToolCallStatus,
905 cx: &mut Context<Self>,
906 ) {
907 let language_registry = self.project.read(cx).languages().clone();
908 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
909 let id = call.id.clone();
910
911 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
912 *current_call = call;
913
914 cx.emit(AcpThreadEvent::EntryUpdated(ix));
915 } else {
916 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
917 };
918
919 self.resolve_locations(id, cx);
920 }
921
922 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
923 // The tool call we are looking for is typically the last one, or very close to the end.
924 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
925 self.entries
926 .iter_mut()
927 .enumerate()
928 .rev()
929 .find_map(|(index, tool_call)| {
930 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
931 && &tool_call.id == id
932 {
933 Some((index, tool_call))
934 } else {
935 None
936 }
937 })
938 }
939
940 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
941 let project = self.project.clone();
942 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
943 return;
944 };
945 let task = tool_call.resolve_locations(project, cx);
946 cx.spawn(async move |this, cx| {
947 let resolved_locations = task.await;
948 this.update(cx, |this, cx| {
949 let project = this.project.clone();
950 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
951 return;
952 };
953 if let Some(Some(location)) = resolved_locations.last() {
954 project.update(cx, |project, cx| {
955 if let Some(agent_location) = project.agent_location() {
956 let should_ignore = agent_location.buffer == location.buffer
957 && location
958 .buffer
959 .update(cx, |buffer, _| {
960 let snapshot = buffer.snapshot();
961 let old_position =
962 agent_location.position.to_point(&snapshot);
963 let new_position = location.position.to_point(&snapshot);
964 // ignore this so that when we get updates from the edit tool
965 // the position doesn't reset to the startof line
966 old_position.row == new_position.row
967 && old_position.column > new_position.column
968 })
969 .ok()
970 .unwrap_or_default();
971 if !should_ignore {
972 project.set_agent_location(Some(location.clone()), cx);
973 }
974 }
975 });
976 }
977 if tool_call.resolved_locations != resolved_locations {
978 tool_call.resolved_locations = resolved_locations;
979 cx.emit(AcpThreadEvent::EntryUpdated(ix));
980 }
981 })
982 })
983 .detach();
984 }
985
986 pub fn request_tool_call_authorization(
987 &mut self,
988 tool_call: acp::ToolCall,
989 options: Vec<acp::PermissionOption>,
990 cx: &mut Context<Self>,
991 ) -> oneshot::Receiver<acp::PermissionOptionId> {
992 let (tx, rx) = oneshot::channel();
993
994 let status = ToolCallStatus::WaitingForConfirmation {
995 options,
996 respond_tx: tx,
997 };
998
999 self.upsert_tool_call_inner(tool_call, status, cx);
1000 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1001 rx
1002 }
1003
1004 pub fn authorize_tool_call(
1005 &mut self,
1006 id: acp::ToolCallId,
1007 option_id: acp::PermissionOptionId,
1008 option_kind: acp::PermissionOptionKind,
1009 cx: &mut Context<Self>,
1010 ) {
1011 let Some((ix, call)) = self.tool_call_mut(&id) else {
1012 return;
1013 };
1014
1015 let new_status = match option_kind {
1016 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1017 ToolCallStatus::Rejected
1018 }
1019 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1020 ToolCallStatus::Allowed {
1021 status: acp::ToolCallStatus::InProgress,
1022 }
1023 }
1024 };
1025
1026 let curr_status = mem::replace(&mut call.status, new_status);
1027
1028 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1029 respond_tx.send(option_id).log_err();
1030 } else if cfg!(debug_assertions) {
1031 panic!("tried to authorize an already authorized tool call");
1032 }
1033
1034 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1035 }
1036
1037 /// Returns true if the last turn is awaiting tool authorization
1038 pub fn waiting_for_tool_confirmation(&self) -> bool {
1039 for entry in self.entries.iter().rev() {
1040 match &entry {
1041 AgentThreadEntry::ToolCall(call) => match call.status {
1042 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1043 ToolCallStatus::Allowed { .. }
1044 | ToolCallStatus::Rejected
1045 | ToolCallStatus::Canceled => continue,
1046 },
1047 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1048 // Reached the beginning of the turn
1049 return false;
1050 }
1051 }
1052 }
1053 false
1054 }
1055
1056 pub fn plan(&self) -> &Plan {
1057 &self.plan
1058 }
1059
1060 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1061 let new_entries_len = request.entries.len();
1062 let mut new_entries = request.entries.into_iter();
1063
1064 // Reuse existing markdown to prevent flickering
1065 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1066 let PlanEntry {
1067 content,
1068 priority,
1069 status,
1070 } = old;
1071 content.update(cx, |old, cx| {
1072 old.replace(new.content, cx);
1073 });
1074 *priority = new.priority;
1075 *status = new.status;
1076 }
1077 for new in new_entries {
1078 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1079 }
1080 self.plan.entries.truncate(new_entries_len);
1081
1082 cx.notify();
1083 }
1084
1085 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1086 self.plan
1087 .entries
1088 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1089 cx.notify();
1090 }
1091
1092 #[cfg(any(test, feature = "test-support"))]
1093 pub fn send_raw(
1094 &mut self,
1095 message: &str,
1096 cx: &mut Context<Self>,
1097 ) -> BoxFuture<'static, Result<()>> {
1098 self.send(
1099 vec![acp::ContentBlock::Text(acp::TextContent {
1100 text: message.to_string(),
1101 annotations: None,
1102 })],
1103 cx,
1104 )
1105 }
1106
1107 pub fn send(
1108 &mut self,
1109 message: Vec<acp::ContentBlock>,
1110 cx: &mut Context<Self>,
1111 ) -> BoxFuture<'static, Result<()>> {
1112 let block = ContentBlock::new_combined(
1113 message.clone(),
1114 self.project.read(cx).languages().clone(),
1115 cx,
1116 );
1117 self.push_entry(
1118 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1119 cx,
1120 );
1121 self.clear_completed_plan_entries(cx);
1122
1123 let (tx, rx) = oneshot::channel();
1124 let cancel_task = self.cancel(cx);
1125
1126 self.send_task = Some(cx.spawn(async move |this, cx| {
1127 async {
1128 cancel_task.await;
1129
1130 let result = this
1131 .update(cx, |this, cx| {
1132 this.connection.prompt(
1133 acp::PromptRequest {
1134 prompt: message,
1135 session_id: this.session_id.clone(),
1136 },
1137 cx,
1138 )
1139 })?
1140 .await;
1141
1142 tx.send(result).log_err();
1143
1144 anyhow::Ok(())
1145 }
1146 .await
1147 .log_err();
1148 }));
1149
1150 cx.spawn(async move |this, cx| match rx.await {
1151 Ok(Err(e)) => {
1152 this.update(cx, |this, cx| {
1153 this.send_task.take();
1154 cx.emit(AcpThreadEvent::Error)
1155 })
1156 .log_err();
1157 Err(e)?
1158 }
1159 result => {
1160 let cancelled = matches!(
1161 result,
1162 Ok(Ok(acp::PromptResponse {
1163 stop_reason: acp::StopReason::Cancelled
1164 }))
1165 );
1166
1167 // We only take the task if the current prompt wasn't cancelled.
1168 //
1169 // This prompt may have been cancelled because another one was sent
1170 // while it was still generating. In these cases, dropping `send_task`
1171 // would cause the next generation to be cancelled.
1172 if !cancelled {
1173 this.update(cx, |this, _cx| this.send_task.take()).ok();
1174 }
1175
1176 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1177 .log_err();
1178 Ok(())
1179 }
1180 })
1181 .boxed()
1182 }
1183
1184 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1185 let Some(send_task) = self.send_task.take() else {
1186 return Task::ready(());
1187 };
1188
1189 for entry in self.entries.iter_mut() {
1190 if let AgentThreadEntry::ToolCall(call) = entry {
1191 let cancel = matches!(
1192 call.status,
1193 ToolCallStatus::WaitingForConfirmation { .. }
1194 | ToolCallStatus::Allowed {
1195 status: acp::ToolCallStatus::InProgress
1196 }
1197 );
1198
1199 if cancel {
1200 call.status = ToolCallStatus::Canceled;
1201 }
1202 }
1203 }
1204
1205 self.connection.cancel(&self.session_id, cx);
1206
1207 // Wait for the send task to complete
1208 cx.foreground_executor().spawn(send_task)
1209 }
1210
1211 pub fn read_text_file(
1212 &self,
1213 path: PathBuf,
1214 line: Option<u32>,
1215 limit: Option<u32>,
1216 reuse_shared_snapshot: bool,
1217 cx: &mut Context<Self>,
1218 ) -> Task<Result<String>> {
1219 let project = self.project.clone();
1220 let action_log = self.action_log.clone();
1221 cx.spawn(async move |this, cx| {
1222 let load = project.update(cx, |project, cx| {
1223 let path = project
1224 .project_path_for_absolute_path(&path, cx)
1225 .context("invalid path")?;
1226 anyhow::Ok(project.open_buffer(path, cx))
1227 });
1228 let buffer = load??.await?;
1229
1230 let snapshot = if reuse_shared_snapshot {
1231 this.read_with(cx, |this, _| {
1232 this.shared_buffers.get(&buffer.clone()).cloned()
1233 })
1234 .log_err()
1235 .flatten()
1236 } else {
1237 None
1238 };
1239
1240 let snapshot = if let Some(snapshot) = snapshot {
1241 snapshot
1242 } else {
1243 action_log.update(cx, |action_log, cx| {
1244 action_log.buffer_read(buffer.clone(), cx);
1245 })?;
1246 project.update(cx, |project, cx| {
1247 let position = buffer
1248 .read(cx)
1249 .snapshot()
1250 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1251 project.set_agent_location(
1252 Some(AgentLocation {
1253 buffer: buffer.downgrade(),
1254 position,
1255 }),
1256 cx,
1257 );
1258 })?;
1259
1260 buffer.update(cx, |buffer, _| buffer.snapshot())?
1261 };
1262
1263 this.update(cx, |this, _| {
1264 let text = snapshot.text();
1265 this.shared_buffers.insert(buffer.clone(), snapshot);
1266 if line.is_none() && limit.is_none() {
1267 return Ok(text);
1268 }
1269 let limit = limit.unwrap_or(u32::MAX) as usize;
1270 let Some(line) = line else {
1271 return Ok(text.lines().take(limit).collect::<String>());
1272 };
1273
1274 let count = text.lines().count();
1275 if count < line as usize {
1276 anyhow::bail!("There are only {} lines", count);
1277 }
1278 Ok(text
1279 .lines()
1280 .skip(line as usize + 1)
1281 .take(limit)
1282 .collect::<String>())
1283 })?
1284 })
1285 }
1286
1287 pub fn write_text_file(
1288 &self,
1289 path: PathBuf,
1290 content: String,
1291 cx: &mut Context<Self>,
1292 ) -> Task<Result<()>> {
1293 let project = self.project.clone();
1294 let action_log = self.action_log.clone();
1295 cx.spawn(async move |this, cx| {
1296 let load = project.update(cx, |project, cx| {
1297 let path = project
1298 .project_path_for_absolute_path(&path, cx)
1299 .context("invalid path")?;
1300 anyhow::Ok(project.open_buffer(path, cx))
1301 });
1302 let buffer = load??.await?;
1303 let snapshot = this.update(cx, |this, cx| {
1304 this.shared_buffers
1305 .get(&buffer)
1306 .cloned()
1307 .unwrap_or_else(|| buffer.read(cx).snapshot())
1308 })?;
1309 let edits = cx
1310 .background_executor()
1311 .spawn(async move {
1312 let old_text = snapshot.text();
1313 text_diff(old_text.as_str(), &content)
1314 .into_iter()
1315 .map(|(range, replacement)| {
1316 (
1317 snapshot.anchor_after(range.start)
1318 ..snapshot.anchor_before(range.end),
1319 replacement,
1320 )
1321 })
1322 .collect::<Vec<_>>()
1323 })
1324 .await;
1325 cx.update(|cx| {
1326 project.update(cx, |project, cx| {
1327 project.set_agent_location(
1328 Some(AgentLocation {
1329 buffer: buffer.downgrade(),
1330 position: edits
1331 .last()
1332 .map(|(range, _)| range.end)
1333 .unwrap_or(Anchor::MIN),
1334 }),
1335 cx,
1336 );
1337 });
1338
1339 action_log.update(cx, |action_log, cx| {
1340 action_log.buffer_read(buffer.clone(), cx);
1341 });
1342 buffer.update(cx, |buffer, cx| {
1343 buffer.edit(edits, None, cx);
1344 });
1345 action_log.update(cx, |action_log, cx| {
1346 action_log.buffer_edited(buffer.clone(), cx);
1347 });
1348 })?;
1349 project
1350 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1351 .await
1352 })
1353 }
1354
1355 pub fn to_markdown(&self, cx: &App) -> String {
1356 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1357 }
1358
1359 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1360 cx.emit(AcpThreadEvent::ServerExited(status));
1361 }
1362}
1363
1364fn markdown_for_raw_output(
1365 raw_output: &serde_json::Value,
1366 language_registry: &Arc<LanguageRegistry>,
1367 cx: &mut App,
1368) -> Option<Entity<Markdown>> {
1369 match raw_output {
1370 serde_json::Value::Null => None,
1371 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1372 Markdown::new(
1373 value.to_string().into(),
1374 Some(language_registry.clone()),
1375 None,
1376 cx,
1377 )
1378 })),
1379 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1380 Markdown::new(
1381 value.to_string().into(),
1382 Some(language_registry.clone()),
1383 None,
1384 cx,
1385 )
1386 })),
1387 serde_json::Value::String(value) => Some(cx.new(|cx| {
1388 Markdown::new(
1389 value.clone().into(),
1390 Some(language_registry.clone()),
1391 None,
1392 cx,
1393 )
1394 })),
1395 value => Some(cx.new(|cx| {
1396 Markdown::new(
1397 format!("```json\n{}\n```", value).into(),
1398 Some(language_registry.clone()),
1399 None,
1400 cx,
1401 )
1402 })),
1403 }
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 use super::*;
1409 use anyhow::anyhow;
1410 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1411 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1412 use indoc::indoc;
1413 use project::FakeFs;
1414 use rand::Rng as _;
1415 use serde_json::json;
1416 use settings::SettingsStore;
1417 use smol::stream::StreamExt as _;
1418 use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1419
1420 use util::path;
1421
1422 fn init_test(cx: &mut TestAppContext) {
1423 env_logger::try_init().ok();
1424 cx.update(|cx| {
1425 let settings_store = SettingsStore::test(cx);
1426 cx.set_global(settings_store);
1427 Project::init_settings(cx);
1428 language::init(cx);
1429 });
1430 }
1431
1432 #[gpui::test]
1433 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1434 init_test(cx);
1435
1436 let fs = FakeFs::new(cx.executor());
1437 let project = Project::test(fs, [], cx).await;
1438 let connection = Rc::new(FakeAgentConnection::new());
1439 let thread = cx
1440 .spawn(async move |mut cx| {
1441 connection
1442 .new_thread(project, Path::new(path!("/test")), &mut cx)
1443 .await
1444 })
1445 .await
1446 .unwrap();
1447
1448 // Test creating a new user message
1449 thread.update(cx, |thread, cx| {
1450 thread.push_user_content_block(
1451 acp::ContentBlock::Text(acp::TextContent {
1452 annotations: None,
1453 text: "Hello, ".to_string(),
1454 }),
1455 cx,
1456 );
1457 });
1458
1459 thread.update(cx, |thread, cx| {
1460 assert_eq!(thread.entries.len(), 1);
1461 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1462 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1463 } else {
1464 panic!("Expected UserMessage");
1465 }
1466 });
1467
1468 // Test appending to existing user message
1469 thread.update(cx, |thread, cx| {
1470 thread.push_user_content_block(
1471 acp::ContentBlock::Text(acp::TextContent {
1472 annotations: None,
1473 text: "world!".to_string(),
1474 }),
1475 cx,
1476 );
1477 });
1478
1479 thread.update(cx, |thread, cx| {
1480 assert_eq!(thread.entries.len(), 1);
1481 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1482 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1483 } else {
1484 panic!("Expected UserMessage");
1485 }
1486 });
1487
1488 // Test creating new user message after assistant message
1489 thread.update(cx, |thread, cx| {
1490 thread.push_assistant_content_block(
1491 acp::ContentBlock::Text(acp::TextContent {
1492 annotations: None,
1493 text: "Assistant response".to_string(),
1494 }),
1495 false,
1496 cx,
1497 );
1498 });
1499
1500 thread.update(cx, |thread, cx| {
1501 thread.push_user_content_block(
1502 acp::ContentBlock::Text(acp::TextContent {
1503 annotations: None,
1504 text: "New user message".to_string(),
1505 }),
1506 cx,
1507 );
1508 });
1509
1510 thread.update(cx, |thread, cx| {
1511 assert_eq!(thread.entries.len(), 3);
1512 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1513 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1514 } else {
1515 panic!("Expected UserMessage at index 2");
1516 }
1517 });
1518 }
1519
1520 #[gpui::test]
1521 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1522 init_test(cx);
1523
1524 let fs = FakeFs::new(cx.executor());
1525 let project = Project::test(fs, [], cx).await;
1526 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1527 |_, thread, mut cx| {
1528 async move {
1529 thread.update(&mut cx, |thread, cx| {
1530 thread
1531 .handle_session_update(
1532 acp::SessionUpdate::AgentThoughtChunk {
1533 content: "Thinking ".into(),
1534 },
1535 cx,
1536 )
1537 .unwrap();
1538 thread
1539 .handle_session_update(
1540 acp::SessionUpdate::AgentThoughtChunk {
1541 content: "hard!".into(),
1542 },
1543 cx,
1544 )
1545 .unwrap();
1546 })?;
1547 Ok(acp::PromptResponse {
1548 stop_reason: acp::StopReason::EndTurn,
1549 })
1550 }
1551 .boxed_local()
1552 },
1553 ));
1554
1555 let thread = cx
1556 .spawn(async move |mut cx| {
1557 connection
1558 .new_thread(project, Path::new(path!("/test")), &mut cx)
1559 .await
1560 })
1561 .await
1562 .unwrap();
1563
1564 thread
1565 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1566 .await
1567 .unwrap();
1568
1569 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1570 assert_eq!(
1571 output,
1572 indoc! {r#"
1573 ## User
1574
1575 Hello from Zed!
1576
1577 ## Assistant
1578
1579 <thinking>
1580 Thinking hard!
1581 </thinking>
1582
1583 "#}
1584 );
1585 }
1586
1587 #[gpui::test]
1588 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1589 init_test(cx);
1590
1591 let fs = FakeFs::new(cx.executor());
1592 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1593 .await;
1594 let project = Project::test(fs.clone(), [], cx).await;
1595 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1596 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1597 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1598 move |_, thread, mut cx| {
1599 let read_file_tx = read_file_tx.clone();
1600 async move {
1601 let content = thread
1602 .update(&mut cx, |thread, cx| {
1603 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1604 })
1605 .unwrap()
1606 .await
1607 .unwrap();
1608 assert_eq!(content, "one\ntwo\nthree\n");
1609 read_file_tx.take().unwrap().send(()).unwrap();
1610 thread
1611 .update(&mut cx, |thread, cx| {
1612 thread.write_text_file(
1613 path!("/tmp/foo").into(),
1614 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1615 cx,
1616 )
1617 })
1618 .unwrap()
1619 .await
1620 .unwrap();
1621 Ok(acp::PromptResponse {
1622 stop_reason: acp::StopReason::EndTurn,
1623 })
1624 }
1625 .boxed_local()
1626 },
1627 ));
1628
1629 let (worktree, pathbuf) = project
1630 .update(cx, |project, cx| {
1631 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1632 })
1633 .await
1634 .unwrap();
1635 let buffer = project
1636 .update(cx, |project, cx| {
1637 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1638 })
1639 .await
1640 .unwrap();
1641
1642 let thread = cx
1643 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1644 .await
1645 .unwrap();
1646
1647 let request = thread.update(cx, |thread, cx| {
1648 thread.send_raw("Extend the count in /tmp/foo", cx)
1649 });
1650 read_file_rx.await.ok();
1651 buffer.update(cx, |buffer, cx| {
1652 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1653 });
1654 cx.run_until_parked();
1655 assert_eq!(
1656 buffer.read_with(cx, |buffer, _| buffer.text()),
1657 "zero\none\ntwo\nthree\nfour\nfive\n"
1658 );
1659 assert_eq!(
1660 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1661 "zero\none\ntwo\nthree\nfour\nfive\n"
1662 );
1663 request.await.unwrap();
1664 }
1665
1666 #[gpui::test]
1667 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1668 init_test(cx);
1669
1670 let fs = FakeFs::new(cx.executor());
1671 let project = Project::test(fs, [], cx).await;
1672 let id = acp::ToolCallId("test".into());
1673
1674 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1675 let id = id.clone();
1676 move |_, thread, mut cx| {
1677 let id = id.clone();
1678 async move {
1679 thread
1680 .update(&mut cx, |thread, cx| {
1681 thread.handle_session_update(
1682 acp::SessionUpdate::ToolCall(acp::ToolCall {
1683 id: id.clone(),
1684 title: "Label".into(),
1685 kind: acp::ToolKind::Fetch,
1686 status: acp::ToolCallStatus::InProgress,
1687 content: vec![],
1688 locations: vec![],
1689 raw_input: None,
1690 raw_output: None,
1691 }),
1692 cx,
1693 )
1694 })
1695 .unwrap()
1696 .unwrap();
1697 Ok(acp::PromptResponse {
1698 stop_reason: acp::StopReason::EndTurn,
1699 })
1700 }
1701 .boxed_local()
1702 }
1703 }));
1704
1705 let thread = cx
1706 .spawn(async move |mut cx| {
1707 connection
1708 .new_thread(project, Path::new(path!("/test")), &mut cx)
1709 .await
1710 })
1711 .await
1712 .unwrap();
1713
1714 let request = thread.update(cx, |thread, cx| {
1715 thread.send_raw("Fetch https://example.com", cx)
1716 });
1717
1718 run_until_first_tool_call(&thread, cx).await;
1719
1720 thread.read_with(cx, |thread, _| {
1721 assert!(matches!(
1722 thread.entries[1],
1723 AgentThreadEntry::ToolCall(ToolCall {
1724 status: ToolCallStatus::Allowed {
1725 status: acp::ToolCallStatus::InProgress,
1726 ..
1727 },
1728 ..
1729 })
1730 ));
1731 });
1732
1733 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1734
1735 thread.read_with(cx, |thread, _| {
1736 assert!(matches!(
1737 &thread.entries[1],
1738 AgentThreadEntry::ToolCall(ToolCall {
1739 status: ToolCallStatus::Canceled,
1740 ..
1741 })
1742 ));
1743 });
1744
1745 thread
1746 .update(cx, |thread, cx| {
1747 thread.handle_session_update(
1748 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1749 id,
1750 fields: acp::ToolCallUpdateFields {
1751 status: Some(acp::ToolCallStatus::Completed),
1752 ..Default::default()
1753 },
1754 }),
1755 cx,
1756 )
1757 })
1758 .unwrap();
1759
1760 request.await.unwrap();
1761
1762 thread.read_with(cx, |thread, _| {
1763 assert!(matches!(
1764 thread.entries[1],
1765 AgentThreadEntry::ToolCall(ToolCall {
1766 status: ToolCallStatus::Allowed {
1767 status: acp::ToolCallStatus::Completed,
1768 ..
1769 },
1770 ..
1771 })
1772 ));
1773 });
1774 }
1775
1776 #[gpui::test]
1777 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1778 init_test(cx);
1779 let fs = FakeFs::new(cx.background_executor.clone());
1780 fs.insert_tree(path!("/test"), json!({})).await;
1781 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1782
1783 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1784 move |_, thread, mut cx| {
1785 async move {
1786 thread
1787 .update(&mut cx, |thread, cx| {
1788 thread.handle_session_update(
1789 acp::SessionUpdate::ToolCall(acp::ToolCall {
1790 id: acp::ToolCallId("test".into()),
1791 title: "Label".into(),
1792 kind: acp::ToolKind::Edit,
1793 status: acp::ToolCallStatus::Completed,
1794 content: vec![acp::ToolCallContent::Diff {
1795 diff: acp::Diff {
1796 path: "/test/test.txt".into(),
1797 old_text: None,
1798 new_text: "foo".into(),
1799 },
1800 }],
1801 locations: vec![],
1802 raw_input: None,
1803 raw_output: None,
1804 }),
1805 cx,
1806 )
1807 })
1808 .unwrap()
1809 .unwrap();
1810 Ok(acp::PromptResponse {
1811 stop_reason: acp::StopReason::EndTurn,
1812 })
1813 }
1814 .boxed_local()
1815 }
1816 }));
1817
1818 let thread = connection
1819 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1820 .await
1821 .unwrap();
1822 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1823 .await
1824 .unwrap();
1825
1826 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1827 }
1828
1829 async fn run_until_first_tool_call(
1830 thread: &Entity<AcpThread>,
1831 cx: &mut TestAppContext,
1832 ) -> usize {
1833 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1834
1835 let subscription = cx.update(|cx| {
1836 cx.subscribe(thread, move |thread, _, cx| {
1837 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1838 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1839 return tx.try_send(ix).unwrap();
1840 }
1841 }
1842 })
1843 });
1844
1845 select! {
1846 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1847 panic!("Timeout waiting for tool call")
1848 }
1849 ix = rx.next().fuse() => {
1850 drop(subscription);
1851 ix.unwrap()
1852 }
1853 }
1854 }
1855
1856 #[derive(Clone, Default)]
1857 struct FakeAgentConnection {
1858 auth_methods: Vec<acp::AuthMethod>,
1859 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1860 on_user_message: Option<
1861 Rc<
1862 dyn Fn(
1863 acp::PromptRequest,
1864 WeakEntity<AcpThread>,
1865 AsyncApp,
1866 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1867 + 'static,
1868 >,
1869 >,
1870 }
1871
1872 impl FakeAgentConnection {
1873 fn new() -> Self {
1874 Self {
1875 auth_methods: Vec::new(),
1876 on_user_message: None,
1877 sessions: Arc::default(),
1878 }
1879 }
1880
1881 #[expect(unused)]
1882 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1883 self.auth_methods = auth_methods;
1884 self
1885 }
1886
1887 fn on_user_message(
1888 mut self,
1889 handler: impl Fn(
1890 acp::PromptRequest,
1891 WeakEntity<AcpThread>,
1892 AsyncApp,
1893 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1894 + 'static,
1895 ) -> Self {
1896 self.on_user_message.replace(Rc::new(handler));
1897 self
1898 }
1899 }
1900
1901 impl AgentConnection for FakeAgentConnection {
1902 fn auth_methods(&self) -> &[acp::AuthMethod] {
1903 &self.auth_methods
1904 }
1905
1906 fn new_thread(
1907 self: Rc<Self>,
1908 project: Entity<Project>,
1909 _cwd: &Path,
1910 cx: &mut gpui::AsyncApp,
1911 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1912 let session_id = acp::SessionId(
1913 rand::thread_rng()
1914 .sample_iter(&rand::distributions::Alphanumeric)
1915 .take(7)
1916 .map(char::from)
1917 .collect::<String>()
1918 .into(),
1919 );
1920 let thread = cx
1921 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1922 .unwrap();
1923 self.sessions.lock().insert(session_id, thread.downgrade());
1924 Task::ready(Ok(thread))
1925 }
1926
1927 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1928 if self.auth_methods().iter().any(|m| m.id == method) {
1929 Task::ready(Ok(()))
1930 } else {
1931 Task::ready(Err(anyhow!("Invalid Auth Method")))
1932 }
1933 }
1934
1935 fn prompt(
1936 &self,
1937 params: acp::PromptRequest,
1938 cx: &mut App,
1939 ) -> Task<gpui::Result<acp::PromptResponse>> {
1940 let sessions = self.sessions.lock();
1941 let thread = sessions.get(¶ms.session_id).unwrap();
1942 if let Some(handler) = &self.on_user_message {
1943 let handler = handler.clone();
1944 let thread = thread.clone();
1945 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1946 } else {
1947 Task::ready(Ok(acp::PromptResponse {
1948 stop_reason: acp::StopReason::EndTurn,
1949 }))
1950 }
1951 }
1952
1953 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1954 let sessions = self.sessions.lock();
1955 let thread = sessions.get(&session_id).unwrap().clone();
1956
1957 cx.spawn(async move |cx| {
1958 thread
1959 .update(cx, |thread, cx| thread.cancel(cx))
1960 .unwrap()
1961 .await
1962 })
1963 .detach();
1964 }
1965 }
1966}