1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 Self {
187 id: tool_call.id,
188 label: cx.new(|cx| {
189 Markdown::new(
190 tool_call.title.into(),
191 Some(language_registry.clone()),
192 None,
193 cx,
194 )
195 }),
196 kind: tool_call.kind,
197 content: tool_call
198 .content
199 .into_iter()
200 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
201 .collect(),
202 locations: tool_call.locations,
203 resolved_locations: Vec::default(),
204 status,
205 raw_input: tool_call.raw_input,
206 raw_output: tool_call.raw_output,
207 }
208 }
209
210 fn update_fields(
211 &mut self,
212 fields: acp::ToolCallUpdateFields,
213 language_registry: Arc<LanguageRegistry>,
214 cx: &mut App,
215 ) {
216 let acp::ToolCallUpdateFields {
217 kind,
218 status,
219 title,
220 content,
221 locations,
222 raw_input,
223 raw_output,
224 } = fields;
225
226 if let Some(kind) = kind {
227 self.kind = kind;
228 }
229
230 if let Some(status) = status {
231 self.status = status.into();
232 }
233
234 if let Some(title) = title {
235 self.label.update(cx, |label, cx| {
236 label.replace(title, cx);
237 });
238 }
239
240 if let Some(content) = content {
241 let new_content_len = content.len();
242 let mut content = content.into_iter();
243
244 // Reuse existing content if we can
245 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
246 old.update_from_acp(new, language_registry.clone(), cx);
247 }
248 for new in content {
249 self.content.push(ToolCallContent::from_acp(
250 new,
251 language_registry.clone(),
252 cx,
253 ))
254 }
255 self.content.truncate(new_content_len);
256 }
257
258 if let Some(locations) = locations {
259 self.locations = locations;
260 }
261
262 if let Some(raw_input) = raw_input {
263 self.raw_input = Some(raw_input);
264 }
265
266 if let Some(raw_output) = raw_output {
267 if self.content.is_empty()
268 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
269 {
270 self.content
271 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
272 markdown,
273 }));
274 }
275 self.raw_output = Some(raw_output);
276 }
277 }
278
279 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
280 self.content.iter().filter_map(|content| match content {
281 ToolCallContent::Diff(diff) => Some(diff),
282 ToolCallContent::ContentBlock(_) => None,
283 ToolCallContent::Terminal(_) => None,
284 })
285 }
286
287 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
288 self.content.iter().filter_map(|content| match content {
289 ToolCallContent::Terminal(terminal) => Some(terminal),
290 ToolCallContent::ContentBlock(_) => None,
291 ToolCallContent::Diff(_) => None,
292 })
293 }
294
295 fn to_markdown(&self, cx: &App) -> String {
296 let mut markdown = format!(
297 "**Tool Call: {}**\nStatus: {}\n\n",
298 self.label.read(cx).source(),
299 self.status
300 );
301 for content in &self.content {
302 markdown.push_str(content.to_markdown(cx).as_str());
303 markdown.push_str("\n\n");
304 }
305 markdown
306 }
307
308 async fn resolve_location(
309 location: acp::ToolCallLocation,
310 project: WeakEntity<Project>,
311 cx: &mut AsyncApp,
312 ) -> Option<AgentLocation> {
313 let buffer = project
314 .update(cx, |project, cx| {
315 project
316 .project_path_for_absolute_path(&location.path, cx)
317 .map(|path| project.open_buffer(path, cx))
318 })
319 .ok()??;
320 let buffer = buffer.await.log_err()?;
321 let position = buffer
322 .update(cx, |buffer, _| {
323 if let Some(row) = location.line {
324 let snapshot = buffer.snapshot();
325 let column = snapshot.indent_size_for_line(row).len;
326 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
327 snapshot.anchor_before(point)
328 } else {
329 Anchor::MIN
330 }
331 })
332 .ok()?;
333
334 Some(AgentLocation {
335 buffer: buffer.downgrade(),
336 position,
337 })
338 }
339
340 fn resolve_locations(
341 &self,
342 project: Entity<Project>,
343 cx: &mut App,
344 ) -> Task<Vec<Option<AgentLocation>>> {
345 let locations = self.locations.clone();
346 project.update(cx, |_, cx| {
347 cx.spawn(async move |project, cx| {
348 let mut new_locations = Vec::new();
349 for location in locations {
350 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
351 }
352 new_locations
353 })
354 })
355 }
356}
357
358#[derive(Debug)]
359pub enum ToolCallStatus {
360 /// The tool call hasn't started running yet, but we start showing it to
361 /// the user.
362 Pending,
363 /// The tool call is waiting for confirmation from the user.
364 WaitingForConfirmation {
365 options: Vec<acp::PermissionOption>,
366 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
367 },
368 /// The tool call is currently running.
369 InProgress,
370 /// The tool call completed successfully.
371 Completed,
372 /// The tool call failed.
373 Failed,
374 /// The user rejected the tool call.
375 Rejected,
376 /// The user canceled generation so the tool call was canceled.
377 Canceled,
378}
379
380impl From<acp::ToolCallStatus> for ToolCallStatus {
381 fn from(status: acp::ToolCallStatus) -> Self {
382 match status {
383 acp::ToolCallStatus::Pending => Self::Pending,
384 acp::ToolCallStatus::InProgress => Self::InProgress,
385 acp::ToolCallStatus::Completed => Self::Completed,
386 acp::ToolCallStatus::Failed => Self::Failed,
387 }
388 }
389}
390
391impl Display for ToolCallStatus {
392 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393 write!(
394 f,
395 "{}",
396 match self {
397 ToolCallStatus::Pending => "Pending",
398 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
399 ToolCallStatus::InProgress => "In Progress",
400 ToolCallStatus::Completed => "Completed",
401 ToolCallStatus::Failed => "Failed",
402 ToolCallStatus::Rejected => "Rejected",
403 ToolCallStatus::Canceled => "Canceled",
404 }
405 )
406 }
407}
408
409#[derive(Debug, PartialEq, Clone)]
410pub enum ContentBlock {
411 Empty,
412 Markdown { markdown: Entity<Markdown> },
413 ResourceLink { resource_link: acp::ResourceLink },
414}
415
416impl ContentBlock {
417 pub fn new(
418 block: acp::ContentBlock,
419 language_registry: &Arc<LanguageRegistry>,
420 cx: &mut App,
421 ) -> Self {
422 let mut this = Self::Empty;
423 this.append(block, language_registry, cx);
424 this
425 }
426
427 pub fn new_combined(
428 blocks: impl IntoIterator<Item = acp::ContentBlock>,
429 language_registry: Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) -> Self {
432 let mut this = Self::Empty;
433 for block in blocks {
434 this.append(block, &language_registry, cx);
435 }
436 this
437 }
438
439 pub fn append(
440 &mut self,
441 block: acp::ContentBlock,
442 language_registry: &Arc<LanguageRegistry>,
443 cx: &mut App,
444 ) {
445 if matches!(self, ContentBlock::Empty)
446 && let acp::ContentBlock::ResourceLink(resource_link) = block
447 {
448 *self = ContentBlock::ResourceLink { resource_link };
449 return;
450 }
451
452 let new_content = self.block_string_contents(block);
453
454 match self {
455 ContentBlock::Empty => {
456 *self = Self::create_markdown_block(new_content, language_registry, cx);
457 }
458 ContentBlock::Markdown { markdown } => {
459 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
460 }
461 ContentBlock::ResourceLink { resource_link } => {
462 let existing_content = Self::resource_link_md(&resource_link.uri);
463 let combined = format!("{}\n{}", existing_content, new_content);
464
465 *self = Self::create_markdown_block(combined, language_registry, cx);
466 }
467 }
468 }
469
470 fn create_markdown_block(
471 content: String,
472 language_registry: &Arc<LanguageRegistry>,
473 cx: &mut App,
474 ) -> ContentBlock {
475 ContentBlock::Markdown {
476 markdown: cx
477 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
478 }
479 }
480
481 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
482 match block {
483 acp::ContentBlock::Text(text_content) => text_content.text,
484 acp::ContentBlock::ResourceLink(resource_link) => {
485 Self::resource_link_md(&resource_link.uri)
486 }
487 acp::ContentBlock::Resource(acp::EmbeddedResource {
488 resource:
489 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
490 uri,
491 ..
492 }),
493 ..
494 }) => Self::resource_link_md(&uri),
495 acp::ContentBlock::Image(image) => Self::image_md(&image),
496 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
497 }
498 }
499
500 fn resource_link_md(uri: &str) -> String {
501 if let Some(uri) = MentionUri::parse(uri).log_err() {
502 uri.as_link().to_string()
503 } else {
504 uri.to_string()
505 }
506 }
507
508 fn image_md(_image: &acp::ImageContent) -> String {
509 "`Image`".into()
510 }
511
512 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
513 match self {
514 ContentBlock::Empty => "",
515 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
516 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
517 }
518 }
519
520 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
521 match self {
522 ContentBlock::Empty => None,
523 ContentBlock::Markdown { markdown } => Some(markdown),
524 ContentBlock::ResourceLink { .. } => None,
525 }
526 }
527
528 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
529 match self {
530 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
531 _ => None,
532 }
533 }
534}
535
536#[derive(Debug)]
537pub enum ToolCallContent {
538 ContentBlock(ContentBlock),
539 Diff(Entity<Diff>),
540 Terminal(Entity<Terminal>),
541}
542
543impl ToolCallContent {
544 pub fn from_acp(
545 content: acp::ToolCallContent,
546 language_registry: Arc<LanguageRegistry>,
547 cx: &mut App,
548 ) -> Self {
549 match content {
550 acp::ToolCallContent::Content { content } => {
551 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
552 }
553 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
554 Diff::finalized(
555 diff.path,
556 diff.old_text,
557 diff.new_text,
558 language_registry,
559 cx,
560 )
561 })),
562 }
563 }
564
565 pub fn update_from_acp(
566 &mut self,
567 new: acp::ToolCallContent,
568 language_registry: Arc<LanguageRegistry>,
569 cx: &mut App,
570 ) {
571 let needs_update = match (&self, &new) {
572 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
573 old_diff.read(cx).needs_update(
574 new_diff.old_text.as_deref().unwrap_or(""),
575 &new_diff.new_text,
576 cx,
577 )
578 }
579 _ => true,
580 };
581
582 if needs_update {
583 *self = Self::from_acp(new, language_registry, cx);
584 }
585 }
586
587 pub fn to_markdown(&self, cx: &App) -> String {
588 match self {
589 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
590 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
591 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
592 }
593 }
594}
595
596#[derive(Debug, PartialEq)]
597pub enum ToolCallUpdate {
598 UpdateFields(acp::ToolCallUpdate),
599 UpdateDiff(ToolCallUpdateDiff),
600 UpdateTerminal(ToolCallUpdateTerminal),
601}
602
603impl ToolCallUpdate {
604 fn id(&self) -> &acp::ToolCallId {
605 match self {
606 Self::UpdateFields(update) => &update.id,
607 Self::UpdateDiff(diff) => &diff.id,
608 Self::UpdateTerminal(terminal) => &terminal.id,
609 }
610 }
611}
612
613impl From<acp::ToolCallUpdate> for ToolCallUpdate {
614 fn from(update: acp::ToolCallUpdate) -> Self {
615 Self::UpdateFields(update)
616 }
617}
618
619impl From<ToolCallUpdateDiff> for ToolCallUpdate {
620 fn from(diff: ToolCallUpdateDiff) -> Self {
621 Self::UpdateDiff(diff)
622 }
623}
624
625#[derive(Debug, PartialEq)]
626pub struct ToolCallUpdateDiff {
627 pub id: acp::ToolCallId,
628 pub diff: Entity<Diff>,
629}
630
631impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
632 fn from(terminal: ToolCallUpdateTerminal) -> Self {
633 Self::UpdateTerminal(terminal)
634 }
635}
636
637#[derive(Debug, PartialEq)]
638pub struct ToolCallUpdateTerminal {
639 pub id: acp::ToolCallId,
640 pub terminal: Entity<Terminal>,
641}
642
643#[derive(Debug, Default)]
644pub struct Plan {
645 pub entries: Vec<PlanEntry>,
646}
647
648#[derive(Debug)]
649pub struct PlanStats<'a> {
650 pub in_progress_entry: Option<&'a PlanEntry>,
651 pub pending: u32,
652 pub completed: u32,
653}
654
655impl Plan {
656 pub fn is_empty(&self) -> bool {
657 self.entries.is_empty()
658 }
659
660 pub fn stats(&self) -> PlanStats<'_> {
661 let mut stats = PlanStats {
662 in_progress_entry: None,
663 pending: 0,
664 completed: 0,
665 };
666
667 for entry in &self.entries {
668 match &entry.status {
669 acp::PlanEntryStatus::Pending => {
670 stats.pending += 1;
671 }
672 acp::PlanEntryStatus::InProgress => {
673 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
674 }
675 acp::PlanEntryStatus::Completed => {
676 stats.completed += 1;
677 }
678 }
679 }
680
681 stats
682 }
683}
684
685#[derive(Debug)]
686pub struct PlanEntry {
687 pub content: Entity<Markdown>,
688 pub priority: acp::PlanEntryPriority,
689 pub status: acp::PlanEntryStatus,
690}
691
692impl PlanEntry {
693 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
694 Self {
695 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
696 priority: entry.priority,
697 status: entry.status,
698 }
699 }
700}
701
702#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
703pub struct TokenUsage {
704 pub max_tokens: u64,
705 pub used_tokens: u64,
706}
707
708impl TokenUsage {
709 pub fn ratio(&self) -> TokenUsageRatio {
710 #[cfg(debug_assertions)]
711 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
712 .unwrap_or("0.8".to_string())
713 .parse()
714 .unwrap();
715 #[cfg(not(debug_assertions))]
716 let warning_threshold: f32 = 0.8;
717
718 // When the maximum is unknown because there is no selected model,
719 // avoid showing the token limit warning.
720 if self.max_tokens == 0 {
721 TokenUsageRatio::Normal
722 } else if self.used_tokens >= self.max_tokens {
723 TokenUsageRatio::Exceeded
724 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
725 TokenUsageRatio::Warning
726 } else {
727 TokenUsageRatio::Normal
728 }
729 }
730}
731
732#[derive(Debug, Clone, PartialEq, Eq)]
733pub enum TokenUsageRatio {
734 Normal,
735 Warning,
736 Exceeded,
737}
738
739#[derive(Debug, Clone)]
740pub struct RetryStatus {
741 pub last_error: SharedString,
742 pub attempt: usize,
743 pub max_attempts: usize,
744 pub started_at: Instant,
745 pub duration: Duration,
746}
747
748pub struct AcpThread {
749 title: SharedString,
750 entries: Vec<AgentThreadEntry>,
751 plan: Plan,
752 project: Entity<Project>,
753 action_log: Entity<ActionLog>,
754 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
755 send_task: Option<Task<()>>,
756 connection: Rc<dyn AgentConnection>,
757 session_id: acp::SessionId,
758 token_usage: Option<TokenUsage>,
759}
760
761#[derive(Debug)]
762pub enum AcpThreadEvent {
763 NewEntry,
764 TitleUpdated,
765 TokenUsageUpdated,
766 EntryUpdated(usize),
767 EntriesRemoved(Range<usize>),
768 ToolAuthorizationRequired,
769 Retry(RetryStatus),
770 Stopped,
771 Error,
772 LoadError(LoadError),
773}
774
775impl EventEmitter<AcpThreadEvent> for AcpThread {}
776
777#[derive(PartialEq, Eq)]
778pub enum ThreadStatus {
779 Idle,
780 WaitingForToolConfirmation,
781 Generating,
782}
783
784#[derive(Debug, Clone)]
785pub enum LoadError {
786 NotInstalled {
787 error_message: SharedString,
788 install_message: SharedString,
789 install_command: String,
790 },
791 Unsupported {
792 error_message: SharedString,
793 upgrade_message: SharedString,
794 upgrade_command: String,
795 },
796 Exited {
797 status: ExitStatus,
798 },
799 Other(SharedString),
800}
801
802impl Display for LoadError {
803 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
804 match self {
805 LoadError::NotInstalled { error_message, .. }
806 | LoadError::Unsupported { error_message, .. } => {
807 write!(f, "{error_message}")
808 }
809 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
810 LoadError::Other(msg) => write!(f, "{}", msg),
811 }
812 }
813}
814
815impl Error for LoadError {}
816
817impl AcpThread {
818 pub fn new(
819 title: impl Into<SharedString>,
820 connection: Rc<dyn AgentConnection>,
821 project: Entity<Project>,
822 action_log: Entity<ActionLog>,
823 session_id: acp::SessionId,
824 ) -> Self {
825 Self {
826 action_log,
827 shared_buffers: Default::default(),
828 entries: Default::default(),
829 plan: Default::default(),
830 title: title.into(),
831 project,
832 send_task: None,
833 connection,
834 session_id,
835 token_usage: None,
836 }
837 }
838
839 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
840 &self.connection
841 }
842
843 pub fn action_log(&self) -> &Entity<ActionLog> {
844 &self.action_log
845 }
846
847 pub fn project(&self) -> &Entity<Project> {
848 &self.project
849 }
850
851 pub fn title(&self) -> SharedString {
852 self.title.clone()
853 }
854
855 pub fn entries(&self) -> &[AgentThreadEntry] {
856 &self.entries
857 }
858
859 pub fn session_id(&self) -> &acp::SessionId {
860 &self.session_id
861 }
862
863 pub fn status(&self) -> ThreadStatus {
864 if self.send_task.is_some() {
865 if self.waiting_for_tool_confirmation() {
866 ThreadStatus::WaitingForToolConfirmation
867 } else {
868 ThreadStatus::Generating
869 }
870 } else {
871 ThreadStatus::Idle
872 }
873 }
874
875 pub fn token_usage(&self) -> Option<&TokenUsage> {
876 self.token_usage.as_ref()
877 }
878
879 pub fn has_pending_edit_tool_calls(&self) -> bool {
880 for entry in self.entries.iter().rev() {
881 match entry {
882 AgentThreadEntry::UserMessage(_) => return false,
883 AgentThreadEntry::ToolCall(
884 call @ ToolCall {
885 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
886 ..
887 },
888 ) if call.diffs().next().is_some() => {
889 return true;
890 }
891 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
892 }
893 }
894
895 false
896 }
897
898 pub fn used_tools_since_last_user_message(&self) -> bool {
899 for entry in self.entries.iter().rev() {
900 match entry {
901 AgentThreadEntry::UserMessage(..) => return false,
902 AgentThreadEntry::AssistantMessage(..) => continue,
903 AgentThreadEntry::ToolCall(..) => return true,
904 }
905 }
906
907 false
908 }
909
910 pub fn handle_session_update(
911 &mut self,
912 update: acp::SessionUpdate,
913 cx: &mut Context<Self>,
914 ) -> Result<(), acp::Error> {
915 match update {
916 acp::SessionUpdate::UserMessageChunk { content } => {
917 self.push_user_content_block(None, content, cx);
918 }
919 acp::SessionUpdate::AgentMessageChunk { content } => {
920 self.push_assistant_content_block(content, false, cx);
921 }
922 acp::SessionUpdate::AgentThoughtChunk { content } => {
923 self.push_assistant_content_block(content, true, cx);
924 }
925 acp::SessionUpdate::ToolCall(tool_call) => {
926 self.upsert_tool_call(tool_call, cx)?;
927 }
928 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
929 self.update_tool_call(tool_call_update, cx)?;
930 }
931 acp::SessionUpdate::Plan(plan) => {
932 self.update_plan(plan, cx);
933 }
934 }
935 Ok(())
936 }
937
938 pub fn push_user_content_block(
939 &mut self,
940 message_id: Option<UserMessageId>,
941 chunk: acp::ContentBlock,
942 cx: &mut Context<Self>,
943 ) {
944 let language_registry = self.project.read(cx).languages().clone();
945 let entries_len = self.entries.len();
946
947 if let Some(last_entry) = self.entries.last_mut()
948 && let AgentThreadEntry::UserMessage(UserMessage {
949 id,
950 content,
951 chunks,
952 ..
953 }) = last_entry
954 {
955 *id = message_id.or(id.take());
956 content.append(chunk.clone(), &language_registry, cx);
957 chunks.push(chunk);
958 let idx = entries_len - 1;
959 cx.emit(AcpThreadEvent::EntryUpdated(idx));
960 } else {
961 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
962 self.push_entry(
963 AgentThreadEntry::UserMessage(UserMessage {
964 id: message_id,
965 content,
966 chunks: vec![chunk],
967 checkpoint: None,
968 }),
969 cx,
970 );
971 }
972 }
973
974 pub fn push_assistant_content_block(
975 &mut self,
976 chunk: acp::ContentBlock,
977 is_thought: bool,
978 cx: &mut Context<Self>,
979 ) {
980 let language_registry = self.project.read(cx).languages().clone();
981 let entries_len = self.entries.len();
982 if let Some(last_entry) = self.entries.last_mut()
983 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
984 {
985 let idx = entries_len - 1;
986 cx.emit(AcpThreadEvent::EntryUpdated(idx));
987 match (chunks.last_mut(), is_thought) {
988 (Some(AssistantMessageChunk::Message { block }), false)
989 | (Some(AssistantMessageChunk::Thought { block }), true) => {
990 block.append(chunk, &language_registry, cx)
991 }
992 _ => {
993 let block = ContentBlock::new(chunk, &language_registry, cx);
994 if is_thought {
995 chunks.push(AssistantMessageChunk::Thought { block })
996 } else {
997 chunks.push(AssistantMessageChunk::Message { block })
998 }
999 }
1000 }
1001 } else {
1002 let block = ContentBlock::new(chunk, &language_registry, cx);
1003 let chunk = if is_thought {
1004 AssistantMessageChunk::Thought { block }
1005 } else {
1006 AssistantMessageChunk::Message { block }
1007 };
1008
1009 self.push_entry(
1010 AgentThreadEntry::AssistantMessage(AssistantMessage {
1011 chunks: vec![chunk],
1012 }),
1013 cx,
1014 );
1015 }
1016 }
1017
1018 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1019 self.entries.push(entry);
1020 cx.emit(AcpThreadEvent::NewEntry);
1021 }
1022
1023 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
1024 self.title = title;
1025 cx.emit(AcpThreadEvent::TitleUpdated);
1026 Ok(())
1027 }
1028
1029 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1030 self.token_usage = usage;
1031 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1032 }
1033
1034 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1035 cx.emit(AcpThreadEvent::Retry(status));
1036 }
1037
1038 pub fn update_tool_call(
1039 &mut self,
1040 update: impl Into<ToolCallUpdate>,
1041 cx: &mut Context<Self>,
1042 ) -> Result<()> {
1043 let update = update.into();
1044 let languages = self.project.read(cx).languages().clone();
1045
1046 let (ix, current_call) = self
1047 .tool_call_mut(update.id())
1048 .context("Tool call not found")?;
1049 match update {
1050 ToolCallUpdate::UpdateFields(update) => {
1051 let location_updated = update.fields.locations.is_some();
1052 current_call.update_fields(update.fields, languages, cx);
1053 if location_updated {
1054 self.resolve_locations(update.id, cx);
1055 }
1056 }
1057 ToolCallUpdate::UpdateDiff(update) => {
1058 current_call.content.clear();
1059 current_call
1060 .content
1061 .push(ToolCallContent::Diff(update.diff));
1062 }
1063 ToolCallUpdate::UpdateTerminal(update) => {
1064 current_call.content.clear();
1065 current_call
1066 .content
1067 .push(ToolCallContent::Terminal(update.terminal));
1068 }
1069 }
1070
1071 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1072
1073 Ok(())
1074 }
1075
1076 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1077 pub fn upsert_tool_call(
1078 &mut self,
1079 tool_call: acp::ToolCall,
1080 cx: &mut Context<Self>,
1081 ) -> Result<(), acp::Error> {
1082 let status = tool_call.status.into();
1083 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1084 }
1085
1086 /// Fails if id does not match an existing entry.
1087 pub fn upsert_tool_call_inner(
1088 &mut self,
1089 tool_call_update: acp::ToolCallUpdate,
1090 status: ToolCallStatus,
1091 cx: &mut Context<Self>,
1092 ) -> Result<(), acp::Error> {
1093 let language_registry = self.project.read(cx).languages().clone();
1094 let id = tool_call_update.id.clone();
1095
1096 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1097 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1098 current_call.status = status;
1099
1100 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1101 } else {
1102 let call =
1103 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1104 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1105 };
1106
1107 self.resolve_locations(id, cx);
1108 Ok(())
1109 }
1110
1111 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1112 // The tool call we are looking for is typically the last one, or very close to the end.
1113 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1114 self.entries
1115 .iter_mut()
1116 .enumerate()
1117 .rev()
1118 .find_map(|(index, tool_call)| {
1119 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1120 && &tool_call.id == id
1121 {
1122 Some((index, tool_call))
1123 } else {
1124 None
1125 }
1126 })
1127 }
1128
1129 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1130 self.entries
1131 .iter()
1132 .enumerate()
1133 .rev()
1134 .find_map(|(index, tool_call)| {
1135 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1136 && &tool_call.id == id
1137 {
1138 Some((index, tool_call))
1139 } else {
1140 None
1141 }
1142 })
1143 }
1144
1145 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1146 let project = self.project.clone();
1147 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1148 return;
1149 };
1150 let task = tool_call.resolve_locations(project, cx);
1151 cx.spawn(async move |this, cx| {
1152 let resolved_locations = task.await;
1153 this.update(cx, |this, cx| {
1154 let project = this.project.clone();
1155 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1156 return;
1157 };
1158 if let Some(Some(location)) = resolved_locations.last() {
1159 project.update(cx, |project, cx| {
1160 if let Some(agent_location) = project.agent_location() {
1161 let should_ignore = agent_location.buffer == location.buffer
1162 && location
1163 .buffer
1164 .update(cx, |buffer, _| {
1165 let snapshot = buffer.snapshot();
1166 let old_position =
1167 agent_location.position.to_point(&snapshot);
1168 let new_position = location.position.to_point(&snapshot);
1169 // ignore this so that when we get updates from the edit tool
1170 // the position doesn't reset to the startof line
1171 old_position.row == new_position.row
1172 && old_position.column > new_position.column
1173 })
1174 .ok()
1175 .unwrap_or_default();
1176 if !should_ignore {
1177 project.set_agent_location(Some(location.clone()), cx);
1178 }
1179 }
1180 });
1181 }
1182 if tool_call.resolved_locations != resolved_locations {
1183 tool_call.resolved_locations = resolved_locations;
1184 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1185 }
1186 })
1187 })
1188 .detach();
1189 }
1190
1191 pub fn request_tool_call_authorization(
1192 &mut self,
1193 tool_call: acp::ToolCallUpdate,
1194 options: Vec<acp::PermissionOption>,
1195 cx: &mut Context<Self>,
1196 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1197 let (tx, rx) = oneshot::channel();
1198
1199 let status = ToolCallStatus::WaitingForConfirmation {
1200 options,
1201 respond_tx: tx,
1202 };
1203
1204 self.upsert_tool_call_inner(tool_call, status, cx)?;
1205 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1206 Ok(rx)
1207 }
1208
1209 pub fn authorize_tool_call(
1210 &mut self,
1211 id: acp::ToolCallId,
1212 option_id: acp::PermissionOptionId,
1213 option_kind: acp::PermissionOptionKind,
1214 cx: &mut Context<Self>,
1215 ) {
1216 let Some((ix, call)) = self.tool_call_mut(&id) else {
1217 return;
1218 };
1219
1220 let new_status = match option_kind {
1221 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1222 ToolCallStatus::Rejected
1223 }
1224 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1225 ToolCallStatus::InProgress
1226 }
1227 };
1228
1229 let curr_status = mem::replace(&mut call.status, new_status);
1230
1231 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1232 respond_tx.send(option_id).log_err();
1233 } else if cfg!(debug_assertions) {
1234 panic!("tried to authorize an already authorized tool call");
1235 }
1236
1237 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1238 }
1239
1240 /// Returns true if the last turn is awaiting tool authorization
1241 pub fn waiting_for_tool_confirmation(&self) -> bool {
1242 for entry in self.entries.iter().rev() {
1243 match &entry {
1244 AgentThreadEntry::ToolCall(call) => match call.status {
1245 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1246 ToolCallStatus::Pending
1247 | ToolCallStatus::InProgress
1248 | ToolCallStatus::Completed
1249 | ToolCallStatus::Failed
1250 | ToolCallStatus::Rejected
1251 | ToolCallStatus::Canceled => continue,
1252 },
1253 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1254 // Reached the beginning of the turn
1255 return false;
1256 }
1257 }
1258 }
1259 false
1260 }
1261
1262 pub fn plan(&self) -> &Plan {
1263 &self.plan
1264 }
1265
1266 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1267 let new_entries_len = request.entries.len();
1268 let mut new_entries = request.entries.into_iter();
1269
1270 // Reuse existing markdown to prevent flickering
1271 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1272 let PlanEntry {
1273 content,
1274 priority,
1275 status,
1276 } = old;
1277 content.update(cx, |old, cx| {
1278 old.replace(new.content, cx);
1279 });
1280 *priority = new.priority;
1281 *status = new.status;
1282 }
1283 for new in new_entries {
1284 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1285 }
1286 self.plan.entries.truncate(new_entries_len);
1287
1288 cx.notify();
1289 }
1290
1291 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1292 self.plan
1293 .entries
1294 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1295 cx.notify();
1296 }
1297
1298 #[cfg(any(test, feature = "test-support"))]
1299 pub fn send_raw(
1300 &mut self,
1301 message: &str,
1302 cx: &mut Context<Self>,
1303 ) -> BoxFuture<'static, Result<()>> {
1304 self.send(
1305 vec![acp::ContentBlock::Text(acp::TextContent {
1306 text: message.to_string(),
1307 annotations: None,
1308 })],
1309 cx,
1310 )
1311 }
1312
1313 pub fn send(
1314 &mut self,
1315 message: Vec<acp::ContentBlock>,
1316 cx: &mut Context<Self>,
1317 ) -> BoxFuture<'static, Result<()>> {
1318 let block = ContentBlock::new_combined(
1319 message.clone(),
1320 self.project.read(cx).languages().clone(),
1321 cx,
1322 );
1323 let request = acp::PromptRequest {
1324 prompt: message.clone(),
1325 session_id: self.session_id.clone(),
1326 };
1327 let git_store = self.project.read(cx).git_store().clone();
1328
1329 let message_id = if self
1330 .connection
1331 .session_editor(&self.session_id, cx)
1332 .is_some()
1333 {
1334 Some(UserMessageId::new())
1335 } else {
1336 None
1337 };
1338
1339 self.run_turn(cx, async move |this, cx| {
1340 this.update(cx, |this, cx| {
1341 this.push_entry(
1342 AgentThreadEntry::UserMessage(UserMessage {
1343 id: message_id.clone(),
1344 content: block,
1345 chunks: message,
1346 checkpoint: None,
1347 }),
1348 cx,
1349 );
1350 })
1351 .ok();
1352
1353 let old_checkpoint = git_store
1354 .update(cx, |git, cx| git.checkpoint(cx))?
1355 .await
1356 .context("failed to get old checkpoint")
1357 .log_err();
1358 this.update(cx, |this, cx| {
1359 if let Some((_ix, message)) = this.last_user_message() {
1360 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1361 git_checkpoint,
1362 show: false,
1363 });
1364 }
1365 this.connection.prompt(message_id, request, cx)
1366 })?
1367 .await
1368 })
1369 }
1370
1371 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1372 self.run_turn(cx, async move |this, cx| {
1373 this.update(cx, |this, cx| {
1374 this.connection
1375 .resume(&this.session_id, cx)
1376 .map(|resume| resume.run(cx))
1377 })?
1378 .context("resuming a session is not supported")?
1379 .await
1380 })
1381 }
1382
1383 fn run_turn(
1384 &mut self,
1385 cx: &mut Context<Self>,
1386 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1387 ) -> BoxFuture<'static, Result<()>> {
1388 self.clear_completed_plan_entries(cx);
1389
1390 let (tx, rx) = oneshot::channel();
1391 let cancel_task = self.cancel(cx);
1392
1393 self.send_task = Some(cx.spawn(async move |this, cx| {
1394 cancel_task.await;
1395 tx.send(f(this, cx).await).ok();
1396 }));
1397
1398 cx.spawn(async move |this, cx| {
1399 let response = rx.await;
1400
1401 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1402 .await?;
1403
1404 this.update(cx, |this, cx| {
1405 this.project
1406 .update(cx, |project, cx| project.set_agent_location(None, cx));
1407 match response {
1408 Ok(Err(e)) => {
1409 this.send_task.take();
1410 cx.emit(AcpThreadEvent::Error);
1411 Err(e)
1412 }
1413 result => {
1414 let canceled = matches!(
1415 result,
1416 Ok(Ok(acp::PromptResponse {
1417 stop_reason: acp::StopReason::Cancelled
1418 }))
1419 );
1420
1421 // We only take the task if the current prompt wasn't canceled.
1422 //
1423 // This prompt may have been canceled because another one was sent
1424 // while it was still generating. In these cases, dropping `send_task`
1425 // would cause the next generation to be canceled.
1426 if !canceled {
1427 this.send_task.take();
1428 }
1429
1430 // Truncate entries if the last prompt was refused.
1431 if let Ok(Ok(acp::PromptResponse {
1432 stop_reason: acp::StopReason::Refusal,
1433 })) = result
1434 && let Some((ix, _)) = this.last_user_message()
1435 {
1436 let range = ix..this.entries.len();
1437 this.entries.truncate(ix);
1438 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1439 }
1440
1441 cx.emit(AcpThreadEvent::Stopped);
1442 Ok(())
1443 }
1444 }
1445 })?
1446 })
1447 .boxed()
1448 }
1449
1450 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1451 let Some(send_task) = self.send_task.take() else {
1452 return Task::ready(());
1453 };
1454
1455 for entry in self.entries.iter_mut() {
1456 if let AgentThreadEntry::ToolCall(call) = entry {
1457 let cancel = matches!(
1458 call.status,
1459 ToolCallStatus::Pending
1460 | ToolCallStatus::WaitingForConfirmation { .. }
1461 | ToolCallStatus::InProgress
1462 );
1463
1464 if cancel {
1465 call.status = ToolCallStatus::Canceled;
1466 }
1467 }
1468 }
1469
1470 self.connection.cancel(&self.session_id, cx);
1471
1472 // Wait for the send task to complete
1473 cx.foreground_executor().spawn(send_task)
1474 }
1475
1476 /// Rewinds this thread to before the entry at `index`, removing it and all
1477 /// subsequent entries while reverting any changes made from that point.
1478 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1479 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1480 return Task::ready(Err(anyhow!("not supported")));
1481 };
1482 let Some(message) = self.user_message(&id) else {
1483 return Task::ready(Err(anyhow!("message not found")));
1484 };
1485
1486 let checkpoint = message
1487 .checkpoint
1488 .as_ref()
1489 .map(|c| c.git_checkpoint.clone());
1490
1491 let git_store = self.project.read(cx).git_store().clone();
1492 cx.spawn(async move |this, cx| {
1493 if let Some(checkpoint) = checkpoint {
1494 git_store
1495 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1496 .await?;
1497 }
1498
1499 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1500 .await?;
1501 this.update(cx, |this, cx| {
1502 if let Some((ix, _)) = this.user_message_mut(&id) {
1503 let range = ix..this.entries.len();
1504 this.entries.truncate(ix);
1505 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1506 }
1507 })
1508 })
1509 }
1510
1511 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1512 let git_store = self.project.read(cx).git_store().clone();
1513
1514 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1515 if let Some(checkpoint) = message.checkpoint.as_ref() {
1516 checkpoint.git_checkpoint.clone()
1517 } else {
1518 return Task::ready(Ok(()));
1519 }
1520 } else {
1521 return Task::ready(Ok(()));
1522 };
1523
1524 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1525 cx.spawn(async move |this, cx| {
1526 let new_checkpoint = new_checkpoint
1527 .await
1528 .context("failed to get new checkpoint")
1529 .log_err();
1530 if let Some(new_checkpoint) = new_checkpoint {
1531 let equal = git_store
1532 .update(cx, |git, cx| {
1533 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1534 })?
1535 .await
1536 .unwrap_or(true);
1537 this.update(cx, |this, cx| {
1538 let (ix, message) = this.last_user_message().context("no user message")?;
1539 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1540 checkpoint.show = !equal;
1541 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1542 anyhow::Ok(())
1543 })??;
1544 }
1545
1546 Ok(())
1547 })
1548 }
1549
1550 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1551 self.entries
1552 .iter_mut()
1553 .enumerate()
1554 .rev()
1555 .find_map(|(ix, entry)| {
1556 if let AgentThreadEntry::UserMessage(message) = entry {
1557 Some((ix, message))
1558 } else {
1559 None
1560 }
1561 })
1562 }
1563
1564 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1565 self.entries.iter().find_map(|entry| {
1566 if let AgentThreadEntry::UserMessage(message) = entry {
1567 if message.id.as_ref() == Some(id) {
1568 Some(message)
1569 } else {
1570 None
1571 }
1572 } else {
1573 None
1574 }
1575 })
1576 }
1577
1578 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1579 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1580 if let AgentThreadEntry::UserMessage(message) = entry {
1581 if message.id.as_ref() == Some(id) {
1582 Some((ix, message))
1583 } else {
1584 None
1585 }
1586 } else {
1587 None
1588 }
1589 })
1590 }
1591
1592 pub fn read_text_file(
1593 &self,
1594 path: PathBuf,
1595 line: Option<u32>,
1596 limit: Option<u32>,
1597 reuse_shared_snapshot: bool,
1598 cx: &mut Context<Self>,
1599 ) -> Task<Result<String>> {
1600 let project = self.project.clone();
1601 let action_log = self.action_log.clone();
1602 cx.spawn(async move |this, cx| {
1603 let load = project.update(cx, |project, cx| {
1604 let path = project
1605 .project_path_for_absolute_path(&path, cx)
1606 .context("invalid path")?;
1607 anyhow::Ok(project.open_buffer(path, cx))
1608 });
1609 let buffer = load??.await?;
1610
1611 let snapshot = if reuse_shared_snapshot {
1612 this.read_with(cx, |this, _| {
1613 this.shared_buffers.get(&buffer.clone()).cloned()
1614 })
1615 .log_err()
1616 .flatten()
1617 } else {
1618 None
1619 };
1620
1621 let snapshot = if let Some(snapshot) = snapshot {
1622 snapshot
1623 } else {
1624 action_log.update(cx, |action_log, cx| {
1625 action_log.buffer_read(buffer.clone(), cx);
1626 })?;
1627 project.update(cx, |project, cx| {
1628 let position = buffer
1629 .read(cx)
1630 .snapshot()
1631 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1632 project.set_agent_location(
1633 Some(AgentLocation {
1634 buffer: buffer.downgrade(),
1635 position,
1636 }),
1637 cx,
1638 );
1639 })?;
1640
1641 buffer.update(cx, |buffer, _| buffer.snapshot())?
1642 };
1643
1644 this.update(cx, |this, _| {
1645 let text = snapshot.text();
1646 this.shared_buffers.insert(buffer.clone(), snapshot);
1647 if line.is_none() && limit.is_none() {
1648 return Ok(text);
1649 }
1650 let limit = limit.unwrap_or(u32::MAX) as usize;
1651 let Some(line) = line else {
1652 return Ok(text.lines().take(limit).collect::<String>());
1653 };
1654
1655 let count = text.lines().count();
1656 if count < line as usize {
1657 anyhow::bail!("There are only {} lines", count);
1658 }
1659 Ok(text
1660 .lines()
1661 .skip(line as usize + 1)
1662 .take(limit)
1663 .collect::<String>())
1664 })?
1665 })
1666 }
1667
1668 pub fn write_text_file(
1669 &self,
1670 path: PathBuf,
1671 content: String,
1672 cx: &mut Context<Self>,
1673 ) -> Task<Result<()>> {
1674 let project = self.project.clone();
1675 let action_log = self.action_log.clone();
1676 cx.spawn(async move |this, cx| {
1677 let load = project.update(cx, |project, cx| {
1678 let path = project
1679 .project_path_for_absolute_path(&path, cx)
1680 .context("invalid path")?;
1681 anyhow::Ok(project.open_buffer(path, cx))
1682 });
1683 let buffer = load??.await?;
1684 let snapshot = this.update(cx, |this, cx| {
1685 this.shared_buffers
1686 .get(&buffer)
1687 .cloned()
1688 .unwrap_or_else(|| buffer.read(cx).snapshot())
1689 })?;
1690 let edits = cx
1691 .background_executor()
1692 .spawn(async move {
1693 let old_text = snapshot.text();
1694 text_diff(old_text.as_str(), &content)
1695 .into_iter()
1696 .map(|(range, replacement)| {
1697 (
1698 snapshot.anchor_after(range.start)
1699 ..snapshot.anchor_before(range.end),
1700 replacement,
1701 )
1702 })
1703 .collect::<Vec<_>>()
1704 })
1705 .await;
1706
1707 project.update(cx, |project, cx| {
1708 project.set_agent_location(
1709 Some(AgentLocation {
1710 buffer: buffer.downgrade(),
1711 position: edits
1712 .last()
1713 .map(|(range, _)| range.end)
1714 .unwrap_or(Anchor::MIN),
1715 }),
1716 cx,
1717 );
1718 })?;
1719
1720 let format_on_save = cx.update(|cx| {
1721 action_log.update(cx, |action_log, cx| {
1722 action_log.buffer_read(buffer.clone(), cx);
1723 });
1724
1725 let format_on_save = buffer.update(cx, |buffer, cx| {
1726 buffer.edit(edits, None, cx);
1727
1728 let settings = language::language_settings::language_settings(
1729 buffer.language().map(|l| l.name()),
1730 buffer.file(),
1731 cx,
1732 );
1733
1734 settings.format_on_save != FormatOnSave::Off
1735 });
1736 action_log.update(cx, |action_log, cx| {
1737 action_log.buffer_edited(buffer.clone(), cx);
1738 });
1739 format_on_save
1740 })?;
1741
1742 if format_on_save {
1743 let format_task = project.update(cx, |project, cx| {
1744 project.format(
1745 HashSet::from_iter([buffer.clone()]),
1746 LspFormatTarget::Buffers,
1747 false,
1748 FormatTrigger::Save,
1749 cx,
1750 )
1751 })?;
1752 format_task.await.log_err();
1753
1754 action_log.update(cx, |action_log, cx| {
1755 action_log.buffer_edited(buffer.clone(), cx);
1756 })?;
1757 }
1758
1759 project
1760 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1761 .await
1762 })
1763 }
1764
1765 pub fn to_markdown(&self, cx: &App) -> String {
1766 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1767 }
1768
1769 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1770 cx.emit(AcpThreadEvent::LoadError(error));
1771 }
1772}
1773
1774fn markdown_for_raw_output(
1775 raw_output: &serde_json::Value,
1776 language_registry: &Arc<LanguageRegistry>,
1777 cx: &mut App,
1778) -> Option<Entity<Markdown>> {
1779 match raw_output {
1780 serde_json::Value::Null => None,
1781 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1782 Markdown::new(
1783 value.to_string().into(),
1784 Some(language_registry.clone()),
1785 None,
1786 cx,
1787 )
1788 })),
1789 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1790 Markdown::new(
1791 value.to_string().into(),
1792 Some(language_registry.clone()),
1793 None,
1794 cx,
1795 )
1796 })),
1797 serde_json::Value::String(value) => Some(cx.new(|cx| {
1798 Markdown::new(
1799 value.clone().into(),
1800 Some(language_registry.clone()),
1801 None,
1802 cx,
1803 )
1804 })),
1805 value => Some(cx.new(|cx| {
1806 Markdown::new(
1807 format!("```json\n{}\n```", value).into(),
1808 Some(language_registry.clone()),
1809 None,
1810 cx,
1811 )
1812 })),
1813 }
1814}
1815
1816#[cfg(test)]
1817mod tests {
1818 use super::*;
1819 use anyhow::anyhow;
1820 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1821 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1822 use indoc::indoc;
1823 use project::{FakeFs, Fs};
1824 use rand::Rng as _;
1825 use serde_json::json;
1826 use settings::SettingsStore;
1827 use smol::stream::StreamExt as _;
1828 use std::{
1829 any::Any,
1830 cell::RefCell,
1831 path::Path,
1832 rc::Rc,
1833 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1834 time::Duration,
1835 };
1836 use util::path;
1837
1838 fn init_test(cx: &mut TestAppContext) {
1839 env_logger::try_init().ok();
1840 cx.update(|cx| {
1841 let settings_store = SettingsStore::test(cx);
1842 cx.set_global(settings_store);
1843 Project::init_settings(cx);
1844 language::init(cx);
1845 });
1846 }
1847
1848 #[gpui::test]
1849 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1850 init_test(cx);
1851
1852 let fs = FakeFs::new(cx.executor());
1853 let project = Project::test(fs, [], cx).await;
1854 let connection = Rc::new(FakeAgentConnection::new());
1855 let thread = cx
1856 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1857 .await
1858 .unwrap();
1859
1860 // Test creating a new user message
1861 thread.update(cx, |thread, cx| {
1862 thread.push_user_content_block(
1863 None,
1864 acp::ContentBlock::Text(acp::TextContent {
1865 annotations: None,
1866 text: "Hello, ".to_string(),
1867 }),
1868 cx,
1869 );
1870 });
1871
1872 thread.update(cx, |thread, cx| {
1873 assert_eq!(thread.entries.len(), 1);
1874 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1875 assert_eq!(user_msg.id, None);
1876 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1877 } else {
1878 panic!("Expected UserMessage");
1879 }
1880 });
1881
1882 // Test appending to existing user message
1883 let message_1_id = UserMessageId::new();
1884 thread.update(cx, |thread, cx| {
1885 thread.push_user_content_block(
1886 Some(message_1_id.clone()),
1887 acp::ContentBlock::Text(acp::TextContent {
1888 annotations: None,
1889 text: "world!".to_string(),
1890 }),
1891 cx,
1892 );
1893 });
1894
1895 thread.update(cx, |thread, cx| {
1896 assert_eq!(thread.entries.len(), 1);
1897 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1898 assert_eq!(user_msg.id, Some(message_1_id));
1899 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1900 } else {
1901 panic!("Expected UserMessage");
1902 }
1903 });
1904
1905 // Test creating new user message after assistant message
1906 thread.update(cx, |thread, cx| {
1907 thread.push_assistant_content_block(
1908 acp::ContentBlock::Text(acp::TextContent {
1909 annotations: None,
1910 text: "Assistant response".to_string(),
1911 }),
1912 false,
1913 cx,
1914 );
1915 });
1916
1917 let message_2_id = UserMessageId::new();
1918 thread.update(cx, |thread, cx| {
1919 thread.push_user_content_block(
1920 Some(message_2_id.clone()),
1921 acp::ContentBlock::Text(acp::TextContent {
1922 annotations: None,
1923 text: "New user message".to_string(),
1924 }),
1925 cx,
1926 );
1927 });
1928
1929 thread.update(cx, |thread, cx| {
1930 assert_eq!(thread.entries.len(), 3);
1931 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1932 assert_eq!(user_msg.id, Some(message_2_id));
1933 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1934 } else {
1935 panic!("Expected UserMessage at index 2");
1936 }
1937 });
1938 }
1939
1940 #[gpui::test]
1941 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1942 init_test(cx);
1943
1944 let fs = FakeFs::new(cx.executor());
1945 let project = Project::test(fs, [], cx).await;
1946 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1947 |_, thread, mut cx| {
1948 async move {
1949 thread.update(&mut cx, |thread, cx| {
1950 thread
1951 .handle_session_update(
1952 acp::SessionUpdate::AgentThoughtChunk {
1953 content: "Thinking ".into(),
1954 },
1955 cx,
1956 )
1957 .unwrap();
1958 thread
1959 .handle_session_update(
1960 acp::SessionUpdate::AgentThoughtChunk {
1961 content: "hard!".into(),
1962 },
1963 cx,
1964 )
1965 .unwrap();
1966 })?;
1967 Ok(acp::PromptResponse {
1968 stop_reason: acp::StopReason::EndTurn,
1969 })
1970 }
1971 .boxed_local()
1972 },
1973 ));
1974
1975 let thread = cx
1976 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1977 .await
1978 .unwrap();
1979
1980 thread
1981 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1982 .await
1983 .unwrap();
1984
1985 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1986 assert_eq!(
1987 output,
1988 indoc! {r#"
1989 ## User
1990
1991 Hello from Zed!
1992
1993 ## Assistant
1994
1995 <thinking>
1996 Thinking hard!
1997 </thinking>
1998
1999 "#}
2000 );
2001 }
2002
2003 #[gpui::test]
2004 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2005 init_test(cx);
2006
2007 let fs = FakeFs::new(cx.executor());
2008 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2009 .await;
2010 let project = Project::test(fs.clone(), [], cx).await;
2011 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2012 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2013 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2014 move |_, thread, mut cx| {
2015 let read_file_tx = read_file_tx.clone();
2016 async move {
2017 let content = thread
2018 .update(&mut cx, |thread, cx| {
2019 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2020 })
2021 .unwrap()
2022 .await
2023 .unwrap();
2024 assert_eq!(content, "one\ntwo\nthree\n");
2025 read_file_tx.take().unwrap().send(()).unwrap();
2026 thread
2027 .update(&mut cx, |thread, cx| {
2028 thread.write_text_file(
2029 path!("/tmp/foo").into(),
2030 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2031 cx,
2032 )
2033 })
2034 .unwrap()
2035 .await
2036 .unwrap();
2037 Ok(acp::PromptResponse {
2038 stop_reason: acp::StopReason::EndTurn,
2039 })
2040 }
2041 .boxed_local()
2042 },
2043 ));
2044
2045 let (worktree, pathbuf) = project
2046 .update(cx, |project, cx| {
2047 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2048 })
2049 .await
2050 .unwrap();
2051 let buffer = project
2052 .update(cx, |project, cx| {
2053 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2054 })
2055 .await
2056 .unwrap();
2057
2058 let thread = cx
2059 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2060 .await
2061 .unwrap();
2062
2063 let request = thread.update(cx, |thread, cx| {
2064 thread.send_raw("Extend the count in /tmp/foo", cx)
2065 });
2066 read_file_rx.await.ok();
2067 buffer.update(cx, |buffer, cx| {
2068 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2069 });
2070 cx.run_until_parked();
2071 assert_eq!(
2072 buffer.read_with(cx, |buffer, _| buffer.text()),
2073 "zero\none\ntwo\nthree\nfour\nfive\n"
2074 );
2075 assert_eq!(
2076 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2077 "zero\none\ntwo\nthree\nfour\nfive\n"
2078 );
2079 request.await.unwrap();
2080 }
2081
2082 #[gpui::test]
2083 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2084 init_test(cx);
2085
2086 let fs = FakeFs::new(cx.executor());
2087 let project = Project::test(fs, [], cx).await;
2088 let id = acp::ToolCallId("test".into());
2089
2090 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2091 let id = id.clone();
2092 move |_, thread, mut cx| {
2093 let id = id.clone();
2094 async move {
2095 thread
2096 .update(&mut cx, |thread, cx| {
2097 thread.handle_session_update(
2098 acp::SessionUpdate::ToolCall(acp::ToolCall {
2099 id: id.clone(),
2100 title: "Label".into(),
2101 kind: acp::ToolKind::Fetch,
2102 status: acp::ToolCallStatus::InProgress,
2103 content: vec![],
2104 locations: vec![],
2105 raw_input: None,
2106 raw_output: None,
2107 }),
2108 cx,
2109 )
2110 })
2111 .unwrap()
2112 .unwrap();
2113 Ok(acp::PromptResponse {
2114 stop_reason: acp::StopReason::EndTurn,
2115 })
2116 }
2117 .boxed_local()
2118 }
2119 }));
2120
2121 let thread = cx
2122 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2123 .await
2124 .unwrap();
2125
2126 let request = thread.update(cx, |thread, cx| {
2127 thread.send_raw("Fetch https://example.com", cx)
2128 });
2129
2130 run_until_first_tool_call(&thread, cx).await;
2131
2132 thread.read_with(cx, |thread, _| {
2133 assert!(matches!(
2134 thread.entries[1],
2135 AgentThreadEntry::ToolCall(ToolCall {
2136 status: ToolCallStatus::InProgress,
2137 ..
2138 })
2139 ));
2140 });
2141
2142 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2143
2144 thread.read_with(cx, |thread, _| {
2145 assert!(matches!(
2146 &thread.entries[1],
2147 AgentThreadEntry::ToolCall(ToolCall {
2148 status: ToolCallStatus::Canceled,
2149 ..
2150 })
2151 ));
2152 });
2153
2154 thread
2155 .update(cx, |thread, cx| {
2156 thread.handle_session_update(
2157 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2158 id,
2159 fields: acp::ToolCallUpdateFields {
2160 status: Some(acp::ToolCallStatus::Completed),
2161 ..Default::default()
2162 },
2163 }),
2164 cx,
2165 )
2166 })
2167 .unwrap();
2168
2169 request.await.unwrap();
2170
2171 thread.read_with(cx, |thread, _| {
2172 assert!(matches!(
2173 thread.entries[1],
2174 AgentThreadEntry::ToolCall(ToolCall {
2175 status: ToolCallStatus::Completed,
2176 ..
2177 })
2178 ));
2179 });
2180 }
2181
2182 #[gpui::test]
2183 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2184 init_test(cx);
2185 let fs = FakeFs::new(cx.background_executor.clone());
2186 fs.insert_tree(path!("/test"), json!({})).await;
2187 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2188
2189 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2190 move |_, thread, mut cx| {
2191 async move {
2192 thread
2193 .update(&mut cx, |thread, cx| {
2194 thread.handle_session_update(
2195 acp::SessionUpdate::ToolCall(acp::ToolCall {
2196 id: acp::ToolCallId("test".into()),
2197 title: "Label".into(),
2198 kind: acp::ToolKind::Edit,
2199 status: acp::ToolCallStatus::Completed,
2200 content: vec![acp::ToolCallContent::Diff {
2201 diff: acp::Diff {
2202 path: "/test/test.txt".into(),
2203 old_text: None,
2204 new_text: "foo".into(),
2205 },
2206 }],
2207 locations: vec![],
2208 raw_input: None,
2209 raw_output: None,
2210 }),
2211 cx,
2212 )
2213 })
2214 .unwrap()
2215 .unwrap();
2216 Ok(acp::PromptResponse {
2217 stop_reason: acp::StopReason::EndTurn,
2218 })
2219 }
2220 .boxed_local()
2221 }
2222 }));
2223
2224 let thread = cx
2225 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2226 .await
2227 .unwrap();
2228
2229 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2230 .await
2231 .unwrap();
2232
2233 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2234 }
2235
2236 #[gpui::test(iterations = 10)]
2237 async fn test_checkpoints(cx: &mut TestAppContext) {
2238 init_test(cx);
2239 let fs = FakeFs::new(cx.background_executor.clone());
2240 fs.insert_tree(
2241 path!("/test"),
2242 json!({
2243 ".git": {}
2244 }),
2245 )
2246 .await;
2247 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2248
2249 let simulate_changes = Arc::new(AtomicBool::new(true));
2250 let next_filename = Arc::new(AtomicUsize::new(0));
2251 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2252 let simulate_changes = simulate_changes.clone();
2253 let next_filename = next_filename.clone();
2254 let fs = fs.clone();
2255 move |request, thread, mut cx| {
2256 let fs = fs.clone();
2257 let simulate_changes = simulate_changes.clone();
2258 let next_filename = next_filename.clone();
2259 async move {
2260 if simulate_changes.load(SeqCst) {
2261 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2262 fs.write(Path::new(&filename), b"").await?;
2263 }
2264
2265 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2266 panic!("expected text content block");
2267 };
2268 thread.update(&mut cx, |thread, cx| {
2269 thread
2270 .handle_session_update(
2271 acp::SessionUpdate::AgentMessageChunk {
2272 content: content.text.to_uppercase().into(),
2273 },
2274 cx,
2275 )
2276 .unwrap();
2277 })?;
2278 Ok(acp::PromptResponse {
2279 stop_reason: acp::StopReason::EndTurn,
2280 })
2281 }
2282 .boxed_local()
2283 }
2284 }));
2285 let thread = cx
2286 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2287 .await
2288 .unwrap();
2289
2290 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2291 .await
2292 .unwrap();
2293 thread.read_with(cx, |thread, cx| {
2294 assert_eq!(
2295 thread.to_markdown(cx),
2296 indoc! {"
2297 ## User (checkpoint)
2298
2299 Lorem
2300
2301 ## Assistant
2302
2303 LOREM
2304
2305 "}
2306 );
2307 });
2308 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2309
2310 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2311 .await
2312 .unwrap();
2313 thread.read_with(cx, |thread, cx| {
2314 assert_eq!(
2315 thread.to_markdown(cx),
2316 indoc! {"
2317 ## User (checkpoint)
2318
2319 Lorem
2320
2321 ## Assistant
2322
2323 LOREM
2324
2325 ## User (checkpoint)
2326
2327 ipsum
2328
2329 ## Assistant
2330
2331 IPSUM
2332
2333 "}
2334 );
2335 });
2336 assert_eq!(
2337 fs.files(),
2338 vec![
2339 Path::new(path!("/test/file-0")),
2340 Path::new(path!("/test/file-1"))
2341 ]
2342 );
2343
2344 // Checkpoint isn't stored when there are no changes.
2345 simulate_changes.store(false, SeqCst);
2346 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2347 .await
2348 .unwrap();
2349 thread.read_with(cx, |thread, cx| {
2350 assert_eq!(
2351 thread.to_markdown(cx),
2352 indoc! {"
2353 ## User (checkpoint)
2354
2355 Lorem
2356
2357 ## Assistant
2358
2359 LOREM
2360
2361 ## User (checkpoint)
2362
2363 ipsum
2364
2365 ## Assistant
2366
2367 IPSUM
2368
2369 ## User
2370
2371 dolor
2372
2373 ## Assistant
2374
2375 DOLOR
2376
2377 "}
2378 );
2379 });
2380 assert_eq!(
2381 fs.files(),
2382 vec![
2383 Path::new(path!("/test/file-0")),
2384 Path::new(path!("/test/file-1"))
2385 ]
2386 );
2387
2388 // Rewinding the conversation truncates the history and restores the checkpoint.
2389 thread
2390 .update(cx, |thread, cx| {
2391 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2392 panic!("unexpected entries {:?}", thread.entries)
2393 };
2394 thread.rewind(message.id.clone().unwrap(), cx)
2395 })
2396 .await
2397 .unwrap();
2398 thread.read_with(cx, |thread, cx| {
2399 assert_eq!(
2400 thread.to_markdown(cx),
2401 indoc! {"
2402 ## User (checkpoint)
2403
2404 Lorem
2405
2406 ## Assistant
2407
2408 LOREM
2409
2410 "}
2411 );
2412 });
2413 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2414 }
2415
2416 #[gpui::test]
2417 async fn test_refusal(cx: &mut TestAppContext) {
2418 init_test(cx);
2419 let fs = FakeFs::new(cx.background_executor.clone());
2420 fs.insert_tree(path!("/"), json!({})).await;
2421 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2422
2423 let refuse_next = Arc::new(AtomicBool::new(false));
2424 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2425 let refuse_next = refuse_next.clone();
2426 move |request, thread, mut cx| {
2427 let refuse_next = refuse_next.clone();
2428 async move {
2429 if refuse_next.load(SeqCst) {
2430 return Ok(acp::PromptResponse {
2431 stop_reason: acp::StopReason::Refusal,
2432 });
2433 }
2434
2435 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2436 panic!("expected text content block");
2437 };
2438 thread.update(&mut cx, |thread, cx| {
2439 thread
2440 .handle_session_update(
2441 acp::SessionUpdate::AgentMessageChunk {
2442 content: content.text.to_uppercase().into(),
2443 },
2444 cx,
2445 )
2446 .unwrap();
2447 })?;
2448 Ok(acp::PromptResponse {
2449 stop_reason: acp::StopReason::EndTurn,
2450 })
2451 }
2452 .boxed_local()
2453 }
2454 }));
2455 let thread = cx
2456 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2457 .await
2458 .unwrap();
2459
2460 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2461 .await
2462 .unwrap();
2463 thread.read_with(cx, |thread, cx| {
2464 assert_eq!(
2465 thread.to_markdown(cx),
2466 indoc! {"
2467 ## User
2468
2469 hello
2470
2471 ## Assistant
2472
2473 HELLO
2474
2475 "}
2476 );
2477 });
2478
2479 // Simulate refusing the second message, ensuring the conversation gets
2480 // truncated to before sending it.
2481 refuse_next.store(true, SeqCst);
2482 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2483 .await
2484 .unwrap();
2485 thread.read_with(cx, |thread, cx| {
2486 assert_eq!(
2487 thread.to_markdown(cx),
2488 indoc! {"
2489 ## User
2490
2491 hello
2492
2493 ## Assistant
2494
2495 HELLO
2496
2497 "}
2498 );
2499 });
2500 }
2501
2502 async fn run_until_first_tool_call(
2503 thread: &Entity<AcpThread>,
2504 cx: &mut TestAppContext,
2505 ) -> usize {
2506 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2507
2508 let subscription = cx.update(|cx| {
2509 cx.subscribe(thread, move |thread, _, cx| {
2510 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2511 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2512 return tx.try_send(ix).unwrap();
2513 }
2514 }
2515 })
2516 });
2517
2518 select! {
2519 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2520 panic!("Timeout waiting for tool call")
2521 }
2522 ix = rx.next().fuse() => {
2523 drop(subscription);
2524 ix.unwrap()
2525 }
2526 }
2527 }
2528
2529 #[derive(Clone, Default)]
2530 struct FakeAgentConnection {
2531 auth_methods: Vec<acp::AuthMethod>,
2532 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2533 on_user_message: Option<
2534 Rc<
2535 dyn Fn(
2536 acp::PromptRequest,
2537 WeakEntity<AcpThread>,
2538 AsyncApp,
2539 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2540 + 'static,
2541 >,
2542 >,
2543 }
2544
2545 impl FakeAgentConnection {
2546 fn new() -> Self {
2547 Self {
2548 auth_methods: Vec::new(),
2549 on_user_message: None,
2550 sessions: Arc::default(),
2551 }
2552 }
2553
2554 #[expect(unused)]
2555 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2556 self.auth_methods = auth_methods;
2557 self
2558 }
2559
2560 fn on_user_message(
2561 mut self,
2562 handler: impl Fn(
2563 acp::PromptRequest,
2564 WeakEntity<AcpThread>,
2565 AsyncApp,
2566 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2567 + 'static,
2568 ) -> Self {
2569 self.on_user_message.replace(Rc::new(handler));
2570 self
2571 }
2572 }
2573
2574 impl AgentConnection for FakeAgentConnection {
2575 fn auth_methods(&self) -> &[acp::AuthMethod] {
2576 &self.auth_methods
2577 }
2578
2579 fn new_thread(
2580 self: Rc<Self>,
2581 project: Entity<Project>,
2582 _cwd: &Path,
2583 cx: &mut App,
2584 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2585 let session_id = acp::SessionId(
2586 rand::thread_rng()
2587 .sample_iter(&rand::distributions::Alphanumeric)
2588 .take(7)
2589 .map(char::from)
2590 .collect::<String>()
2591 .into(),
2592 );
2593 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2594 let thread = cx.new(|_cx| {
2595 AcpThread::new(
2596 "Test",
2597 self.clone(),
2598 project,
2599 action_log,
2600 session_id.clone(),
2601 )
2602 });
2603 self.sessions.lock().insert(session_id, thread.downgrade());
2604 Task::ready(Ok(thread))
2605 }
2606
2607 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2608 if self.auth_methods().iter().any(|m| m.id == method) {
2609 Task::ready(Ok(()))
2610 } else {
2611 Task::ready(Err(anyhow!("Invalid Auth Method")))
2612 }
2613 }
2614
2615 fn prompt(
2616 &self,
2617 _id: Option<UserMessageId>,
2618 params: acp::PromptRequest,
2619 cx: &mut App,
2620 ) -> Task<gpui::Result<acp::PromptResponse>> {
2621 let sessions = self.sessions.lock();
2622 let thread = sessions.get(¶ms.session_id).unwrap();
2623 if let Some(handler) = &self.on_user_message {
2624 let handler = handler.clone();
2625 let thread = thread.clone();
2626 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2627 } else {
2628 Task::ready(Ok(acp::PromptResponse {
2629 stop_reason: acp::StopReason::EndTurn,
2630 }))
2631 }
2632 }
2633
2634 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
2635 acp::PromptCapabilities {
2636 image: true,
2637 audio: true,
2638 embedded_context: true,
2639 }
2640 }
2641
2642 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2643 let sessions = self.sessions.lock();
2644 let thread = sessions.get(session_id).unwrap().clone();
2645
2646 cx.spawn(async move |cx| {
2647 thread
2648 .update(cx, |thread, cx| thread.cancel(cx))
2649 .unwrap()
2650 .await
2651 })
2652 .detach();
2653 }
2654
2655 fn session_editor(
2656 &self,
2657 session_id: &acp::SessionId,
2658 _cx: &mut App,
2659 ) -> Option<Rc<dyn AgentSessionEditor>> {
2660 Some(Rc::new(FakeAgentSessionEditor {
2661 _session_id: session_id.clone(),
2662 }))
2663 }
2664
2665 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2666 self
2667 }
2668 }
2669
2670 struct FakeAgentSessionEditor {
2671 _session_id: acp::SessionId,
2672 }
2673
2674 impl AgentSessionEditor for FakeAgentSessionEditor {
2675 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2676 Task::ready(Ok(()))
2677 }
2678 }
2679}