1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 Self {
187 id: tool_call.id,
188 label: cx.new(|cx| {
189 Markdown::new(
190 tool_call.title.into(),
191 Some(language_registry.clone()),
192 None,
193 cx,
194 )
195 }),
196 kind: tool_call.kind,
197 content: tool_call
198 .content
199 .into_iter()
200 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
201 .collect(),
202 locations: tool_call.locations,
203 resolved_locations: Vec::default(),
204 status,
205 raw_input: tool_call.raw_input,
206 raw_output: tool_call.raw_output,
207 }
208 }
209
210 fn update_fields(
211 &mut self,
212 fields: acp::ToolCallUpdateFields,
213 language_registry: Arc<LanguageRegistry>,
214 cx: &mut App,
215 ) {
216 let acp::ToolCallUpdateFields {
217 kind,
218 status,
219 title,
220 content,
221 locations,
222 raw_input,
223 raw_output,
224 } = fields;
225
226 if let Some(kind) = kind {
227 self.kind = kind;
228 }
229
230 if let Some(status) = status {
231 self.status = status.into();
232 }
233
234 if let Some(title) = title {
235 self.label.update(cx, |label, cx| {
236 label.replace(title, cx);
237 });
238 }
239
240 if let Some(content) = content {
241 let new_content_len = content.len();
242 let mut content = content.into_iter();
243
244 // Reuse existing content if we can
245 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
246 old.update_from_acp(new, language_registry.clone(), cx);
247 }
248 for new in content {
249 self.content.push(ToolCallContent::from_acp(
250 new,
251 language_registry.clone(),
252 cx,
253 ))
254 }
255 self.content.truncate(new_content_len);
256 }
257
258 if let Some(locations) = locations {
259 self.locations = locations;
260 }
261
262 if let Some(raw_input) = raw_input {
263 self.raw_input = Some(raw_input);
264 }
265
266 if let Some(raw_output) = raw_output {
267 if self.content.is_empty()
268 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
269 {
270 self.content
271 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
272 markdown,
273 }));
274 }
275 self.raw_output = Some(raw_output);
276 }
277 }
278
279 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
280 self.content.iter().filter_map(|content| match content {
281 ToolCallContent::Diff(diff) => Some(diff),
282 ToolCallContent::ContentBlock(_) => None,
283 ToolCallContent::Terminal(_) => None,
284 })
285 }
286
287 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
288 self.content.iter().filter_map(|content| match content {
289 ToolCallContent::Terminal(terminal) => Some(terminal),
290 ToolCallContent::ContentBlock(_) => None,
291 ToolCallContent::Diff(_) => None,
292 })
293 }
294
295 fn to_markdown(&self, cx: &App) -> String {
296 let mut markdown = format!(
297 "**Tool Call: {}**\nStatus: {}\n\n",
298 self.label.read(cx).source(),
299 self.status
300 );
301 for content in &self.content {
302 markdown.push_str(content.to_markdown(cx).as_str());
303 markdown.push_str("\n\n");
304 }
305 markdown
306 }
307
308 async fn resolve_location(
309 location: acp::ToolCallLocation,
310 project: WeakEntity<Project>,
311 cx: &mut AsyncApp,
312 ) -> Option<AgentLocation> {
313 let buffer = project
314 .update(cx, |project, cx| {
315 project
316 .project_path_for_absolute_path(&location.path, cx)
317 .map(|path| project.open_buffer(path, cx))
318 })
319 .ok()??;
320 let buffer = buffer.await.log_err()?;
321 let position = buffer
322 .update(cx, |buffer, _| {
323 if let Some(row) = location.line {
324 let snapshot = buffer.snapshot();
325 let column = snapshot.indent_size_for_line(row).len;
326 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
327 snapshot.anchor_before(point)
328 } else {
329 Anchor::MIN
330 }
331 })
332 .ok()?;
333
334 Some(AgentLocation {
335 buffer: buffer.downgrade(),
336 position,
337 })
338 }
339
340 fn resolve_locations(
341 &self,
342 project: Entity<Project>,
343 cx: &mut App,
344 ) -> Task<Vec<Option<AgentLocation>>> {
345 let locations = self.locations.clone();
346 project.update(cx, |_, cx| {
347 cx.spawn(async move |project, cx| {
348 let mut new_locations = Vec::new();
349 for location in locations {
350 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
351 }
352 new_locations
353 })
354 })
355 }
356}
357
358#[derive(Debug)]
359pub enum ToolCallStatus {
360 /// The tool call hasn't started running yet, but we start showing it to
361 /// the user.
362 Pending,
363 /// The tool call is waiting for confirmation from the user.
364 WaitingForConfirmation {
365 options: Vec<acp::PermissionOption>,
366 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
367 },
368 /// The tool call is currently running.
369 InProgress,
370 /// The tool call completed successfully.
371 Completed,
372 /// The tool call failed.
373 Failed,
374 /// The user rejected the tool call.
375 Rejected,
376 /// The user canceled generation so the tool call was canceled.
377 Canceled,
378}
379
380impl From<acp::ToolCallStatus> for ToolCallStatus {
381 fn from(status: acp::ToolCallStatus) -> Self {
382 match status {
383 acp::ToolCallStatus::Pending => Self::Pending,
384 acp::ToolCallStatus::InProgress => Self::InProgress,
385 acp::ToolCallStatus::Completed => Self::Completed,
386 acp::ToolCallStatus::Failed => Self::Failed,
387 }
388 }
389}
390
391impl Display for ToolCallStatus {
392 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393 write!(
394 f,
395 "{}",
396 match self {
397 ToolCallStatus::Pending => "Pending",
398 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
399 ToolCallStatus::InProgress => "In Progress",
400 ToolCallStatus::Completed => "Completed",
401 ToolCallStatus::Failed => "Failed",
402 ToolCallStatus::Rejected => "Rejected",
403 ToolCallStatus::Canceled => "Canceled",
404 }
405 )
406 }
407}
408
409#[derive(Debug, PartialEq, Clone)]
410pub enum ContentBlock {
411 Empty,
412 Markdown { markdown: Entity<Markdown> },
413 ResourceLink { resource_link: acp::ResourceLink },
414}
415
416impl ContentBlock {
417 pub fn new(
418 block: acp::ContentBlock,
419 language_registry: &Arc<LanguageRegistry>,
420 cx: &mut App,
421 ) -> Self {
422 let mut this = Self::Empty;
423 this.append(block, language_registry, cx);
424 this
425 }
426
427 pub fn new_combined(
428 blocks: impl IntoIterator<Item = acp::ContentBlock>,
429 language_registry: Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) -> Self {
432 let mut this = Self::Empty;
433 for block in blocks {
434 this.append(block, &language_registry, cx);
435 }
436 this
437 }
438
439 pub fn append(
440 &mut self,
441 block: acp::ContentBlock,
442 language_registry: &Arc<LanguageRegistry>,
443 cx: &mut App,
444 ) {
445 if matches!(self, ContentBlock::Empty)
446 && let acp::ContentBlock::ResourceLink(resource_link) = block
447 {
448 *self = ContentBlock::ResourceLink { resource_link };
449 return;
450 }
451
452 let new_content = self.block_string_contents(block);
453
454 match self {
455 ContentBlock::Empty => {
456 *self = Self::create_markdown_block(new_content, language_registry, cx);
457 }
458 ContentBlock::Markdown { markdown } => {
459 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
460 }
461 ContentBlock::ResourceLink { resource_link } => {
462 let existing_content = Self::resource_link_md(&resource_link.uri);
463 let combined = format!("{}\n{}", existing_content, new_content);
464
465 *self = Self::create_markdown_block(combined, language_registry, cx);
466 }
467 }
468 }
469
470 fn create_markdown_block(
471 content: String,
472 language_registry: &Arc<LanguageRegistry>,
473 cx: &mut App,
474 ) -> ContentBlock {
475 ContentBlock::Markdown {
476 markdown: cx
477 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
478 }
479 }
480
481 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
482 match block {
483 acp::ContentBlock::Text(text_content) => text_content.text,
484 acp::ContentBlock::ResourceLink(resource_link) => {
485 Self::resource_link_md(&resource_link.uri)
486 }
487 acp::ContentBlock::Resource(acp::EmbeddedResource {
488 resource:
489 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
490 uri,
491 ..
492 }),
493 ..
494 }) => Self::resource_link_md(&uri),
495 acp::ContentBlock::Image(image) => Self::image_md(&image),
496 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
497 }
498 }
499
500 fn resource_link_md(uri: &str) -> String {
501 if let Some(uri) = MentionUri::parse(uri).log_err() {
502 uri.as_link().to_string()
503 } else {
504 uri.to_string()
505 }
506 }
507
508 fn image_md(_image: &acp::ImageContent) -> String {
509 "`Image`".into()
510 }
511
512 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
513 match self {
514 ContentBlock::Empty => "",
515 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
516 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
517 }
518 }
519
520 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
521 match self {
522 ContentBlock::Empty => None,
523 ContentBlock::Markdown { markdown } => Some(markdown),
524 ContentBlock::ResourceLink { .. } => None,
525 }
526 }
527
528 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
529 match self {
530 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
531 _ => None,
532 }
533 }
534}
535
536#[derive(Debug)]
537pub enum ToolCallContent {
538 ContentBlock(ContentBlock),
539 Diff(Entity<Diff>),
540 Terminal(Entity<Terminal>),
541}
542
543impl ToolCallContent {
544 pub fn from_acp(
545 content: acp::ToolCallContent,
546 language_registry: Arc<LanguageRegistry>,
547 cx: &mut App,
548 ) -> Self {
549 match content {
550 acp::ToolCallContent::Content { content } => {
551 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
552 }
553 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
554 Diff::finalized(
555 diff.path,
556 diff.old_text,
557 diff.new_text,
558 language_registry,
559 cx,
560 )
561 })),
562 }
563 }
564
565 pub fn update_from_acp(
566 &mut self,
567 new: acp::ToolCallContent,
568 language_registry: Arc<LanguageRegistry>,
569 cx: &mut App,
570 ) {
571 let needs_update = match (&self, &new) {
572 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
573 old_diff.read(cx).needs_update(
574 new_diff.old_text.as_deref().unwrap_or(""),
575 &new_diff.new_text,
576 cx,
577 )
578 }
579 _ => true,
580 };
581
582 if needs_update {
583 *self = Self::from_acp(new, language_registry, cx);
584 }
585 }
586
587 pub fn to_markdown(&self, cx: &App) -> String {
588 match self {
589 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
590 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
591 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
592 }
593 }
594}
595
596#[derive(Debug, PartialEq)]
597pub enum ToolCallUpdate {
598 UpdateFields(acp::ToolCallUpdate),
599 UpdateDiff(ToolCallUpdateDiff),
600 UpdateTerminal(ToolCallUpdateTerminal),
601}
602
603impl ToolCallUpdate {
604 fn id(&self) -> &acp::ToolCallId {
605 match self {
606 Self::UpdateFields(update) => &update.id,
607 Self::UpdateDiff(diff) => &diff.id,
608 Self::UpdateTerminal(terminal) => &terminal.id,
609 }
610 }
611}
612
613impl From<acp::ToolCallUpdate> for ToolCallUpdate {
614 fn from(update: acp::ToolCallUpdate) -> Self {
615 Self::UpdateFields(update)
616 }
617}
618
619impl From<ToolCallUpdateDiff> for ToolCallUpdate {
620 fn from(diff: ToolCallUpdateDiff) -> Self {
621 Self::UpdateDiff(diff)
622 }
623}
624
625#[derive(Debug, PartialEq)]
626pub struct ToolCallUpdateDiff {
627 pub id: acp::ToolCallId,
628 pub diff: Entity<Diff>,
629}
630
631impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
632 fn from(terminal: ToolCallUpdateTerminal) -> Self {
633 Self::UpdateTerminal(terminal)
634 }
635}
636
637#[derive(Debug, PartialEq)]
638pub struct ToolCallUpdateTerminal {
639 pub id: acp::ToolCallId,
640 pub terminal: Entity<Terminal>,
641}
642
643#[derive(Debug, Default)]
644pub struct Plan {
645 pub entries: Vec<PlanEntry>,
646}
647
648#[derive(Debug)]
649pub struct PlanStats<'a> {
650 pub in_progress_entry: Option<&'a PlanEntry>,
651 pub pending: u32,
652 pub completed: u32,
653}
654
655impl Plan {
656 pub fn is_empty(&self) -> bool {
657 self.entries.is_empty()
658 }
659
660 pub fn stats(&self) -> PlanStats<'_> {
661 let mut stats = PlanStats {
662 in_progress_entry: None,
663 pending: 0,
664 completed: 0,
665 };
666
667 for entry in &self.entries {
668 match &entry.status {
669 acp::PlanEntryStatus::Pending => {
670 stats.pending += 1;
671 }
672 acp::PlanEntryStatus::InProgress => {
673 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
674 }
675 acp::PlanEntryStatus::Completed => {
676 stats.completed += 1;
677 }
678 }
679 }
680
681 stats
682 }
683}
684
685#[derive(Debug)]
686pub struct PlanEntry {
687 pub content: Entity<Markdown>,
688 pub priority: acp::PlanEntryPriority,
689 pub status: acp::PlanEntryStatus,
690}
691
692impl PlanEntry {
693 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
694 Self {
695 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
696 priority: entry.priority,
697 status: entry.status,
698 }
699 }
700}
701
702#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
703pub struct TokenUsage {
704 pub max_tokens: u64,
705 pub used_tokens: u64,
706}
707
708impl TokenUsage {
709 pub fn ratio(&self) -> TokenUsageRatio {
710 #[cfg(debug_assertions)]
711 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
712 .unwrap_or("0.8".to_string())
713 .parse()
714 .unwrap();
715 #[cfg(not(debug_assertions))]
716 let warning_threshold: f32 = 0.8;
717
718 // When the maximum is unknown because there is no selected model,
719 // avoid showing the token limit warning.
720 if self.max_tokens == 0 {
721 TokenUsageRatio::Normal
722 } else if self.used_tokens >= self.max_tokens {
723 TokenUsageRatio::Exceeded
724 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
725 TokenUsageRatio::Warning
726 } else {
727 TokenUsageRatio::Normal
728 }
729 }
730}
731
732#[derive(Debug, Clone, PartialEq, Eq)]
733pub enum TokenUsageRatio {
734 Normal,
735 Warning,
736 Exceeded,
737}
738
739#[derive(Debug, Clone)]
740pub struct RetryStatus {
741 pub last_error: SharedString,
742 pub attempt: usize,
743 pub max_attempts: usize,
744 pub started_at: Instant,
745 pub duration: Duration,
746}
747
748pub struct AcpThread {
749 title: SharedString,
750 entries: Vec<AgentThreadEntry>,
751 plan: Plan,
752 project: Entity<Project>,
753 action_log: Entity<ActionLog>,
754 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
755 send_task: Option<Task<()>>,
756 connection: Rc<dyn AgentConnection>,
757 session_id: acp::SessionId,
758 token_usage: Option<TokenUsage>,
759}
760
761#[derive(Debug)]
762pub enum AcpThreadEvent {
763 NewEntry,
764 TitleUpdated,
765 TokenUsageUpdated,
766 EntryUpdated(usize),
767 EntriesRemoved(Range<usize>),
768 ToolAuthorizationRequired,
769 Retry(RetryStatus),
770 Stopped,
771 Error,
772 LoadError(LoadError),
773}
774
775impl EventEmitter<AcpThreadEvent> for AcpThread {}
776
777#[derive(PartialEq, Eq, Debug)]
778pub enum ThreadStatus {
779 Idle,
780 WaitingForToolConfirmation,
781 Generating,
782}
783
784#[derive(Debug, Clone)]
785pub enum LoadError {
786 NotInstalled {
787 error_message: SharedString,
788 install_message: SharedString,
789 install_command: String,
790 },
791 Unsupported {
792 error_message: SharedString,
793 upgrade_message: SharedString,
794 upgrade_command: String,
795 },
796 Exited {
797 status: ExitStatus,
798 },
799 Other(SharedString),
800}
801
802impl Display for LoadError {
803 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
804 match self {
805 LoadError::NotInstalled { error_message, .. }
806 | LoadError::Unsupported { error_message, .. } => {
807 write!(f, "{error_message}")
808 }
809 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
810 LoadError::Other(msg) => write!(f, "{}", msg),
811 }
812 }
813}
814
815impl Error for LoadError {}
816
817impl AcpThread {
818 pub fn new(
819 title: impl Into<SharedString>,
820 connection: Rc<dyn AgentConnection>,
821 project: Entity<Project>,
822 action_log: Entity<ActionLog>,
823 session_id: acp::SessionId,
824 ) -> Self {
825 Self {
826 action_log,
827 shared_buffers: Default::default(),
828 entries: Default::default(),
829 plan: Default::default(),
830 title: title.into(),
831 project,
832 send_task: None,
833 connection,
834 session_id,
835 token_usage: None,
836 }
837 }
838
839 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
840 &self.connection
841 }
842
843 pub fn action_log(&self) -> &Entity<ActionLog> {
844 &self.action_log
845 }
846
847 pub fn project(&self) -> &Entity<Project> {
848 &self.project
849 }
850
851 pub fn title(&self) -> SharedString {
852 self.title.clone()
853 }
854
855 pub fn entries(&self) -> &[AgentThreadEntry] {
856 &self.entries
857 }
858
859 pub fn session_id(&self) -> &acp::SessionId {
860 &self.session_id
861 }
862
863 pub fn status(&self) -> ThreadStatus {
864 if self.send_task.is_some() {
865 if self.waiting_for_tool_confirmation() {
866 ThreadStatus::WaitingForToolConfirmation
867 } else {
868 ThreadStatus::Generating
869 }
870 } else {
871 ThreadStatus::Idle
872 }
873 }
874
875 pub fn token_usage(&self) -> Option<&TokenUsage> {
876 self.token_usage.as_ref()
877 }
878
879 pub fn has_pending_edit_tool_calls(&self) -> bool {
880 for entry in self.entries.iter().rev() {
881 match entry {
882 AgentThreadEntry::UserMessage(_) => return false,
883 AgentThreadEntry::ToolCall(
884 call @ ToolCall {
885 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
886 ..
887 },
888 ) if call.diffs().next().is_some() => {
889 return true;
890 }
891 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
892 }
893 }
894
895 false
896 }
897
898 pub fn used_tools_since_last_user_message(&self) -> bool {
899 for entry in self.entries.iter().rev() {
900 match entry {
901 AgentThreadEntry::UserMessage(..) => return false,
902 AgentThreadEntry::AssistantMessage(..) => continue,
903 AgentThreadEntry::ToolCall(..) => return true,
904 }
905 }
906
907 false
908 }
909
910 pub fn handle_session_update(
911 &mut self,
912 update: acp::SessionUpdate,
913 cx: &mut Context<Self>,
914 ) -> Result<(), acp::Error> {
915 match update {
916 acp::SessionUpdate::UserMessageChunk { content } => {
917 self.push_user_content_block(None, content, cx);
918 }
919 acp::SessionUpdate::AgentMessageChunk { content } => {
920 self.push_assistant_content_block(content, false, cx);
921 }
922 acp::SessionUpdate::AgentThoughtChunk { content } => {
923 self.push_assistant_content_block(content, true, cx);
924 }
925 acp::SessionUpdate::ToolCall(tool_call) => {
926 self.upsert_tool_call(tool_call, cx)?;
927 }
928 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
929 self.update_tool_call(tool_call_update, cx)?;
930 }
931 acp::SessionUpdate::Plan(plan) => {
932 self.update_plan(plan, cx);
933 }
934 }
935 Ok(())
936 }
937
938 pub fn push_user_content_block(
939 &mut self,
940 message_id: Option<UserMessageId>,
941 chunk: acp::ContentBlock,
942 cx: &mut Context<Self>,
943 ) {
944 let language_registry = self.project.read(cx).languages().clone();
945 let entries_len = self.entries.len();
946
947 if let Some(last_entry) = self.entries.last_mut()
948 && let AgentThreadEntry::UserMessage(UserMessage {
949 id,
950 content,
951 chunks,
952 ..
953 }) = last_entry
954 {
955 *id = message_id.or(id.take());
956 content.append(chunk.clone(), &language_registry, cx);
957 chunks.push(chunk);
958 let idx = entries_len - 1;
959 cx.emit(AcpThreadEvent::EntryUpdated(idx));
960 } else {
961 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
962 self.push_entry(
963 AgentThreadEntry::UserMessage(UserMessage {
964 id: message_id,
965 content,
966 chunks: vec![chunk],
967 checkpoint: None,
968 }),
969 cx,
970 );
971 }
972 }
973
974 pub fn push_assistant_content_block(
975 &mut self,
976 chunk: acp::ContentBlock,
977 is_thought: bool,
978 cx: &mut Context<Self>,
979 ) {
980 let language_registry = self.project.read(cx).languages().clone();
981 let entries_len = self.entries.len();
982 if let Some(last_entry) = self.entries.last_mut()
983 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
984 {
985 let idx = entries_len - 1;
986 cx.emit(AcpThreadEvent::EntryUpdated(idx));
987 match (chunks.last_mut(), is_thought) {
988 (Some(AssistantMessageChunk::Message { block }), false)
989 | (Some(AssistantMessageChunk::Thought { block }), true) => {
990 block.append(chunk, &language_registry, cx)
991 }
992 _ => {
993 let block = ContentBlock::new(chunk, &language_registry, cx);
994 if is_thought {
995 chunks.push(AssistantMessageChunk::Thought { block })
996 } else {
997 chunks.push(AssistantMessageChunk::Message { block })
998 }
999 }
1000 }
1001 } else {
1002 let block = ContentBlock::new(chunk, &language_registry, cx);
1003 let chunk = if is_thought {
1004 AssistantMessageChunk::Thought { block }
1005 } else {
1006 AssistantMessageChunk::Message { block }
1007 };
1008
1009 self.push_entry(
1010 AgentThreadEntry::AssistantMessage(AssistantMessage {
1011 chunks: vec![chunk],
1012 }),
1013 cx,
1014 );
1015 }
1016 }
1017
1018 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1019 self.entries.push(entry);
1020 cx.emit(AcpThreadEvent::NewEntry);
1021 }
1022
1023 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1024 self.connection.set_title(&self.session_id, cx).is_some()
1025 }
1026
1027 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1028 if title != self.title {
1029 self.title = title.clone();
1030 cx.emit(AcpThreadEvent::TitleUpdated);
1031 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1032 return set_title.run(title, cx);
1033 }
1034 }
1035 Task::ready(Ok(()))
1036 }
1037
1038 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1039 self.token_usage = usage;
1040 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1041 }
1042
1043 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1044 cx.emit(AcpThreadEvent::Retry(status));
1045 }
1046
1047 pub fn update_tool_call(
1048 &mut self,
1049 update: impl Into<ToolCallUpdate>,
1050 cx: &mut Context<Self>,
1051 ) -> Result<()> {
1052 let update = update.into();
1053 let languages = self.project.read(cx).languages().clone();
1054
1055 let (ix, current_call) = self
1056 .tool_call_mut(update.id())
1057 .context("Tool call not found")?;
1058 match update {
1059 ToolCallUpdate::UpdateFields(update) => {
1060 let location_updated = update.fields.locations.is_some();
1061 current_call.update_fields(update.fields, languages, cx);
1062 if location_updated {
1063 self.resolve_locations(update.id, cx);
1064 }
1065 }
1066 ToolCallUpdate::UpdateDiff(update) => {
1067 current_call.content.clear();
1068 current_call
1069 .content
1070 .push(ToolCallContent::Diff(update.diff));
1071 }
1072 ToolCallUpdate::UpdateTerminal(update) => {
1073 current_call.content.clear();
1074 current_call
1075 .content
1076 .push(ToolCallContent::Terminal(update.terminal));
1077 }
1078 }
1079
1080 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1081
1082 Ok(())
1083 }
1084
1085 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1086 pub fn upsert_tool_call(
1087 &mut self,
1088 tool_call: acp::ToolCall,
1089 cx: &mut Context<Self>,
1090 ) -> Result<(), acp::Error> {
1091 let status = tool_call.status.into();
1092 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1093 }
1094
1095 /// Fails if id does not match an existing entry.
1096 pub fn upsert_tool_call_inner(
1097 &mut self,
1098 tool_call_update: acp::ToolCallUpdate,
1099 status: ToolCallStatus,
1100 cx: &mut Context<Self>,
1101 ) -> Result<(), acp::Error> {
1102 let language_registry = self.project.read(cx).languages().clone();
1103 let id = tool_call_update.id.clone();
1104
1105 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1106 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1107 current_call.status = status;
1108
1109 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1110 } else {
1111 let call =
1112 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1113 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1114 };
1115
1116 self.resolve_locations(id, cx);
1117 Ok(())
1118 }
1119
1120 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1121 // The tool call we are looking for is typically the last one, or very close to the end.
1122 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1123 self.entries
1124 .iter_mut()
1125 .enumerate()
1126 .rev()
1127 .find_map(|(index, tool_call)| {
1128 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1129 && &tool_call.id == id
1130 {
1131 Some((index, tool_call))
1132 } else {
1133 None
1134 }
1135 })
1136 }
1137
1138 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1139 self.entries
1140 .iter()
1141 .enumerate()
1142 .rev()
1143 .find_map(|(index, tool_call)| {
1144 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1145 && &tool_call.id == id
1146 {
1147 Some((index, tool_call))
1148 } else {
1149 None
1150 }
1151 })
1152 }
1153
1154 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1155 let project = self.project.clone();
1156 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1157 return;
1158 };
1159 let task = tool_call.resolve_locations(project, cx);
1160 cx.spawn(async move |this, cx| {
1161 let resolved_locations = task.await;
1162 this.update(cx, |this, cx| {
1163 let project = this.project.clone();
1164 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1165 return;
1166 };
1167 if let Some(Some(location)) = resolved_locations.last() {
1168 project.update(cx, |project, cx| {
1169 if let Some(agent_location) = project.agent_location() {
1170 let should_ignore = agent_location.buffer == location.buffer
1171 && location
1172 .buffer
1173 .update(cx, |buffer, _| {
1174 let snapshot = buffer.snapshot();
1175 let old_position =
1176 agent_location.position.to_point(&snapshot);
1177 let new_position = location.position.to_point(&snapshot);
1178 // ignore this so that when we get updates from the edit tool
1179 // the position doesn't reset to the startof line
1180 old_position.row == new_position.row
1181 && old_position.column > new_position.column
1182 })
1183 .ok()
1184 .unwrap_or_default();
1185 if !should_ignore {
1186 project.set_agent_location(Some(location.clone()), cx);
1187 }
1188 }
1189 });
1190 }
1191 if tool_call.resolved_locations != resolved_locations {
1192 tool_call.resolved_locations = resolved_locations;
1193 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1194 }
1195 })
1196 })
1197 .detach();
1198 }
1199
1200 pub fn request_tool_call_authorization(
1201 &mut self,
1202 tool_call: acp::ToolCallUpdate,
1203 options: Vec<acp::PermissionOption>,
1204 cx: &mut Context<Self>,
1205 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1206 let (tx, rx) = oneshot::channel();
1207
1208 let status = ToolCallStatus::WaitingForConfirmation {
1209 options,
1210 respond_tx: tx,
1211 };
1212
1213 self.upsert_tool_call_inner(tool_call, status, cx)?;
1214 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1215 Ok(rx)
1216 }
1217
1218 pub fn authorize_tool_call(
1219 &mut self,
1220 id: acp::ToolCallId,
1221 option_id: acp::PermissionOptionId,
1222 option_kind: acp::PermissionOptionKind,
1223 cx: &mut Context<Self>,
1224 ) {
1225 let Some((ix, call)) = self.tool_call_mut(&id) else {
1226 return;
1227 };
1228
1229 let new_status = match option_kind {
1230 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1231 ToolCallStatus::Rejected
1232 }
1233 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1234 ToolCallStatus::InProgress
1235 }
1236 };
1237
1238 let curr_status = mem::replace(&mut call.status, new_status);
1239
1240 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1241 respond_tx.send(option_id).log_err();
1242 } else if cfg!(debug_assertions) {
1243 panic!("tried to authorize an already authorized tool call");
1244 }
1245
1246 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1247 }
1248
1249 /// Returns true if the last turn is awaiting tool authorization
1250 pub fn waiting_for_tool_confirmation(&self) -> bool {
1251 for entry in self.entries.iter().rev() {
1252 match &entry {
1253 AgentThreadEntry::ToolCall(call) => match call.status {
1254 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1255 ToolCallStatus::Pending
1256 | ToolCallStatus::InProgress
1257 | ToolCallStatus::Completed
1258 | ToolCallStatus::Failed
1259 | ToolCallStatus::Rejected
1260 | ToolCallStatus::Canceled => continue,
1261 },
1262 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1263 // Reached the beginning of the turn
1264 return false;
1265 }
1266 }
1267 }
1268 false
1269 }
1270
1271 pub fn plan(&self) -> &Plan {
1272 &self.plan
1273 }
1274
1275 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1276 let new_entries_len = request.entries.len();
1277 let mut new_entries = request.entries.into_iter();
1278
1279 // Reuse existing markdown to prevent flickering
1280 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1281 let PlanEntry {
1282 content,
1283 priority,
1284 status,
1285 } = old;
1286 content.update(cx, |old, cx| {
1287 old.replace(new.content, cx);
1288 });
1289 *priority = new.priority;
1290 *status = new.status;
1291 }
1292 for new in new_entries {
1293 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1294 }
1295 self.plan.entries.truncate(new_entries_len);
1296
1297 cx.notify();
1298 }
1299
1300 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1301 self.plan
1302 .entries
1303 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1304 cx.notify();
1305 }
1306
1307 #[cfg(any(test, feature = "test-support"))]
1308 pub fn send_raw(
1309 &mut self,
1310 message: &str,
1311 cx: &mut Context<Self>,
1312 ) -> BoxFuture<'static, Result<()>> {
1313 self.send(
1314 vec![acp::ContentBlock::Text(acp::TextContent {
1315 text: message.to_string(),
1316 annotations: None,
1317 })],
1318 cx,
1319 )
1320 }
1321
1322 pub fn send(
1323 &mut self,
1324 message: Vec<acp::ContentBlock>,
1325 cx: &mut Context<Self>,
1326 ) -> BoxFuture<'static, Result<()>> {
1327 let block = ContentBlock::new_combined(
1328 message.clone(),
1329 self.project.read(cx).languages().clone(),
1330 cx,
1331 );
1332 let request = acp::PromptRequest {
1333 prompt: message.clone(),
1334 session_id: self.session_id.clone(),
1335 };
1336 let git_store = self.project.read(cx).git_store().clone();
1337
1338 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1339 Some(UserMessageId::new())
1340 } else {
1341 None
1342 };
1343
1344 self.run_turn(cx, async move |this, cx| {
1345 this.update(cx, |this, cx| {
1346 this.push_entry(
1347 AgentThreadEntry::UserMessage(UserMessage {
1348 id: message_id.clone(),
1349 content: block,
1350 chunks: message,
1351 checkpoint: None,
1352 }),
1353 cx,
1354 );
1355 })
1356 .ok();
1357
1358 let old_checkpoint = git_store
1359 .update(cx, |git, cx| git.checkpoint(cx))?
1360 .await
1361 .context("failed to get old checkpoint")
1362 .log_err();
1363 this.update(cx, |this, cx| {
1364 if let Some((_ix, message)) = this.last_user_message() {
1365 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1366 git_checkpoint,
1367 show: false,
1368 });
1369 }
1370 this.connection.prompt(message_id, request, cx)
1371 })?
1372 .await
1373 })
1374 }
1375
1376 pub fn can_resume(&self, cx: &App) -> bool {
1377 self.connection.resume(&self.session_id, cx).is_some()
1378 }
1379
1380 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1381 self.run_turn(cx, async move |this, cx| {
1382 this.update(cx, |this, cx| {
1383 this.connection
1384 .resume(&this.session_id, cx)
1385 .map(|resume| resume.run(cx))
1386 })?
1387 .context("resuming a session is not supported")?
1388 .await
1389 })
1390 }
1391
1392 fn run_turn(
1393 &mut self,
1394 cx: &mut Context<Self>,
1395 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1396 ) -> BoxFuture<'static, Result<()>> {
1397 self.clear_completed_plan_entries(cx);
1398
1399 let (tx, rx) = oneshot::channel();
1400 let cancel_task = self.cancel(cx);
1401
1402 self.send_task = Some(cx.spawn(async move |this, cx| {
1403 cancel_task.await;
1404 tx.send(f(this, cx).await).ok();
1405 }));
1406
1407 cx.spawn(async move |this, cx| {
1408 let response = rx.await;
1409
1410 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1411 .await?;
1412
1413 this.update(cx, |this, cx| {
1414 this.project
1415 .update(cx, |project, cx| project.set_agent_location(None, cx));
1416 match response {
1417 Ok(Err(e)) => {
1418 this.send_task.take();
1419 cx.emit(AcpThreadEvent::Error);
1420 Err(e)
1421 }
1422 result => {
1423 let canceled = matches!(
1424 result,
1425 Ok(Ok(acp::PromptResponse {
1426 stop_reason: acp::StopReason::Cancelled
1427 }))
1428 );
1429
1430 // We only take the task if the current prompt wasn't canceled.
1431 //
1432 // This prompt may have been canceled because another one was sent
1433 // while it was still generating. In these cases, dropping `send_task`
1434 // would cause the next generation to be canceled.
1435 if !canceled {
1436 this.send_task.take();
1437 }
1438
1439 // Truncate entries if the last prompt was refused.
1440 if let Ok(Ok(acp::PromptResponse {
1441 stop_reason: acp::StopReason::Refusal,
1442 })) = result
1443 && let Some((ix, _)) = this.last_user_message()
1444 {
1445 let range = ix..this.entries.len();
1446 this.entries.truncate(ix);
1447 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1448 }
1449
1450 cx.emit(AcpThreadEvent::Stopped);
1451 Ok(())
1452 }
1453 }
1454 })?
1455 })
1456 .boxed()
1457 }
1458
1459 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1460 let Some(send_task) = self.send_task.take() else {
1461 return Task::ready(());
1462 };
1463
1464 for entry in self.entries.iter_mut() {
1465 if let AgentThreadEntry::ToolCall(call) = entry {
1466 let cancel = matches!(
1467 call.status,
1468 ToolCallStatus::Pending
1469 | ToolCallStatus::WaitingForConfirmation { .. }
1470 | ToolCallStatus::InProgress
1471 );
1472
1473 if cancel {
1474 call.status = ToolCallStatus::Canceled;
1475 }
1476 }
1477 }
1478
1479 self.connection.cancel(&self.session_id, cx);
1480
1481 // Wait for the send task to complete
1482 cx.foreground_executor().spawn(send_task)
1483 }
1484
1485 /// Rewinds this thread to before the entry at `index`, removing it and all
1486 /// subsequent entries while reverting any changes made from that point.
1487 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1488 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1489 return Task::ready(Err(anyhow!("not supported")));
1490 };
1491 let Some(message) = self.user_message(&id) else {
1492 return Task::ready(Err(anyhow!("message not found")));
1493 };
1494
1495 let checkpoint = message
1496 .checkpoint
1497 .as_ref()
1498 .map(|c| c.git_checkpoint.clone());
1499
1500 let git_store = self.project.read(cx).git_store().clone();
1501 cx.spawn(async move |this, cx| {
1502 if let Some(checkpoint) = checkpoint {
1503 git_store
1504 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1505 .await?;
1506 }
1507
1508 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1509 this.update(cx, |this, cx| {
1510 if let Some((ix, _)) = this.user_message_mut(&id) {
1511 let range = ix..this.entries.len();
1512 this.entries.truncate(ix);
1513 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1514 }
1515 })
1516 })
1517 }
1518
1519 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1520 let git_store = self.project.read(cx).git_store().clone();
1521
1522 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1523 if let Some(checkpoint) = message.checkpoint.as_ref() {
1524 checkpoint.git_checkpoint.clone()
1525 } else {
1526 return Task::ready(Ok(()));
1527 }
1528 } else {
1529 return Task::ready(Ok(()));
1530 };
1531
1532 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1533 cx.spawn(async move |this, cx| {
1534 let new_checkpoint = new_checkpoint
1535 .await
1536 .context("failed to get new checkpoint")
1537 .log_err();
1538 if let Some(new_checkpoint) = new_checkpoint {
1539 let equal = git_store
1540 .update(cx, |git, cx| {
1541 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1542 })?
1543 .await
1544 .unwrap_or(true);
1545 this.update(cx, |this, cx| {
1546 let (ix, message) = this.last_user_message().context("no user message")?;
1547 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1548 checkpoint.show = !equal;
1549 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1550 anyhow::Ok(())
1551 })??;
1552 }
1553
1554 Ok(())
1555 })
1556 }
1557
1558 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1559 self.entries
1560 .iter_mut()
1561 .enumerate()
1562 .rev()
1563 .find_map(|(ix, entry)| {
1564 if let AgentThreadEntry::UserMessage(message) = entry {
1565 Some((ix, message))
1566 } else {
1567 None
1568 }
1569 })
1570 }
1571
1572 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1573 self.entries.iter().find_map(|entry| {
1574 if let AgentThreadEntry::UserMessage(message) = entry {
1575 if message.id.as_ref() == Some(id) {
1576 Some(message)
1577 } else {
1578 None
1579 }
1580 } else {
1581 None
1582 }
1583 })
1584 }
1585
1586 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1587 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1588 if let AgentThreadEntry::UserMessage(message) = entry {
1589 if message.id.as_ref() == Some(id) {
1590 Some((ix, message))
1591 } else {
1592 None
1593 }
1594 } else {
1595 None
1596 }
1597 })
1598 }
1599
1600 pub fn read_text_file(
1601 &self,
1602 path: PathBuf,
1603 line: Option<u32>,
1604 limit: Option<u32>,
1605 reuse_shared_snapshot: bool,
1606 cx: &mut Context<Self>,
1607 ) -> Task<Result<String>> {
1608 let project = self.project.clone();
1609 let action_log = self.action_log.clone();
1610 cx.spawn(async move |this, cx| {
1611 let load = project.update(cx, |project, cx| {
1612 let path = project
1613 .project_path_for_absolute_path(&path, cx)
1614 .context("invalid path")?;
1615 anyhow::Ok(project.open_buffer(path, cx))
1616 });
1617 let buffer = load??.await?;
1618
1619 let snapshot = if reuse_shared_snapshot {
1620 this.read_with(cx, |this, _| {
1621 this.shared_buffers.get(&buffer.clone()).cloned()
1622 })
1623 .log_err()
1624 .flatten()
1625 } else {
1626 None
1627 };
1628
1629 let snapshot = if let Some(snapshot) = snapshot {
1630 snapshot
1631 } else {
1632 action_log.update(cx, |action_log, cx| {
1633 action_log.buffer_read(buffer.clone(), cx);
1634 })?;
1635 project.update(cx, |project, cx| {
1636 let position = buffer
1637 .read(cx)
1638 .snapshot()
1639 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1640 project.set_agent_location(
1641 Some(AgentLocation {
1642 buffer: buffer.downgrade(),
1643 position,
1644 }),
1645 cx,
1646 );
1647 })?;
1648
1649 buffer.update(cx, |buffer, _| buffer.snapshot())?
1650 };
1651
1652 this.update(cx, |this, _| {
1653 let text = snapshot.text();
1654 this.shared_buffers.insert(buffer.clone(), snapshot);
1655 if line.is_none() && limit.is_none() {
1656 return Ok(text);
1657 }
1658 let limit = limit.unwrap_or(u32::MAX) as usize;
1659 let Some(line) = line else {
1660 return Ok(text.lines().take(limit).collect::<String>());
1661 };
1662
1663 let count = text.lines().count();
1664 if count < line as usize {
1665 anyhow::bail!("There are only {} lines", count);
1666 }
1667 Ok(text
1668 .lines()
1669 .skip(line as usize + 1)
1670 .take(limit)
1671 .collect::<String>())
1672 })?
1673 })
1674 }
1675
1676 pub fn write_text_file(
1677 &self,
1678 path: PathBuf,
1679 content: String,
1680 cx: &mut Context<Self>,
1681 ) -> Task<Result<()>> {
1682 let project = self.project.clone();
1683 let action_log = self.action_log.clone();
1684 cx.spawn(async move |this, cx| {
1685 let load = project.update(cx, |project, cx| {
1686 let path = project
1687 .project_path_for_absolute_path(&path, cx)
1688 .context("invalid path")?;
1689 anyhow::Ok(project.open_buffer(path, cx))
1690 });
1691 let buffer = load??.await?;
1692 let snapshot = this.update(cx, |this, cx| {
1693 this.shared_buffers
1694 .get(&buffer)
1695 .cloned()
1696 .unwrap_or_else(|| buffer.read(cx).snapshot())
1697 })?;
1698 let edits = cx
1699 .background_executor()
1700 .spawn(async move {
1701 let old_text = snapshot.text();
1702 text_diff(old_text.as_str(), &content)
1703 .into_iter()
1704 .map(|(range, replacement)| {
1705 (
1706 snapshot.anchor_after(range.start)
1707 ..snapshot.anchor_before(range.end),
1708 replacement,
1709 )
1710 })
1711 .collect::<Vec<_>>()
1712 })
1713 .await;
1714
1715 project.update(cx, |project, cx| {
1716 project.set_agent_location(
1717 Some(AgentLocation {
1718 buffer: buffer.downgrade(),
1719 position: edits
1720 .last()
1721 .map(|(range, _)| range.end)
1722 .unwrap_or(Anchor::MIN),
1723 }),
1724 cx,
1725 );
1726 })?;
1727
1728 let format_on_save = cx.update(|cx| {
1729 action_log.update(cx, |action_log, cx| {
1730 action_log.buffer_read(buffer.clone(), cx);
1731 });
1732
1733 let format_on_save = buffer.update(cx, |buffer, cx| {
1734 buffer.edit(edits, None, cx);
1735
1736 let settings = language::language_settings::language_settings(
1737 buffer.language().map(|l| l.name()),
1738 buffer.file(),
1739 cx,
1740 );
1741
1742 settings.format_on_save != FormatOnSave::Off
1743 });
1744 action_log.update(cx, |action_log, cx| {
1745 action_log.buffer_edited(buffer.clone(), cx);
1746 });
1747 format_on_save
1748 })?;
1749
1750 if format_on_save {
1751 let format_task = project.update(cx, |project, cx| {
1752 project.format(
1753 HashSet::from_iter([buffer.clone()]),
1754 LspFormatTarget::Buffers,
1755 false,
1756 FormatTrigger::Save,
1757 cx,
1758 )
1759 })?;
1760 format_task.await.log_err();
1761
1762 action_log.update(cx, |action_log, cx| {
1763 action_log.buffer_edited(buffer.clone(), cx);
1764 })?;
1765 }
1766
1767 project
1768 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1769 .await
1770 })
1771 }
1772
1773 pub fn to_markdown(&self, cx: &App) -> String {
1774 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1775 }
1776
1777 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1778 cx.emit(AcpThreadEvent::LoadError(error));
1779 }
1780}
1781
1782fn markdown_for_raw_output(
1783 raw_output: &serde_json::Value,
1784 language_registry: &Arc<LanguageRegistry>,
1785 cx: &mut App,
1786) -> Option<Entity<Markdown>> {
1787 match raw_output {
1788 serde_json::Value::Null => None,
1789 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1790 Markdown::new(
1791 value.to_string().into(),
1792 Some(language_registry.clone()),
1793 None,
1794 cx,
1795 )
1796 })),
1797 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1798 Markdown::new(
1799 value.to_string().into(),
1800 Some(language_registry.clone()),
1801 None,
1802 cx,
1803 )
1804 })),
1805 serde_json::Value::String(value) => Some(cx.new(|cx| {
1806 Markdown::new(
1807 value.clone().into(),
1808 Some(language_registry.clone()),
1809 None,
1810 cx,
1811 )
1812 })),
1813 value => Some(cx.new(|cx| {
1814 Markdown::new(
1815 format!("```json\n{}\n```", value).into(),
1816 Some(language_registry.clone()),
1817 None,
1818 cx,
1819 )
1820 })),
1821 }
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826 use super::*;
1827 use anyhow::anyhow;
1828 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1829 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1830 use indoc::indoc;
1831 use project::{FakeFs, Fs};
1832 use rand::Rng as _;
1833 use serde_json::json;
1834 use settings::SettingsStore;
1835 use smol::stream::StreamExt as _;
1836 use std::{
1837 any::Any,
1838 cell::RefCell,
1839 path::Path,
1840 rc::Rc,
1841 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1842 time::Duration,
1843 };
1844 use util::path;
1845
1846 fn init_test(cx: &mut TestAppContext) {
1847 env_logger::try_init().ok();
1848 cx.update(|cx| {
1849 let settings_store = SettingsStore::test(cx);
1850 cx.set_global(settings_store);
1851 Project::init_settings(cx);
1852 language::init(cx);
1853 });
1854 }
1855
1856 #[gpui::test]
1857 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1858 init_test(cx);
1859
1860 let fs = FakeFs::new(cx.executor());
1861 let project = Project::test(fs, [], cx).await;
1862 let connection = Rc::new(FakeAgentConnection::new());
1863 let thread = cx
1864 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1865 .await
1866 .unwrap();
1867
1868 // Test creating a new user message
1869 thread.update(cx, |thread, cx| {
1870 thread.push_user_content_block(
1871 None,
1872 acp::ContentBlock::Text(acp::TextContent {
1873 annotations: None,
1874 text: "Hello, ".to_string(),
1875 }),
1876 cx,
1877 );
1878 });
1879
1880 thread.update(cx, |thread, cx| {
1881 assert_eq!(thread.entries.len(), 1);
1882 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1883 assert_eq!(user_msg.id, None);
1884 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1885 } else {
1886 panic!("Expected UserMessage");
1887 }
1888 });
1889
1890 // Test appending to existing user message
1891 let message_1_id = UserMessageId::new();
1892 thread.update(cx, |thread, cx| {
1893 thread.push_user_content_block(
1894 Some(message_1_id.clone()),
1895 acp::ContentBlock::Text(acp::TextContent {
1896 annotations: None,
1897 text: "world!".to_string(),
1898 }),
1899 cx,
1900 );
1901 });
1902
1903 thread.update(cx, |thread, cx| {
1904 assert_eq!(thread.entries.len(), 1);
1905 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1906 assert_eq!(user_msg.id, Some(message_1_id));
1907 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1908 } else {
1909 panic!("Expected UserMessage");
1910 }
1911 });
1912
1913 // Test creating new user message after assistant message
1914 thread.update(cx, |thread, cx| {
1915 thread.push_assistant_content_block(
1916 acp::ContentBlock::Text(acp::TextContent {
1917 annotations: None,
1918 text: "Assistant response".to_string(),
1919 }),
1920 false,
1921 cx,
1922 );
1923 });
1924
1925 let message_2_id = UserMessageId::new();
1926 thread.update(cx, |thread, cx| {
1927 thread.push_user_content_block(
1928 Some(message_2_id.clone()),
1929 acp::ContentBlock::Text(acp::TextContent {
1930 annotations: None,
1931 text: "New user message".to_string(),
1932 }),
1933 cx,
1934 );
1935 });
1936
1937 thread.update(cx, |thread, cx| {
1938 assert_eq!(thread.entries.len(), 3);
1939 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1940 assert_eq!(user_msg.id, Some(message_2_id));
1941 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1942 } else {
1943 panic!("Expected UserMessage at index 2");
1944 }
1945 });
1946 }
1947
1948 #[gpui::test]
1949 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1950 init_test(cx);
1951
1952 let fs = FakeFs::new(cx.executor());
1953 let project = Project::test(fs, [], cx).await;
1954 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1955 |_, thread, mut cx| {
1956 async move {
1957 thread.update(&mut cx, |thread, cx| {
1958 thread
1959 .handle_session_update(
1960 acp::SessionUpdate::AgentThoughtChunk {
1961 content: "Thinking ".into(),
1962 },
1963 cx,
1964 )
1965 .unwrap();
1966 thread
1967 .handle_session_update(
1968 acp::SessionUpdate::AgentThoughtChunk {
1969 content: "hard!".into(),
1970 },
1971 cx,
1972 )
1973 .unwrap();
1974 })?;
1975 Ok(acp::PromptResponse {
1976 stop_reason: acp::StopReason::EndTurn,
1977 })
1978 }
1979 .boxed_local()
1980 },
1981 ));
1982
1983 let thread = cx
1984 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1985 .await
1986 .unwrap();
1987
1988 thread
1989 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1990 .await
1991 .unwrap();
1992
1993 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1994 assert_eq!(
1995 output,
1996 indoc! {r#"
1997 ## User
1998
1999 Hello from Zed!
2000
2001 ## Assistant
2002
2003 <thinking>
2004 Thinking hard!
2005 </thinking>
2006
2007 "#}
2008 );
2009 }
2010
2011 #[gpui::test]
2012 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2013 init_test(cx);
2014
2015 let fs = FakeFs::new(cx.executor());
2016 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2017 .await;
2018 let project = Project::test(fs.clone(), [], cx).await;
2019 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2020 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2021 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2022 move |_, thread, mut cx| {
2023 let read_file_tx = read_file_tx.clone();
2024 async move {
2025 let content = thread
2026 .update(&mut cx, |thread, cx| {
2027 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2028 })
2029 .unwrap()
2030 .await
2031 .unwrap();
2032 assert_eq!(content, "one\ntwo\nthree\n");
2033 read_file_tx.take().unwrap().send(()).unwrap();
2034 thread
2035 .update(&mut cx, |thread, cx| {
2036 thread.write_text_file(
2037 path!("/tmp/foo").into(),
2038 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2039 cx,
2040 )
2041 })
2042 .unwrap()
2043 .await
2044 .unwrap();
2045 Ok(acp::PromptResponse {
2046 stop_reason: acp::StopReason::EndTurn,
2047 })
2048 }
2049 .boxed_local()
2050 },
2051 ));
2052
2053 let (worktree, pathbuf) = project
2054 .update(cx, |project, cx| {
2055 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2056 })
2057 .await
2058 .unwrap();
2059 let buffer = project
2060 .update(cx, |project, cx| {
2061 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2062 })
2063 .await
2064 .unwrap();
2065
2066 let thread = cx
2067 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2068 .await
2069 .unwrap();
2070
2071 let request = thread.update(cx, |thread, cx| {
2072 thread.send_raw("Extend the count in /tmp/foo", cx)
2073 });
2074 read_file_rx.await.ok();
2075 buffer.update(cx, |buffer, cx| {
2076 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2077 });
2078 cx.run_until_parked();
2079 assert_eq!(
2080 buffer.read_with(cx, |buffer, _| buffer.text()),
2081 "zero\none\ntwo\nthree\nfour\nfive\n"
2082 );
2083 assert_eq!(
2084 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2085 "zero\none\ntwo\nthree\nfour\nfive\n"
2086 );
2087 request.await.unwrap();
2088 }
2089
2090 #[gpui::test]
2091 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2092 init_test(cx);
2093
2094 let fs = FakeFs::new(cx.executor());
2095 let project = Project::test(fs, [], cx).await;
2096 let id = acp::ToolCallId("test".into());
2097
2098 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2099 let id = id.clone();
2100 move |_, thread, mut cx| {
2101 let id = id.clone();
2102 async move {
2103 thread
2104 .update(&mut cx, |thread, cx| {
2105 thread.handle_session_update(
2106 acp::SessionUpdate::ToolCall(acp::ToolCall {
2107 id: id.clone(),
2108 title: "Label".into(),
2109 kind: acp::ToolKind::Fetch,
2110 status: acp::ToolCallStatus::InProgress,
2111 content: vec![],
2112 locations: vec![],
2113 raw_input: None,
2114 raw_output: None,
2115 }),
2116 cx,
2117 )
2118 })
2119 .unwrap()
2120 .unwrap();
2121 Ok(acp::PromptResponse {
2122 stop_reason: acp::StopReason::EndTurn,
2123 })
2124 }
2125 .boxed_local()
2126 }
2127 }));
2128
2129 let thread = cx
2130 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2131 .await
2132 .unwrap();
2133
2134 let request = thread.update(cx, |thread, cx| {
2135 thread.send_raw("Fetch https://example.com", cx)
2136 });
2137
2138 run_until_first_tool_call(&thread, cx).await;
2139
2140 thread.read_with(cx, |thread, _| {
2141 assert!(matches!(
2142 thread.entries[1],
2143 AgentThreadEntry::ToolCall(ToolCall {
2144 status: ToolCallStatus::InProgress,
2145 ..
2146 })
2147 ));
2148 });
2149
2150 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2151
2152 thread.read_with(cx, |thread, _| {
2153 assert!(matches!(
2154 &thread.entries[1],
2155 AgentThreadEntry::ToolCall(ToolCall {
2156 status: ToolCallStatus::Canceled,
2157 ..
2158 })
2159 ));
2160 });
2161
2162 thread
2163 .update(cx, |thread, cx| {
2164 thread.handle_session_update(
2165 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2166 id,
2167 fields: acp::ToolCallUpdateFields {
2168 status: Some(acp::ToolCallStatus::Completed),
2169 ..Default::default()
2170 },
2171 }),
2172 cx,
2173 )
2174 })
2175 .unwrap();
2176
2177 request.await.unwrap();
2178
2179 thread.read_with(cx, |thread, _| {
2180 assert!(matches!(
2181 thread.entries[1],
2182 AgentThreadEntry::ToolCall(ToolCall {
2183 status: ToolCallStatus::Completed,
2184 ..
2185 })
2186 ));
2187 });
2188 }
2189
2190 #[gpui::test]
2191 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2192 init_test(cx);
2193 let fs = FakeFs::new(cx.background_executor.clone());
2194 fs.insert_tree(path!("/test"), json!({})).await;
2195 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2196
2197 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2198 move |_, thread, mut cx| {
2199 async move {
2200 thread
2201 .update(&mut cx, |thread, cx| {
2202 thread.handle_session_update(
2203 acp::SessionUpdate::ToolCall(acp::ToolCall {
2204 id: acp::ToolCallId("test".into()),
2205 title: "Label".into(),
2206 kind: acp::ToolKind::Edit,
2207 status: acp::ToolCallStatus::Completed,
2208 content: vec![acp::ToolCallContent::Diff {
2209 diff: acp::Diff {
2210 path: "/test/test.txt".into(),
2211 old_text: None,
2212 new_text: "foo".into(),
2213 },
2214 }],
2215 locations: vec![],
2216 raw_input: None,
2217 raw_output: None,
2218 }),
2219 cx,
2220 )
2221 })
2222 .unwrap()
2223 .unwrap();
2224 Ok(acp::PromptResponse {
2225 stop_reason: acp::StopReason::EndTurn,
2226 })
2227 }
2228 .boxed_local()
2229 }
2230 }));
2231
2232 let thread = cx
2233 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2234 .await
2235 .unwrap();
2236
2237 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2238 .await
2239 .unwrap();
2240
2241 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2242 }
2243
2244 #[gpui::test(iterations = 10)]
2245 async fn test_checkpoints(cx: &mut TestAppContext) {
2246 init_test(cx);
2247 let fs = FakeFs::new(cx.background_executor.clone());
2248 fs.insert_tree(
2249 path!("/test"),
2250 json!({
2251 ".git": {}
2252 }),
2253 )
2254 .await;
2255 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2256
2257 let simulate_changes = Arc::new(AtomicBool::new(true));
2258 let next_filename = Arc::new(AtomicUsize::new(0));
2259 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2260 let simulate_changes = simulate_changes.clone();
2261 let next_filename = next_filename.clone();
2262 let fs = fs.clone();
2263 move |request, thread, mut cx| {
2264 let fs = fs.clone();
2265 let simulate_changes = simulate_changes.clone();
2266 let next_filename = next_filename.clone();
2267 async move {
2268 if simulate_changes.load(SeqCst) {
2269 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2270 fs.write(Path::new(&filename), b"").await?;
2271 }
2272
2273 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2274 panic!("expected text content block");
2275 };
2276 thread.update(&mut cx, |thread, cx| {
2277 thread
2278 .handle_session_update(
2279 acp::SessionUpdate::AgentMessageChunk {
2280 content: content.text.to_uppercase().into(),
2281 },
2282 cx,
2283 )
2284 .unwrap();
2285 })?;
2286 Ok(acp::PromptResponse {
2287 stop_reason: acp::StopReason::EndTurn,
2288 })
2289 }
2290 .boxed_local()
2291 }
2292 }));
2293 let thread = cx
2294 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2295 .await
2296 .unwrap();
2297
2298 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2299 .await
2300 .unwrap();
2301 thread.read_with(cx, |thread, cx| {
2302 assert_eq!(
2303 thread.to_markdown(cx),
2304 indoc! {"
2305 ## User (checkpoint)
2306
2307 Lorem
2308
2309 ## Assistant
2310
2311 LOREM
2312
2313 "}
2314 );
2315 });
2316 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2317
2318 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2319 .await
2320 .unwrap();
2321 thread.read_with(cx, |thread, cx| {
2322 assert_eq!(
2323 thread.to_markdown(cx),
2324 indoc! {"
2325 ## User (checkpoint)
2326
2327 Lorem
2328
2329 ## Assistant
2330
2331 LOREM
2332
2333 ## User (checkpoint)
2334
2335 ipsum
2336
2337 ## Assistant
2338
2339 IPSUM
2340
2341 "}
2342 );
2343 });
2344 assert_eq!(
2345 fs.files(),
2346 vec![
2347 Path::new(path!("/test/file-0")),
2348 Path::new(path!("/test/file-1"))
2349 ]
2350 );
2351
2352 // Checkpoint isn't stored when there are no changes.
2353 simulate_changes.store(false, SeqCst);
2354 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2355 .await
2356 .unwrap();
2357 thread.read_with(cx, |thread, cx| {
2358 assert_eq!(
2359 thread.to_markdown(cx),
2360 indoc! {"
2361 ## User (checkpoint)
2362
2363 Lorem
2364
2365 ## Assistant
2366
2367 LOREM
2368
2369 ## User (checkpoint)
2370
2371 ipsum
2372
2373 ## Assistant
2374
2375 IPSUM
2376
2377 ## User
2378
2379 dolor
2380
2381 ## Assistant
2382
2383 DOLOR
2384
2385 "}
2386 );
2387 });
2388 assert_eq!(
2389 fs.files(),
2390 vec![
2391 Path::new(path!("/test/file-0")),
2392 Path::new(path!("/test/file-1"))
2393 ]
2394 );
2395
2396 // Rewinding the conversation truncates the history and restores the checkpoint.
2397 thread
2398 .update(cx, |thread, cx| {
2399 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2400 panic!("unexpected entries {:?}", thread.entries)
2401 };
2402 thread.rewind(message.id.clone().unwrap(), cx)
2403 })
2404 .await
2405 .unwrap();
2406 thread.read_with(cx, |thread, cx| {
2407 assert_eq!(
2408 thread.to_markdown(cx),
2409 indoc! {"
2410 ## User (checkpoint)
2411
2412 Lorem
2413
2414 ## Assistant
2415
2416 LOREM
2417
2418 "}
2419 );
2420 });
2421 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2422 }
2423
2424 #[gpui::test]
2425 async fn test_refusal(cx: &mut TestAppContext) {
2426 init_test(cx);
2427 let fs = FakeFs::new(cx.background_executor.clone());
2428 fs.insert_tree(path!("/"), json!({})).await;
2429 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2430
2431 let refuse_next = Arc::new(AtomicBool::new(false));
2432 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2433 let refuse_next = refuse_next.clone();
2434 move |request, thread, mut cx| {
2435 let refuse_next = refuse_next.clone();
2436 async move {
2437 if refuse_next.load(SeqCst) {
2438 return Ok(acp::PromptResponse {
2439 stop_reason: acp::StopReason::Refusal,
2440 });
2441 }
2442
2443 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2444 panic!("expected text content block");
2445 };
2446 thread.update(&mut cx, |thread, cx| {
2447 thread
2448 .handle_session_update(
2449 acp::SessionUpdate::AgentMessageChunk {
2450 content: content.text.to_uppercase().into(),
2451 },
2452 cx,
2453 )
2454 .unwrap();
2455 })?;
2456 Ok(acp::PromptResponse {
2457 stop_reason: acp::StopReason::EndTurn,
2458 })
2459 }
2460 .boxed_local()
2461 }
2462 }));
2463 let thread = cx
2464 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2465 .await
2466 .unwrap();
2467
2468 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2469 .await
2470 .unwrap();
2471 thread.read_with(cx, |thread, cx| {
2472 assert_eq!(
2473 thread.to_markdown(cx),
2474 indoc! {"
2475 ## User
2476
2477 hello
2478
2479 ## Assistant
2480
2481 HELLO
2482
2483 "}
2484 );
2485 });
2486
2487 // Simulate refusing the second message, ensuring the conversation gets
2488 // truncated to before sending it.
2489 refuse_next.store(true, SeqCst);
2490 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2491 .await
2492 .unwrap();
2493 thread.read_with(cx, |thread, cx| {
2494 assert_eq!(
2495 thread.to_markdown(cx),
2496 indoc! {"
2497 ## User
2498
2499 hello
2500
2501 ## Assistant
2502
2503 HELLO
2504
2505 "}
2506 );
2507 });
2508 }
2509
2510 async fn run_until_first_tool_call(
2511 thread: &Entity<AcpThread>,
2512 cx: &mut TestAppContext,
2513 ) -> usize {
2514 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2515
2516 let subscription = cx.update(|cx| {
2517 cx.subscribe(thread, move |thread, _, cx| {
2518 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2519 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2520 return tx.try_send(ix).unwrap();
2521 }
2522 }
2523 })
2524 });
2525
2526 select! {
2527 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2528 panic!("Timeout waiting for tool call")
2529 }
2530 ix = rx.next().fuse() => {
2531 drop(subscription);
2532 ix.unwrap()
2533 }
2534 }
2535 }
2536
2537 #[derive(Clone, Default)]
2538 struct FakeAgentConnection {
2539 auth_methods: Vec<acp::AuthMethod>,
2540 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2541 on_user_message: Option<
2542 Rc<
2543 dyn Fn(
2544 acp::PromptRequest,
2545 WeakEntity<AcpThread>,
2546 AsyncApp,
2547 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2548 + 'static,
2549 >,
2550 >,
2551 }
2552
2553 impl FakeAgentConnection {
2554 fn new() -> Self {
2555 Self {
2556 auth_methods: Vec::new(),
2557 on_user_message: None,
2558 sessions: Arc::default(),
2559 }
2560 }
2561
2562 #[expect(unused)]
2563 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2564 self.auth_methods = auth_methods;
2565 self
2566 }
2567
2568 fn on_user_message(
2569 mut self,
2570 handler: impl Fn(
2571 acp::PromptRequest,
2572 WeakEntity<AcpThread>,
2573 AsyncApp,
2574 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2575 + 'static,
2576 ) -> Self {
2577 self.on_user_message.replace(Rc::new(handler));
2578 self
2579 }
2580 }
2581
2582 impl AgentConnection for FakeAgentConnection {
2583 fn auth_methods(&self) -> &[acp::AuthMethod] {
2584 &self.auth_methods
2585 }
2586
2587 fn new_thread(
2588 self: Rc<Self>,
2589 project: Entity<Project>,
2590 _cwd: &Path,
2591 cx: &mut App,
2592 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2593 let session_id = acp::SessionId(
2594 rand::thread_rng()
2595 .sample_iter(&rand::distributions::Alphanumeric)
2596 .take(7)
2597 .map(char::from)
2598 .collect::<String>()
2599 .into(),
2600 );
2601 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2602 let thread = cx.new(|_cx| {
2603 AcpThread::new(
2604 "Test",
2605 self.clone(),
2606 project,
2607 action_log,
2608 session_id.clone(),
2609 )
2610 });
2611 self.sessions.lock().insert(session_id, thread.downgrade());
2612 Task::ready(Ok(thread))
2613 }
2614
2615 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2616 if self.auth_methods().iter().any(|m| m.id == method) {
2617 Task::ready(Ok(()))
2618 } else {
2619 Task::ready(Err(anyhow!("Invalid Auth Method")))
2620 }
2621 }
2622
2623 fn prompt(
2624 &self,
2625 _id: Option<UserMessageId>,
2626 params: acp::PromptRequest,
2627 cx: &mut App,
2628 ) -> Task<gpui::Result<acp::PromptResponse>> {
2629 let sessions = self.sessions.lock();
2630 let thread = sessions.get(¶ms.session_id).unwrap();
2631 if let Some(handler) = &self.on_user_message {
2632 let handler = handler.clone();
2633 let thread = thread.clone();
2634 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2635 } else {
2636 Task::ready(Ok(acp::PromptResponse {
2637 stop_reason: acp::StopReason::EndTurn,
2638 }))
2639 }
2640 }
2641
2642 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
2643 acp::PromptCapabilities {
2644 image: true,
2645 audio: true,
2646 embedded_context: true,
2647 }
2648 }
2649
2650 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2651 let sessions = self.sessions.lock();
2652 let thread = sessions.get(session_id).unwrap().clone();
2653
2654 cx.spawn(async move |cx| {
2655 thread
2656 .update(cx, |thread, cx| thread.cancel(cx))
2657 .unwrap()
2658 .await
2659 })
2660 .detach();
2661 }
2662
2663 fn truncate(
2664 &self,
2665 session_id: &acp::SessionId,
2666 _cx: &App,
2667 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2668 Some(Rc::new(FakeAgentSessionEditor {
2669 _session_id: session_id.clone(),
2670 }))
2671 }
2672
2673 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2674 self
2675 }
2676 }
2677
2678 struct FakeAgentSessionEditor {
2679 _session_id: acp::SessionId,
2680 }
2681
2682 impl AgentSessionTruncate for FakeAgentSessionEditor {
2683 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2684 Task::ready(Ok(()))
2685 }
2686 }
2687}