1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 Self {
187 id: tool_call.id,
188 label: cx.new(|cx| {
189 Markdown::new(
190 tool_call.title.into(),
191 Some(language_registry.clone()),
192 None,
193 cx,
194 )
195 }),
196 kind: tool_call.kind,
197 content: tool_call
198 .content
199 .into_iter()
200 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
201 .collect(),
202 locations: tool_call.locations,
203 resolved_locations: Vec::default(),
204 status,
205 raw_input: tool_call.raw_input,
206 raw_output: tool_call.raw_output,
207 }
208 }
209
210 fn update_fields(
211 &mut self,
212 fields: acp::ToolCallUpdateFields,
213 language_registry: Arc<LanguageRegistry>,
214 cx: &mut App,
215 ) {
216 let acp::ToolCallUpdateFields {
217 kind,
218 status,
219 title,
220 content,
221 locations,
222 raw_input,
223 raw_output,
224 } = fields;
225
226 if let Some(kind) = kind {
227 self.kind = kind;
228 }
229
230 if let Some(status) = status {
231 self.status = status.into();
232 }
233
234 if let Some(title) = title {
235 self.label.update(cx, |label, cx| {
236 label.replace(title, cx);
237 });
238 }
239
240 if let Some(content) = content {
241 let new_content_len = content.len();
242 let mut content = content.into_iter();
243
244 // Reuse existing content if we can
245 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
246 old.update_from_acp(new, language_registry.clone(), cx);
247 }
248 for new in content {
249 self.content.push(ToolCallContent::from_acp(
250 new,
251 language_registry.clone(),
252 cx,
253 ))
254 }
255 self.content.truncate(new_content_len);
256 }
257
258 if let Some(locations) = locations {
259 self.locations = locations;
260 }
261
262 if let Some(raw_input) = raw_input {
263 self.raw_input = Some(raw_input);
264 }
265
266 if let Some(raw_output) = raw_output {
267 if self.content.is_empty()
268 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
269 {
270 self.content
271 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
272 markdown,
273 }));
274 }
275 self.raw_output = Some(raw_output);
276 }
277 }
278
279 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
280 self.content.iter().filter_map(|content| match content {
281 ToolCallContent::Diff(diff) => Some(diff),
282 ToolCallContent::ContentBlock(_) => None,
283 ToolCallContent::Terminal(_) => None,
284 })
285 }
286
287 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
288 self.content.iter().filter_map(|content| match content {
289 ToolCallContent::Terminal(terminal) => Some(terminal),
290 ToolCallContent::ContentBlock(_) => None,
291 ToolCallContent::Diff(_) => None,
292 })
293 }
294
295 fn to_markdown(&self, cx: &App) -> String {
296 let mut markdown = format!(
297 "**Tool Call: {}**\nStatus: {}\n\n",
298 self.label.read(cx).source(),
299 self.status
300 );
301 for content in &self.content {
302 markdown.push_str(content.to_markdown(cx).as_str());
303 markdown.push_str("\n\n");
304 }
305 markdown
306 }
307
308 async fn resolve_location(
309 location: acp::ToolCallLocation,
310 project: WeakEntity<Project>,
311 cx: &mut AsyncApp,
312 ) -> Option<AgentLocation> {
313 let buffer = project
314 .update(cx, |project, cx| {
315 project
316 .project_path_for_absolute_path(&location.path, cx)
317 .map(|path| project.open_buffer(path, cx))
318 })
319 .ok()??;
320 let buffer = buffer.await.log_err()?;
321 let position = buffer
322 .update(cx, |buffer, _| {
323 if let Some(row) = location.line {
324 let snapshot = buffer.snapshot();
325 let column = snapshot.indent_size_for_line(row).len;
326 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
327 snapshot.anchor_before(point)
328 } else {
329 Anchor::MIN
330 }
331 })
332 .ok()?;
333
334 Some(AgentLocation {
335 buffer: buffer.downgrade(),
336 position,
337 })
338 }
339
340 fn resolve_locations(
341 &self,
342 project: Entity<Project>,
343 cx: &mut App,
344 ) -> Task<Vec<Option<AgentLocation>>> {
345 let locations = self.locations.clone();
346 project.update(cx, |_, cx| {
347 cx.spawn(async move |project, cx| {
348 let mut new_locations = Vec::new();
349 for location in locations {
350 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
351 }
352 new_locations
353 })
354 })
355 }
356}
357
358#[derive(Debug)]
359pub enum ToolCallStatus {
360 /// The tool call hasn't started running yet, but we start showing it to
361 /// the user.
362 Pending,
363 /// The tool call is waiting for confirmation from the user.
364 WaitingForConfirmation {
365 options: Vec<acp::PermissionOption>,
366 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
367 },
368 /// The tool call is currently running.
369 InProgress,
370 /// The tool call completed successfully.
371 Completed,
372 /// The tool call failed.
373 Failed,
374 /// The user rejected the tool call.
375 Rejected,
376 /// The user canceled generation so the tool call was canceled.
377 Canceled,
378}
379
380impl From<acp::ToolCallStatus> for ToolCallStatus {
381 fn from(status: acp::ToolCallStatus) -> Self {
382 match status {
383 acp::ToolCallStatus::Pending => Self::Pending,
384 acp::ToolCallStatus::InProgress => Self::InProgress,
385 acp::ToolCallStatus::Completed => Self::Completed,
386 acp::ToolCallStatus::Failed => Self::Failed,
387 }
388 }
389}
390
391impl Display for ToolCallStatus {
392 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393 write!(
394 f,
395 "{}",
396 match self {
397 ToolCallStatus::Pending => "Pending",
398 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
399 ToolCallStatus::InProgress => "In Progress",
400 ToolCallStatus::Completed => "Completed",
401 ToolCallStatus::Failed => "Failed",
402 ToolCallStatus::Rejected => "Rejected",
403 ToolCallStatus::Canceled => "Canceled",
404 }
405 )
406 }
407}
408
409#[derive(Debug, PartialEq, Clone)]
410pub enum ContentBlock {
411 Empty,
412 Markdown { markdown: Entity<Markdown> },
413 ResourceLink { resource_link: acp::ResourceLink },
414}
415
416impl ContentBlock {
417 pub fn new(
418 block: acp::ContentBlock,
419 language_registry: &Arc<LanguageRegistry>,
420 cx: &mut App,
421 ) -> Self {
422 let mut this = Self::Empty;
423 this.append(block, language_registry, cx);
424 this
425 }
426
427 pub fn new_combined(
428 blocks: impl IntoIterator<Item = acp::ContentBlock>,
429 language_registry: Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) -> Self {
432 let mut this = Self::Empty;
433 for block in blocks {
434 this.append(block, &language_registry, cx);
435 }
436 this
437 }
438
439 pub fn append(
440 &mut self,
441 block: acp::ContentBlock,
442 language_registry: &Arc<LanguageRegistry>,
443 cx: &mut App,
444 ) {
445 if matches!(self, ContentBlock::Empty)
446 && let acp::ContentBlock::ResourceLink(resource_link) = block
447 {
448 *self = ContentBlock::ResourceLink { resource_link };
449 return;
450 }
451
452 let new_content = self.block_string_contents(block);
453
454 match self {
455 ContentBlock::Empty => {
456 *self = Self::create_markdown_block(new_content, language_registry, cx);
457 }
458 ContentBlock::Markdown { markdown } => {
459 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
460 }
461 ContentBlock::ResourceLink { resource_link } => {
462 let existing_content = Self::resource_link_md(&resource_link.uri);
463 let combined = format!("{}\n{}", existing_content, new_content);
464
465 *self = Self::create_markdown_block(combined, language_registry, cx);
466 }
467 }
468 }
469
470 fn create_markdown_block(
471 content: String,
472 language_registry: &Arc<LanguageRegistry>,
473 cx: &mut App,
474 ) -> ContentBlock {
475 ContentBlock::Markdown {
476 markdown: cx
477 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
478 }
479 }
480
481 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
482 match block {
483 acp::ContentBlock::Text(text_content) => text_content.text,
484 acp::ContentBlock::ResourceLink(resource_link) => {
485 Self::resource_link_md(&resource_link.uri)
486 }
487 acp::ContentBlock::Resource(acp::EmbeddedResource {
488 resource:
489 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
490 uri,
491 ..
492 }),
493 ..
494 }) => Self::resource_link_md(&uri),
495 acp::ContentBlock::Image(image) => Self::image_md(&image),
496 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
497 }
498 }
499
500 fn resource_link_md(uri: &str) -> String {
501 if let Some(uri) = MentionUri::parse(uri).log_err() {
502 uri.as_link().to_string()
503 } else {
504 uri.to_string()
505 }
506 }
507
508 fn image_md(_image: &acp::ImageContent) -> String {
509 "`Image`".into()
510 }
511
512 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
513 match self {
514 ContentBlock::Empty => "",
515 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
516 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
517 }
518 }
519
520 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
521 match self {
522 ContentBlock::Empty => None,
523 ContentBlock::Markdown { markdown } => Some(markdown),
524 ContentBlock::ResourceLink { .. } => None,
525 }
526 }
527
528 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
529 match self {
530 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
531 _ => None,
532 }
533 }
534}
535
536#[derive(Debug)]
537pub enum ToolCallContent {
538 ContentBlock(ContentBlock),
539 Diff(Entity<Diff>),
540 Terminal(Entity<Terminal>),
541}
542
543impl ToolCallContent {
544 pub fn from_acp(
545 content: acp::ToolCallContent,
546 language_registry: Arc<LanguageRegistry>,
547 cx: &mut App,
548 ) -> Self {
549 match content {
550 acp::ToolCallContent::Content { content } => {
551 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
552 }
553 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
554 Diff::finalized(
555 diff.path,
556 diff.old_text,
557 diff.new_text,
558 language_registry,
559 cx,
560 )
561 })),
562 }
563 }
564
565 pub fn update_from_acp(
566 &mut self,
567 new: acp::ToolCallContent,
568 language_registry: Arc<LanguageRegistry>,
569 cx: &mut App,
570 ) {
571 let needs_update = match (&self, &new) {
572 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
573 old_diff.read(cx).needs_update(
574 new_diff.old_text.as_deref().unwrap_or(""),
575 &new_diff.new_text,
576 cx,
577 )
578 }
579 _ => true,
580 };
581
582 if needs_update {
583 *self = Self::from_acp(new, language_registry, cx);
584 }
585 }
586
587 pub fn to_markdown(&self, cx: &App) -> String {
588 match self {
589 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
590 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
591 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
592 }
593 }
594}
595
596#[derive(Debug, PartialEq)]
597pub enum ToolCallUpdate {
598 UpdateFields(acp::ToolCallUpdate),
599 UpdateDiff(ToolCallUpdateDiff),
600 UpdateTerminal(ToolCallUpdateTerminal),
601}
602
603impl ToolCallUpdate {
604 fn id(&self) -> &acp::ToolCallId {
605 match self {
606 Self::UpdateFields(update) => &update.id,
607 Self::UpdateDiff(diff) => &diff.id,
608 Self::UpdateTerminal(terminal) => &terminal.id,
609 }
610 }
611}
612
613impl From<acp::ToolCallUpdate> for ToolCallUpdate {
614 fn from(update: acp::ToolCallUpdate) -> Self {
615 Self::UpdateFields(update)
616 }
617}
618
619impl From<ToolCallUpdateDiff> for ToolCallUpdate {
620 fn from(diff: ToolCallUpdateDiff) -> Self {
621 Self::UpdateDiff(diff)
622 }
623}
624
625#[derive(Debug, PartialEq)]
626pub struct ToolCallUpdateDiff {
627 pub id: acp::ToolCallId,
628 pub diff: Entity<Diff>,
629}
630
631impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
632 fn from(terminal: ToolCallUpdateTerminal) -> Self {
633 Self::UpdateTerminal(terminal)
634 }
635}
636
637#[derive(Debug, PartialEq)]
638pub struct ToolCallUpdateTerminal {
639 pub id: acp::ToolCallId,
640 pub terminal: Entity<Terminal>,
641}
642
643#[derive(Debug, Default)]
644pub struct Plan {
645 pub entries: Vec<PlanEntry>,
646}
647
648#[derive(Debug)]
649pub struct PlanStats<'a> {
650 pub in_progress_entry: Option<&'a PlanEntry>,
651 pub pending: u32,
652 pub completed: u32,
653}
654
655impl Plan {
656 pub fn is_empty(&self) -> bool {
657 self.entries.is_empty()
658 }
659
660 pub fn stats(&self) -> PlanStats<'_> {
661 let mut stats = PlanStats {
662 in_progress_entry: None,
663 pending: 0,
664 completed: 0,
665 };
666
667 for entry in &self.entries {
668 match &entry.status {
669 acp::PlanEntryStatus::Pending => {
670 stats.pending += 1;
671 }
672 acp::PlanEntryStatus::InProgress => {
673 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
674 }
675 acp::PlanEntryStatus::Completed => {
676 stats.completed += 1;
677 }
678 }
679 }
680
681 stats
682 }
683}
684
685#[derive(Debug)]
686pub struct PlanEntry {
687 pub content: Entity<Markdown>,
688 pub priority: acp::PlanEntryPriority,
689 pub status: acp::PlanEntryStatus,
690}
691
692impl PlanEntry {
693 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
694 Self {
695 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
696 priority: entry.priority,
697 status: entry.status,
698 }
699 }
700}
701
702#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
703pub struct TokenUsage {
704 pub max_tokens: u64,
705 pub used_tokens: u64,
706}
707
708impl TokenUsage {
709 pub fn ratio(&self) -> TokenUsageRatio {
710 #[cfg(debug_assertions)]
711 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
712 .unwrap_or("0.8".to_string())
713 .parse()
714 .unwrap();
715 #[cfg(not(debug_assertions))]
716 let warning_threshold: f32 = 0.8;
717
718 // When the maximum is unknown because there is no selected model,
719 // avoid showing the token limit warning.
720 if self.max_tokens == 0 {
721 TokenUsageRatio::Normal
722 } else if self.used_tokens >= self.max_tokens {
723 TokenUsageRatio::Exceeded
724 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
725 TokenUsageRatio::Warning
726 } else {
727 TokenUsageRatio::Normal
728 }
729 }
730}
731
732#[derive(Debug, Clone, PartialEq, Eq)]
733pub enum TokenUsageRatio {
734 Normal,
735 Warning,
736 Exceeded,
737}
738
739#[derive(Debug, Clone)]
740pub struct RetryStatus {
741 pub last_error: SharedString,
742 pub attempt: usize,
743 pub max_attempts: usize,
744 pub started_at: Instant,
745 pub duration: Duration,
746}
747
748pub struct AcpThread {
749 title: SharedString,
750 entries: Vec<AgentThreadEntry>,
751 plan: Plan,
752 project: Entity<Project>,
753 action_log: Entity<ActionLog>,
754 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
755 send_task: Option<Task<()>>,
756 connection: Rc<dyn AgentConnection>,
757 session_id: acp::SessionId,
758 token_usage: Option<TokenUsage>,
759 prompt_capabilities: acp::PromptCapabilities,
760 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
761}
762
763#[derive(Debug)]
764pub enum AcpThreadEvent {
765 NewEntry,
766 TitleUpdated,
767 TokenUsageUpdated,
768 EntryUpdated(usize),
769 EntriesRemoved(Range<usize>),
770 ToolAuthorizationRequired,
771 Retry(RetryStatus),
772 Stopped,
773 Error,
774 LoadError(LoadError),
775 PromptCapabilitiesUpdated,
776}
777
778impl EventEmitter<AcpThreadEvent> for AcpThread {}
779
780#[derive(PartialEq, Eq, Debug)]
781pub enum ThreadStatus {
782 Idle,
783 WaitingForToolConfirmation,
784 Generating,
785}
786
787#[derive(Debug, Clone)]
788pub enum LoadError {
789 NotInstalled {
790 error_message: SharedString,
791 install_message: SharedString,
792 install_command: String,
793 },
794 Unsupported {
795 error_message: SharedString,
796 upgrade_message: SharedString,
797 upgrade_command: String,
798 },
799 Exited {
800 status: ExitStatus,
801 },
802 Other(SharedString),
803}
804
805impl Display for LoadError {
806 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
807 match self {
808 LoadError::NotInstalled { error_message, .. }
809 | LoadError::Unsupported { error_message, .. } => {
810 write!(f, "{error_message}")
811 }
812 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
813 LoadError::Other(msg) => write!(f, "{}", msg),
814 }
815 }
816}
817
818impl Error for LoadError {}
819
820impl AcpThread {
821 pub fn new(
822 title: impl Into<SharedString>,
823 connection: Rc<dyn AgentConnection>,
824 project: Entity<Project>,
825 action_log: Entity<ActionLog>,
826 session_id: acp::SessionId,
827 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
828 cx: &mut Context<Self>,
829 ) -> Self {
830 let prompt_capabilities = *prompt_capabilities_rx.borrow();
831 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
832 loop {
833 let caps = prompt_capabilities_rx.recv().await?;
834 this.update(cx, |this, cx| {
835 this.prompt_capabilities = caps;
836 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
837 })?;
838 }
839 });
840
841 Self {
842 action_log,
843 shared_buffers: Default::default(),
844 entries: Default::default(),
845 plan: Default::default(),
846 title: title.into(),
847 project,
848 send_task: None,
849 connection,
850 session_id,
851 token_usage: None,
852 prompt_capabilities,
853 _observe_prompt_capabilities: task,
854 }
855 }
856
857 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
858 self.prompt_capabilities
859 }
860
861 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
862 &self.connection
863 }
864
865 pub fn action_log(&self) -> &Entity<ActionLog> {
866 &self.action_log
867 }
868
869 pub fn project(&self) -> &Entity<Project> {
870 &self.project
871 }
872
873 pub fn title(&self) -> SharedString {
874 self.title.clone()
875 }
876
877 pub fn entries(&self) -> &[AgentThreadEntry] {
878 &self.entries
879 }
880
881 pub fn session_id(&self) -> &acp::SessionId {
882 &self.session_id
883 }
884
885 pub fn status(&self) -> ThreadStatus {
886 if self.send_task.is_some() {
887 if self.waiting_for_tool_confirmation() {
888 ThreadStatus::WaitingForToolConfirmation
889 } else {
890 ThreadStatus::Generating
891 }
892 } else {
893 ThreadStatus::Idle
894 }
895 }
896
897 pub fn token_usage(&self) -> Option<&TokenUsage> {
898 self.token_usage.as_ref()
899 }
900
901 pub fn has_pending_edit_tool_calls(&self) -> bool {
902 for entry in self.entries.iter().rev() {
903 match entry {
904 AgentThreadEntry::UserMessage(_) => return false,
905 AgentThreadEntry::ToolCall(
906 call @ ToolCall {
907 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
908 ..
909 },
910 ) if call.diffs().next().is_some() => {
911 return true;
912 }
913 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
914 }
915 }
916
917 false
918 }
919
920 pub fn used_tools_since_last_user_message(&self) -> bool {
921 for entry in self.entries.iter().rev() {
922 match entry {
923 AgentThreadEntry::UserMessage(..) => return false,
924 AgentThreadEntry::AssistantMessage(..) => continue,
925 AgentThreadEntry::ToolCall(..) => return true,
926 }
927 }
928
929 false
930 }
931
932 pub fn handle_session_update(
933 &mut self,
934 update: acp::SessionUpdate,
935 cx: &mut Context<Self>,
936 ) -> Result<(), acp::Error> {
937 match update {
938 acp::SessionUpdate::UserMessageChunk { content } => {
939 self.push_user_content_block(None, content, cx);
940 }
941 acp::SessionUpdate::AgentMessageChunk { content } => {
942 self.push_assistant_content_block(content, false, cx);
943 }
944 acp::SessionUpdate::AgentThoughtChunk { content } => {
945 self.push_assistant_content_block(content, true, cx);
946 }
947 acp::SessionUpdate::ToolCall(tool_call) => {
948 self.upsert_tool_call(tool_call, cx)?;
949 }
950 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
951 self.update_tool_call(tool_call_update, cx)?;
952 }
953 acp::SessionUpdate::Plan(plan) => {
954 self.update_plan(plan, cx);
955 }
956 }
957 Ok(())
958 }
959
960 pub fn push_user_content_block(
961 &mut self,
962 message_id: Option<UserMessageId>,
963 chunk: acp::ContentBlock,
964 cx: &mut Context<Self>,
965 ) {
966 let language_registry = self.project.read(cx).languages().clone();
967 let entries_len = self.entries.len();
968
969 if let Some(last_entry) = self.entries.last_mut()
970 && let AgentThreadEntry::UserMessage(UserMessage {
971 id,
972 content,
973 chunks,
974 ..
975 }) = last_entry
976 {
977 *id = message_id.or(id.take());
978 content.append(chunk.clone(), &language_registry, cx);
979 chunks.push(chunk);
980 let idx = entries_len - 1;
981 cx.emit(AcpThreadEvent::EntryUpdated(idx));
982 } else {
983 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
984 self.push_entry(
985 AgentThreadEntry::UserMessage(UserMessage {
986 id: message_id,
987 content,
988 chunks: vec![chunk],
989 checkpoint: None,
990 }),
991 cx,
992 );
993 }
994 }
995
996 pub fn push_assistant_content_block(
997 &mut self,
998 chunk: acp::ContentBlock,
999 is_thought: bool,
1000 cx: &mut Context<Self>,
1001 ) {
1002 let language_registry = self.project.read(cx).languages().clone();
1003 let entries_len = self.entries.len();
1004 if let Some(last_entry) = self.entries.last_mut()
1005 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1006 {
1007 let idx = entries_len - 1;
1008 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1009 match (chunks.last_mut(), is_thought) {
1010 (Some(AssistantMessageChunk::Message { block }), false)
1011 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1012 block.append(chunk, &language_registry, cx)
1013 }
1014 _ => {
1015 let block = ContentBlock::new(chunk, &language_registry, cx);
1016 if is_thought {
1017 chunks.push(AssistantMessageChunk::Thought { block })
1018 } else {
1019 chunks.push(AssistantMessageChunk::Message { block })
1020 }
1021 }
1022 }
1023 } else {
1024 let block = ContentBlock::new(chunk, &language_registry, cx);
1025 let chunk = if is_thought {
1026 AssistantMessageChunk::Thought { block }
1027 } else {
1028 AssistantMessageChunk::Message { block }
1029 };
1030
1031 self.push_entry(
1032 AgentThreadEntry::AssistantMessage(AssistantMessage {
1033 chunks: vec![chunk],
1034 }),
1035 cx,
1036 );
1037 }
1038 }
1039
1040 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1041 self.entries.push(entry);
1042 cx.emit(AcpThreadEvent::NewEntry);
1043 }
1044
1045 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1046 self.connection.set_title(&self.session_id, cx).is_some()
1047 }
1048
1049 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1050 if title != self.title {
1051 self.title = title.clone();
1052 cx.emit(AcpThreadEvent::TitleUpdated);
1053 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1054 return set_title.run(title, cx);
1055 }
1056 }
1057 Task::ready(Ok(()))
1058 }
1059
1060 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1061 self.token_usage = usage;
1062 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1063 }
1064
1065 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1066 cx.emit(AcpThreadEvent::Retry(status));
1067 }
1068
1069 pub fn update_tool_call(
1070 &mut self,
1071 update: impl Into<ToolCallUpdate>,
1072 cx: &mut Context<Self>,
1073 ) -> Result<()> {
1074 let update = update.into();
1075 let languages = self.project.read(cx).languages().clone();
1076
1077 let (ix, current_call) = self
1078 .tool_call_mut(update.id())
1079 .context("Tool call not found")?;
1080 match update {
1081 ToolCallUpdate::UpdateFields(update) => {
1082 let location_updated = update.fields.locations.is_some();
1083 current_call.update_fields(update.fields, languages, cx);
1084 if location_updated {
1085 self.resolve_locations(update.id, cx);
1086 }
1087 }
1088 ToolCallUpdate::UpdateDiff(update) => {
1089 current_call.content.clear();
1090 current_call
1091 .content
1092 .push(ToolCallContent::Diff(update.diff));
1093 }
1094 ToolCallUpdate::UpdateTerminal(update) => {
1095 current_call.content.clear();
1096 current_call
1097 .content
1098 .push(ToolCallContent::Terminal(update.terminal));
1099 }
1100 }
1101
1102 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1103
1104 Ok(())
1105 }
1106
1107 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1108 pub fn upsert_tool_call(
1109 &mut self,
1110 tool_call: acp::ToolCall,
1111 cx: &mut Context<Self>,
1112 ) -> Result<(), acp::Error> {
1113 let status = tool_call.status.into();
1114 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1115 }
1116
1117 /// Fails if id does not match an existing entry.
1118 pub fn upsert_tool_call_inner(
1119 &mut self,
1120 tool_call_update: acp::ToolCallUpdate,
1121 status: ToolCallStatus,
1122 cx: &mut Context<Self>,
1123 ) -> Result<(), acp::Error> {
1124 let language_registry = self.project.read(cx).languages().clone();
1125 let id = tool_call_update.id.clone();
1126
1127 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1128 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1129 current_call.status = status;
1130
1131 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1132 } else {
1133 let call =
1134 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1135 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1136 };
1137
1138 self.resolve_locations(id, cx);
1139 Ok(())
1140 }
1141
1142 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1143 // The tool call we are looking for is typically the last one, or very close to the end.
1144 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1145 self.entries
1146 .iter_mut()
1147 .enumerate()
1148 .rev()
1149 .find_map(|(index, tool_call)| {
1150 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1151 && &tool_call.id == id
1152 {
1153 Some((index, tool_call))
1154 } else {
1155 None
1156 }
1157 })
1158 }
1159
1160 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1161 self.entries
1162 .iter()
1163 .enumerate()
1164 .rev()
1165 .find_map(|(index, tool_call)| {
1166 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1167 && &tool_call.id == id
1168 {
1169 Some((index, tool_call))
1170 } else {
1171 None
1172 }
1173 })
1174 }
1175
1176 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1177 let project = self.project.clone();
1178 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1179 return;
1180 };
1181 let task = tool_call.resolve_locations(project, cx);
1182 cx.spawn(async move |this, cx| {
1183 let resolved_locations = task.await;
1184 this.update(cx, |this, cx| {
1185 let project = this.project.clone();
1186 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1187 return;
1188 };
1189 if let Some(Some(location)) = resolved_locations.last() {
1190 project.update(cx, |project, cx| {
1191 if let Some(agent_location) = project.agent_location() {
1192 let should_ignore = agent_location.buffer == location.buffer
1193 && location
1194 .buffer
1195 .update(cx, |buffer, _| {
1196 let snapshot = buffer.snapshot();
1197 let old_position =
1198 agent_location.position.to_point(&snapshot);
1199 let new_position = location.position.to_point(&snapshot);
1200 // ignore this so that when we get updates from the edit tool
1201 // the position doesn't reset to the startof line
1202 old_position.row == new_position.row
1203 && old_position.column > new_position.column
1204 })
1205 .ok()
1206 .unwrap_or_default();
1207 if !should_ignore {
1208 project.set_agent_location(Some(location.clone()), cx);
1209 }
1210 }
1211 });
1212 }
1213 if tool_call.resolved_locations != resolved_locations {
1214 tool_call.resolved_locations = resolved_locations;
1215 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1216 }
1217 })
1218 })
1219 .detach();
1220 }
1221
1222 pub fn request_tool_call_authorization(
1223 &mut self,
1224 tool_call: acp::ToolCallUpdate,
1225 options: Vec<acp::PermissionOption>,
1226 cx: &mut Context<Self>,
1227 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1228 let (tx, rx) = oneshot::channel();
1229
1230 let status = ToolCallStatus::WaitingForConfirmation {
1231 options,
1232 respond_tx: tx,
1233 };
1234
1235 self.upsert_tool_call_inner(tool_call, status, cx)?;
1236 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1237 Ok(rx)
1238 }
1239
1240 pub fn authorize_tool_call(
1241 &mut self,
1242 id: acp::ToolCallId,
1243 option_id: acp::PermissionOptionId,
1244 option_kind: acp::PermissionOptionKind,
1245 cx: &mut Context<Self>,
1246 ) {
1247 let Some((ix, call)) = self.tool_call_mut(&id) else {
1248 return;
1249 };
1250
1251 let new_status = match option_kind {
1252 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1253 ToolCallStatus::Rejected
1254 }
1255 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1256 ToolCallStatus::InProgress
1257 }
1258 };
1259
1260 let curr_status = mem::replace(&mut call.status, new_status);
1261
1262 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1263 respond_tx.send(option_id).log_err();
1264 } else if cfg!(debug_assertions) {
1265 panic!("tried to authorize an already authorized tool call");
1266 }
1267
1268 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1269 }
1270
1271 /// Returns true if the last turn is awaiting tool authorization
1272 pub fn waiting_for_tool_confirmation(&self) -> bool {
1273 for entry in self.entries.iter().rev() {
1274 match &entry {
1275 AgentThreadEntry::ToolCall(call) => match call.status {
1276 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1277 ToolCallStatus::Pending
1278 | ToolCallStatus::InProgress
1279 | ToolCallStatus::Completed
1280 | ToolCallStatus::Failed
1281 | ToolCallStatus::Rejected
1282 | ToolCallStatus::Canceled => continue,
1283 },
1284 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1285 // Reached the beginning of the turn
1286 return false;
1287 }
1288 }
1289 }
1290 false
1291 }
1292
1293 pub fn plan(&self) -> &Plan {
1294 &self.plan
1295 }
1296
1297 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1298 let new_entries_len = request.entries.len();
1299 let mut new_entries = request.entries.into_iter();
1300
1301 // Reuse existing markdown to prevent flickering
1302 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1303 let PlanEntry {
1304 content,
1305 priority,
1306 status,
1307 } = old;
1308 content.update(cx, |old, cx| {
1309 old.replace(new.content, cx);
1310 });
1311 *priority = new.priority;
1312 *status = new.status;
1313 }
1314 for new in new_entries {
1315 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1316 }
1317 self.plan.entries.truncate(new_entries_len);
1318
1319 cx.notify();
1320 }
1321
1322 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1323 self.plan
1324 .entries
1325 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1326 cx.notify();
1327 }
1328
1329 #[cfg(any(test, feature = "test-support"))]
1330 pub fn send_raw(
1331 &mut self,
1332 message: &str,
1333 cx: &mut Context<Self>,
1334 ) -> BoxFuture<'static, Result<()>> {
1335 self.send(
1336 vec![acp::ContentBlock::Text(acp::TextContent {
1337 text: message.to_string(),
1338 annotations: None,
1339 })],
1340 cx,
1341 )
1342 }
1343
1344 pub fn send(
1345 &mut self,
1346 message: Vec<acp::ContentBlock>,
1347 cx: &mut Context<Self>,
1348 ) -> BoxFuture<'static, Result<()>> {
1349 let block = ContentBlock::new_combined(
1350 message.clone(),
1351 self.project.read(cx).languages().clone(),
1352 cx,
1353 );
1354 let request = acp::PromptRequest {
1355 prompt: message.clone(),
1356 session_id: self.session_id.clone(),
1357 };
1358 let git_store = self.project.read(cx).git_store().clone();
1359
1360 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1361 Some(UserMessageId::new())
1362 } else {
1363 None
1364 };
1365
1366 self.run_turn(cx, async move |this, cx| {
1367 this.update(cx, |this, cx| {
1368 this.push_entry(
1369 AgentThreadEntry::UserMessage(UserMessage {
1370 id: message_id.clone(),
1371 content: block,
1372 chunks: message,
1373 checkpoint: None,
1374 }),
1375 cx,
1376 );
1377 })
1378 .ok();
1379
1380 let old_checkpoint = git_store
1381 .update(cx, |git, cx| git.checkpoint(cx))?
1382 .await
1383 .context("failed to get old checkpoint")
1384 .log_err();
1385 this.update(cx, |this, cx| {
1386 if let Some((_ix, message)) = this.last_user_message() {
1387 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1388 git_checkpoint,
1389 show: false,
1390 });
1391 }
1392 this.connection.prompt(message_id, request, cx)
1393 })?
1394 .await
1395 })
1396 }
1397
1398 pub fn can_resume(&self, cx: &App) -> bool {
1399 self.connection.resume(&self.session_id, cx).is_some()
1400 }
1401
1402 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1403 self.run_turn(cx, async move |this, cx| {
1404 this.update(cx, |this, cx| {
1405 this.connection
1406 .resume(&this.session_id, cx)
1407 .map(|resume| resume.run(cx))
1408 })?
1409 .context("resuming a session is not supported")?
1410 .await
1411 })
1412 }
1413
1414 fn run_turn(
1415 &mut self,
1416 cx: &mut Context<Self>,
1417 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1418 ) -> BoxFuture<'static, Result<()>> {
1419 self.clear_completed_plan_entries(cx);
1420
1421 let (tx, rx) = oneshot::channel();
1422 let cancel_task = self.cancel(cx);
1423
1424 self.send_task = Some(cx.spawn(async move |this, cx| {
1425 cancel_task.await;
1426 tx.send(f(this, cx).await).ok();
1427 }));
1428
1429 cx.spawn(async move |this, cx| {
1430 let response = rx.await;
1431
1432 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1433 .await?;
1434
1435 this.update(cx, |this, cx| {
1436 this.project
1437 .update(cx, |project, cx| project.set_agent_location(None, cx));
1438 match response {
1439 Ok(Err(e)) => {
1440 this.send_task.take();
1441 cx.emit(AcpThreadEvent::Error);
1442 Err(e)
1443 }
1444 result => {
1445 let canceled = matches!(
1446 result,
1447 Ok(Ok(acp::PromptResponse {
1448 stop_reason: acp::StopReason::Cancelled
1449 }))
1450 );
1451
1452 // We only take the task if the current prompt wasn't canceled.
1453 //
1454 // This prompt may have been canceled because another one was sent
1455 // while it was still generating. In these cases, dropping `send_task`
1456 // would cause the next generation to be canceled.
1457 if !canceled {
1458 this.send_task.take();
1459 }
1460
1461 // Truncate entries if the last prompt was refused.
1462 if let Ok(Ok(acp::PromptResponse {
1463 stop_reason: acp::StopReason::Refusal,
1464 })) = result
1465 && let Some((ix, _)) = this.last_user_message()
1466 {
1467 let range = ix..this.entries.len();
1468 this.entries.truncate(ix);
1469 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1470 }
1471
1472 cx.emit(AcpThreadEvent::Stopped);
1473 Ok(())
1474 }
1475 }
1476 })?
1477 })
1478 .boxed()
1479 }
1480
1481 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1482 let Some(send_task) = self.send_task.take() else {
1483 return Task::ready(());
1484 };
1485
1486 for entry in self.entries.iter_mut() {
1487 if let AgentThreadEntry::ToolCall(call) = entry {
1488 let cancel = matches!(
1489 call.status,
1490 ToolCallStatus::Pending
1491 | ToolCallStatus::WaitingForConfirmation { .. }
1492 | ToolCallStatus::InProgress
1493 );
1494
1495 if cancel {
1496 call.status = ToolCallStatus::Canceled;
1497 }
1498 }
1499 }
1500
1501 self.connection.cancel(&self.session_id, cx);
1502
1503 // Wait for the send task to complete
1504 cx.foreground_executor().spawn(send_task)
1505 }
1506
1507 /// Rewinds this thread to before the entry at `index`, removing it and all
1508 /// subsequent entries while reverting any changes made from that point.
1509 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1510 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1511 return Task::ready(Err(anyhow!("not supported")));
1512 };
1513 let Some(message) = self.user_message(&id) else {
1514 return Task::ready(Err(anyhow!("message not found")));
1515 };
1516
1517 let checkpoint = message
1518 .checkpoint
1519 .as_ref()
1520 .map(|c| c.git_checkpoint.clone());
1521
1522 let git_store = self.project.read(cx).git_store().clone();
1523 cx.spawn(async move |this, cx| {
1524 if let Some(checkpoint) = checkpoint {
1525 git_store
1526 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1527 .await?;
1528 }
1529
1530 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1531 this.update(cx, |this, cx| {
1532 if let Some((ix, _)) = this.user_message_mut(&id) {
1533 let range = ix..this.entries.len();
1534 this.entries.truncate(ix);
1535 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1536 }
1537 })
1538 })
1539 }
1540
1541 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1542 let git_store = self.project.read(cx).git_store().clone();
1543
1544 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1545 if let Some(checkpoint) = message.checkpoint.as_ref() {
1546 checkpoint.git_checkpoint.clone()
1547 } else {
1548 return Task::ready(Ok(()));
1549 }
1550 } else {
1551 return Task::ready(Ok(()));
1552 };
1553
1554 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1555 cx.spawn(async move |this, cx| {
1556 let new_checkpoint = new_checkpoint
1557 .await
1558 .context("failed to get new checkpoint")
1559 .log_err();
1560 if let Some(new_checkpoint) = new_checkpoint {
1561 let equal = git_store
1562 .update(cx, |git, cx| {
1563 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1564 })?
1565 .await
1566 .unwrap_or(true);
1567 this.update(cx, |this, cx| {
1568 let (ix, message) = this.last_user_message().context("no user message")?;
1569 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1570 checkpoint.show = !equal;
1571 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1572 anyhow::Ok(())
1573 })??;
1574 }
1575
1576 Ok(())
1577 })
1578 }
1579
1580 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1581 self.entries
1582 .iter_mut()
1583 .enumerate()
1584 .rev()
1585 .find_map(|(ix, entry)| {
1586 if let AgentThreadEntry::UserMessage(message) = entry {
1587 Some((ix, message))
1588 } else {
1589 None
1590 }
1591 })
1592 }
1593
1594 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1595 self.entries.iter().find_map(|entry| {
1596 if let AgentThreadEntry::UserMessage(message) = entry {
1597 if message.id.as_ref() == Some(id) {
1598 Some(message)
1599 } else {
1600 None
1601 }
1602 } else {
1603 None
1604 }
1605 })
1606 }
1607
1608 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1609 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1610 if let AgentThreadEntry::UserMessage(message) = entry {
1611 if message.id.as_ref() == Some(id) {
1612 Some((ix, message))
1613 } else {
1614 None
1615 }
1616 } else {
1617 None
1618 }
1619 })
1620 }
1621
1622 pub fn read_text_file(
1623 &self,
1624 path: PathBuf,
1625 line: Option<u32>,
1626 limit: Option<u32>,
1627 reuse_shared_snapshot: bool,
1628 cx: &mut Context<Self>,
1629 ) -> Task<Result<String>> {
1630 let project = self.project.clone();
1631 let action_log = self.action_log.clone();
1632 cx.spawn(async move |this, cx| {
1633 let load = project.update(cx, |project, cx| {
1634 let path = project
1635 .project_path_for_absolute_path(&path, cx)
1636 .context("invalid path")?;
1637 anyhow::Ok(project.open_buffer(path, cx))
1638 });
1639 let buffer = load??.await?;
1640
1641 let snapshot = if reuse_shared_snapshot {
1642 this.read_with(cx, |this, _| {
1643 this.shared_buffers.get(&buffer.clone()).cloned()
1644 })
1645 .log_err()
1646 .flatten()
1647 } else {
1648 None
1649 };
1650
1651 let snapshot = if let Some(snapshot) = snapshot {
1652 snapshot
1653 } else {
1654 action_log.update(cx, |action_log, cx| {
1655 action_log.buffer_read(buffer.clone(), cx);
1656 })?;
1657 project.update(cx, |project, cx| {
1658 let position = buffer
1659 .read(cx)
1660 .snapshot()
1661 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1662 project.set_agent_location(
1663 Some(AgentLocation {
1664 buffer: buffer.downgrade(),
1665 position,
1666 }),
1667 cx,
1668 );
1669 })?;
1670
1671 buffer.update(cx, |buffer, _| buffer.snapshot())?
1672 };
1673
1674 this.update(cx, |this, _| {
1675 let text = snapshot.text();
1676 this.shared_buffers.insert(buffer.clone(), snapshot);
1677 if line.is_none() && limit.is_none() {
1678 return Ok(text);
1679 }
1680 let limit = limit.unwrap_or(u32::MAX) as usize;
1681 let Some(line) = line else {
1682 return Ok(text.lines().take(limit).collect::<String>());
1683 };
1684
1685 let count = text.lines().count();
1686 if count < line as usize {
1687 anyhow::bail!("There are only {} lines", count);
1688 }
1689 Ok(text
1690 .lines()
1691 .skip(line as usize + 1)
1692 .take(limit)
1693 .collect::<String>())
1694 })?
1695 })
1696 }
1697
1698 pub fn write_text_file(
1699 &self,
1700 path: PathBuf,
1701 content: String,
1702 cx: &mut Context<Self>,
1703 ) -> Task<Result<()>> {
1704 let project = self.project.clone();
1705 let action_log = self.action_log.clone();
1706 cx.spawn(async move |this, cx| {
1707 let load = project.update(cx, |project, cx| {
1708 let path = project
1709 .project_path_for_absolute_path(&path, cx)
1710 .context("invalid path")?;
1711 anyhow::Ok(project.open_buffer(path, cx))
1712 });
1713 let buffer = load??.await?;
1714 let snapshot = this.update(cx, |this, cx| {
1715 this.shared_buffers
1716 .get(&buffer)
1717 .cloned()
1718 .unwrap_or_else(|| buffer.read(cx).snapshot())
1719 })?;
1720 let edits = cx
1721 .background_executor()
1722 .spawn(async move {
1723 let old_text = snapshot.text();
1724 text_diff(old_text.as_str(), &content)
1725 .into_iter()
1726 .map(|(range, replacement)| {
1727 (
1728 snapshot.anchor_after(range.start)
1729 ..snapshot.anchor_before(range.end),
1730 replacement,
1731 )
1732 })
1733 .collect::<Vec<_>>()
1734 })
1735 .await;
1736
1737 project.update(cx, |project, cx| {
1738 project.set_agent_location(
1739 Some(AgentLocation {
1740 buffer: buffer.downgrade(),
1741 position: edits
1742 .last()
1743 .map(|(range, _)| range.end)
1744 .unwrap_or(Anchor::MIN),
1745 }),
1746 cx,
1747 );
1748 })?;
1749
1750 let format_on_save = cx.update(|cx| {
1751 action_log.update(cx, |action_log, cx| {
1752 action_log.buffer_read(buffer.clone(), cx);
1753 });
1754
1755 let format_on_save = buffer.update(cx, |buffer, cx| {
1756 buffer.edit(edits, None, cx);
1757
1758 let settings = language::language_settings::language_settings(
1759 buffer.language().map(|l| l.name()),
1760 buffer.file(),
1761 cx,
1762 );
1763
1764 settings.format_on_save != FormatOnSave::Off
1765 });
1766 action_log.update(cx, |action_log, cx| {
1767 action_log.buffer_edited(buffer.clone(), cx);
1768 });
1769 format_on_save
1770 })?;
1771
1772 if format_on_save {
1773 let format_task = project.update(cx, |project, cx| {
1774 project.format(
1775 HashSet::from_iter([buffer.clone()]),
1776 LspFormatTarget::Buffers,
1777 false,
1778 FormatTrigger::Save,
1779 cx,
1780 )
1781 })?;
1782 format_task.await.log_err();
1783
1784 action_log.update(cx, |action_log, cx| {
1785 action_log.buffer_edited(buffer.clone(), cx);
1786 })?;
1787 }
1788
1789 project
1790 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1791 .await
1792 })
1793 }
1794
1795 pub fn to_markdown(&self, cx: &App) -> String {
1796 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1797 }
1798
1799 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1800 cx.emit(AcpThreadEvent::LoadError(error));
1801 }
1802}
1803
1804fn markdown_for_raw_output(
1805 raw_output: &serde_json::Value,
1806 language_registry: &Arc<LanguageRegistry>,
1807 cx: &mut App,
1808) -> Option<Entity<Markdown>> {
1809 match raw_output {
1810 serde_json::Value::Null => None,
1811 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1812 Markdown::new(
1813 value.to_string().into(),
1814 Some(language_registry.clone()),
1815 None,
1816 cx,
1817 )
1818 })),
1819 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1820 Markdown::new(
1821 value.to_string().into(),
1822 Some(language_registry.clone()),
1823 None,
1824 cx,
1825 )
1826 })),
1827 serde_json::Value::String(value) => Some(cx.new(|cx| {
1828 Markdown::new(
1829 value.clone().into(),
1830 Some(language_registry.clone()),
1831 None,
1832 cx,
1833 )
1834 })),
1835 value => Some(cx.new(|cx| {
1836 Markdown::new(
1837 format!("```json\n{}\n```", value).into(),
1838 Some(language_registry.clone()),
1839 None,
1840 cx,
1841 )
1842 })),
1843 }
1844}
1845
1846#[cfg(test)]
1847mod tests {
1848 use super::*;
1849 use anyhow::anyhow;
1850 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1851 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1852 use indoc::indoc;
1853 use project::{FakeFs, Fs};
1854 use rand::Rng as _;
1855 use serde_json::json;
1856 use settings::SettingsStore;
1857 use smol::stream::StreamExt as _;
1858 use std::{
1859 any::Any,
1860 cell::RefCell,
1861 path::Path,
1862 rc::Rc,
1863 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1864 time::Duration,
1865 };
1866 use util::path;
1867
1868 fn init_test(cx: &mut TestAppContext) {
1869 env_logger::try_init().ok();
1870 cx.update(|cx| {
1871 let settings_store = SettingsStore::test(cx);
1872 cx.set_global(settings_store);
1873 Project::init_settings(cx);
1874 language::init(cx);
1875 });
1876 }
1877
1878 #[gpui::test]
1879 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1880 init_test(cx);
1881
1882 let fs = FakeFs::new(cx.executor());
1883 let project = Project::test(fs, [], cx).await;
1884 let connection = Rc::new(FakeAgentConnection::new());
1885 let thread = cx
1886 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1887 .await
1888 .unwrap();
1889
1890 // Test creating a new user message
1891 thread.update(cx, |thread, cx| {
1892 thread.push_user_content_block(
1893 None,
1894 acp::ContentBlock::Text(acp::TextContent {
1895 annotations: None,
1896 text: "Hello, ".to_string(),
1897 }),
1898 cx,
1899 );
1900 });
1901
1902 thread.update(cx, |thread, cx| {
1903 assert_eq!(thread.entries.len(), 1);
1904 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1905 assert_eq!(user_msg.id, None);
1906 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1907 } else {
1908 panic!("Expected UserMessage");
1909 }
1910 });
1911
1912 // Test appending to existing user message
1913 let message_1_id = UserMessageId::new();
1914 thread.update(cx, |thread, cx| {
1915 thread.push_user_content_block(
1916 Some(message_1_id.clone()),
1917 acp::ContentBlock::Text(acp::TextContent {
1918 annotations: None,
1919 text: "world!".to_string(),
1920 }),
1921 cx,
1922 );
1923 });
1924
1925 thread.update(cx, |thread, cx| {
1926 assert_eq!(thread.entries.len(), 1);
1927 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1928 assert_eq!(user_msg.id, Some(message_1_id));
1929 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1930 } else {
1931 panic!("Expected UserMessage");
1932 }
1933 });
1934
1935 // Test creating new user message after assistant message
1936 thread.update(cx, |thread, cx| {
1937 thread.push_assistant_content_block(
1938 acp::ContentBlock::Text(acp::TextContent {
1939 annotations: None,
1940 text: "Assistant response".to_string(),
1941 }),
1942 false,
1943 cx,
1944 );
1945 });
1946
1947 let message_2_id = UserMessageId::new();
1948 thread.update(cx, |thread, cx| {
1949 thread.push_user_content_block(
1950 Some(message_2_id.clone()),
1951 acp::ContentBlock::Text(acp::TextContent {
1952 annotations: None,
1953 text: "New user message".to_string(),
1954 }),
1955 cx,
1956 );
1957 });
1958
1959 thread.update(cx, |thread, cx| {
1960 assert_eq!(thread.entries.len(), 3);
1961 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1962 assert_eq!(user_msg.id, Some(message_2_id));
1963 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1964 } else {
1965 panic!("Expected UserMessage at index 2");
1966 }
1967 });
1968 }
1969
1970 #[gpui::test]
1971 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1972 init_test(cx);
1973
1974 let fs = FakeFs::new(cx.executor());
1975 let project = Project::test(fs, [], cx).await;
1976 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1977 |_, thread, mut cx| {
1978 async move {
1979 thread.update(&mut cx, |thread, cx| {
1980 thread
1981 .handle_session_update(
1982 acp::SessionUpdate::AgentThoughtChunk {
1983 content: "Thinking ".into(),
1984 },
1985 cx,
1986 )
1987 .unwrap();
1988 thread
1989 .handle_session_update(
1990 acp::SessionUpdate::AgentThoughtChunk {
1991 content: "hard!".into(),
1992 },
1993 cx,
1994 )
1995 .unwrap();
1996 })?;
1997 Ok(acp::PromptResponse {
1998 stop_reason: acp::StopReason::EndTurn,
1999 })
2000 }
2001 .boxed_local()
2002 },
2003 ));
2004
2005 let thread = cx
2006 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2007 .await
2008 .unwrap();
2009
2010 thread
2011 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2012 .await
2013 .unwrap();
2014
2015 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2016 assert_eq!(
2017 output,
2018 indoc! {r#"
2019 ## User
2020
2021 Hello from Zed!
2022
2023 ## Assistant
2024
2025 <thinking>
2026 Thinking hard!
2027 </thinking>
2028
2029 "#}
2030 );
2031 }
2032
2033 #[gpui::test]
2034 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2035 init_test(cx);
2036
2037 let fs = FakeFs::new(cx.executor());
2038 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2039 .await;
2040 let project = Project::test(fs.clone(), [], cx).await;
2041 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2042 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2043 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2044 move |_, thread, mut cx| {
2045 let read_file_tx = read_file_tx.clone();
2046 async move {
2047 let content = thread
2048 .update(&mut cx, |thread, cx| {
2049 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2050 })
2051 .unwrap()
2052 .await
2053 .unwrap();
2054 assert_eq!(content, "one\ntwo\nthree\n");
2055 read_file_tx.take().unwrap().send(()).unwrap();
2056 thread
2057 .update(&mut cx, |thread, cx| {
2058 thread.write_text_file(
2059 path!("/tmp/foo").into(),
2060 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2061 cx,
2062 )
2063 })
2064 .unwrap()
2065 .await
2066 .unwrap();
2067 Ok(acp::PromptResponse {
2068 stop_reason: acp::StopReason::EndTurn,
2069 })
2070 }
2071 .boxed_local()
2072 },
2073 ));
2074
2075 let (worktree, pathbuf) = project
2076 .update(cx, |project, cx| {
2077 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2078 })
2079 .await
2080 .unwrap();
2081 let buffer = project
2082 .update(cx, |project, cx| {
2083 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2084 })
2085 .await
2086 .unwrap();
2087
2088 let thread = cx
2089 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2090 .await
2091 .unwrap();
2092
2093 let request = thread.update(cx, |thread, cx| {
2094 thread.send_raw("Extend the count in /tmp/foo", cx)
2095 });
2096 read_file_rx.await.ok();
2097 buffer.update(cx, |buffer, cx| {
2098 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2099 });
2100 cx.run_until_parked();
2101 assert_eq!(
2102 buffer.read_with(cx, |buffer, _| buffer.text()),
2103 "zero\none\ntwo\nthree\nfour\nfive\n"
2104 );
2105 assert_eq!(
2106 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2107 "zero\none\ntwo\nthree\nfour\nfive\n"
2108 );
2109 request.await.unwrap();
2110 }
2111
2112 #[gpui::test]
2113 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2114 init_test(cx);
2115
2116 let fs = FakeFs::new(cx.executor());
2117 let project = Project::test(fs, [], cx).await;
2118 let id = acp::ToolCallId("test".into());
2119
2120 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2121 let id = id.clone();
2122 move |_, thread, mut cx| {
2123 let id = id.clone();
2124 async move {
2125 thread
2126 .update(&mut cx, |thread, cx| {
2127 thread.handle_session_update(
2128 acp::SessionUpdate::ToolCall(acp::ToolCall {
2129 id: id.clone(),
2130 title: "Label".into(),
2131 kind: acp::ToolKind::Fetch,
2132 status: acp::ToolCallStatus::InProgress,
2133 content: vec![],
2134 locations: vec![],
2135 raw_input: None,
2136 raw_output: None,
2137 }),
2138 cx,
2139 )
2140 })
2141 .unwrap()
2142 .unwrap();
2143 Ok(acp::PromptResponse {
2144 stop_reason: acp::StopReason::EndTurn,
2145 })
2146 }
2147 .boxed_local()
2148 }
2149 }));
2150
2151 let thread = cx
2152 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2153 .await
2154 .unwrap();
2155
2156 let request = thread.update(cx, |thread, cx| {
2157 thread.send_raw("Fetch https://example.com", cx)
2158 });
2159
2160 run_until_first_tool_call(&thread, cx).await;
2161
2162 thread.read_with(cx, |thread, _| {
2163 assert!(matches!(
2164 thread.entries[1],
2165 AgentThreadEntry::ToolCall(ToolCall {
2166 status: ToolCallStatus::InProgress,
2167 ..
2168 })
2169 ));
2170 });
2171
2172 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2173
2174 thread.read_with(cx, |thread, _| {
2175 assert!(matches!(
2176 &thread.entries[1],
2177 AgentThreadEntry::ToolCall(ToolCall {
2178 status: ToolCallStatus::Canceled,
2179 ..
2180 })
2181 ));
2182 });
2183
2184 thread
2185 .update(cx, |thread, cx| {
2186 thread.handle_session_update(
2187 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2188 id,
2189 fields: acp::ToolCallUpdateFields {
2190 status: Some(acp::ToolCallStatus::Completed),
2191 ..Default::default()
2192 },
2193 }),
2194 cx,
2195 )
2196 })
2197 .unwrap();
2198
2199 request.await.unwrap();
2200
2201 thread.read_with(cx, |thread, _| {
2202 assert!(matches!(
2203 thread.entries[1],
2204 AgentThreadEntry::ToolCall(ToolCall {
2205 status: ToolCallStatus::Completed,
2206 ..
2207 })
2208 ));
2209 });
2210 }
2211
2212 #[gpui::test]
2213 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2214 init_test(cx);
2215 let fs = FakeFs::new(cx.background_executor.clone());
2216 fs.insert_tree(path!("/test"), json!({})).await;
2217 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2218
2219 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2220 move |_, thread, mut cx| {
2221 async move {
2222 thread
2223 .update(&mut cx, |thread, cx| {
2224 thread.handle_session_update(
2225 acp::SessionUpdate::ToolCall(acp::ToolCall {
2226 id: acp::ToolCallId("test".into()),
2227 title: "Label".into(),
2228 kind: acp::ToolKind::Edit,
2229 status: acp::ToolCallStatus::Completed,
2230 content: vec![acp::ToolCallContent::Diff {
2231 diff: acp::Diff {
2232 path: "/test/test.txt".into(),
2233 old_text: None,
2234 new_text: "foo".into(),
2235 },
2236 }],
2237 locations: vec![],
2238 raw_input: None,
2239 raw_output: None,
2240 }),
2241 cx,
2242 )
2243 })
2244 .unwrap()
2245 .unwrap();
2246 Ok(acp::PromptResponse {
2247 stop_reason: acp::StopReason::EndTurn,
2248 })
2249 }
2250 .boxed_local()
2251 }
2252 }));
2253
2254 let thread = cx
2255 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2256 .await
2257 .unwrap();
2258
2259 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2260 .await
2261 .unwrap();
2262
2263 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2264 }
2265
2266 #[gpui::test(iterations = 10)]
2267 async fn test_checkpoints(cx: &mut TestAppContext) {
2268 init_test(cx);
2269 let fs = FakeFs::new(cx.background_executor.clone());
2270 fs.insert_tree(
2271 path!("/test"),
2272 json!({
2273 ".git": {}
2274 }),
2275 )
2276 .await;
2277 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2278
2279 let simulate_changes = Arc::new(AtomicBool::new(true));
2280 let next_filename = Arc::new(AtomicUsize::new(0));
2281 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2282 let simulate_changes = simulate_changes.clone();
2283 let next_filename = next_filename.clone();
2284 let fs = fs.clone();
2285 move |request, thread, mut cx| {
2286 let fs = fs.clone();
2287 let simulate_changes = simulate_changes.clone();
2288 let next_filename = next_filename.clone();
2289 async move {
2290 if simulate_changes.load(SeqCst) {
2291 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2292 fs.write(Path::new(&filename), b"").await?;
2293 }
2294
2295 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2296 panic!("expected text content block");
2297 };
2298 thread.update(&mut cx, |thread, cx| {
2299 thread
2300 .handle_session_update(
2301 acp::SessionUpdate::AgentMessageChunk {
2302 content: content.text.to_uppercase().into(),
2303 },
2304 cx,
2305 )
2306 .unwrap();
2307 })?;
2308 Ok(acp::PromptResponse {
2309 stop_reason: acp::StopReason::EndTurn,
2310 })
2311 }
2312 .boxed_local()
2313 }
2314 }));
2315 let thread = cx
2316 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2317 .await
2318 .unwrap();
2319
2320 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2321 .await
2322 .unwrap();
2323 thread.read_with(cx, |thread, cx| {
2324 assert_eq!(
2325 thread.to_markdown(cx),
2326 indoc! {"
2327 ## User (checkpoint)
2328
2329 Lorem
2330
2331 ## Assistant
2332
2333 LOREM
2334
2335 "}
2336 );
2337 });
2338 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2339
2340 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2341 .await
2342 .unwrap();
2343 thread.read_with(cx, |thread, cx| {
2344 assert_eq!(
2345 thread.to_markdown(cx),
2346 indoc! {"
2347 ## User (checkpoint)
2348
2349 Lorem
2350
2351 ## Assistant
2352
2353 LOREM
2354
2355 ## User (checkpoint)
2356
2357 ipsum
2358
2359 ## Assistant
2360
2361 IPSUM
2362
2363 "}
2364 );
2365 });
2366 assert_eq!(
2367 fs.files(),
2368 vec![
2369 Path::new(path!("/test/file-0")),
2370 Path::new(path!("/test/file-1"))
2371 ]
2372 );
2373
2374 // Checkpoint isn't stored when there are no changes.
2375 simulate_changes.store(false, SeqCst);
2376 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2377 .await
2378 .unwrap();
2379 thread.read_with(cx, |thread, cx| {
2380 assert_eq!(
2381 thread.to_markdown(cx),
2382 indoc! {"
2383 ## User (checkpoint)
2384
2385 Lorem
2386
2387 ## Assistant
2388
2389 LOREM
2390
2391 ## User (checkpoint)
2392
2393 ipsum
2394
2395 ## Assistant
2396
2397 IPSUM
2398
2399 ## User
2400
2401 dolor
2402
2403 ## Assistant
2404
2405 DOLOR
2406
2407 "}
2408 );
2409 });
2410 assert_eq!(
2411 fs.files(),
2412 vec![
2413 Path::new(path!("/test/file-0")),
2414 Path::new(path!("/test/file-1"))
2415 ]
2416 );
2417
2418 // Rewinding the conversation truncates the history and restores the checkpoint.
2419 thread
2420 .update(cx, |thread, cx| {
2421 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2422 panic!("unexpected entries {:?}", thread.entries)
2423 };
2424 thread.rewind(message.id.clone().unwrap(), cx)
2425 })
2426 .await
2427 .unwrap();
2428 thread.read_with(cx, |thread, cx| {
2429 assert_eq!(
2430 thread.to_markdown(cx),
2431 indoc! {"
2432 ## User (checkpoint)
2433
2434 Lorem
2435
2436 ## Assistant
2437
2438 LOREM
2439
2440 "}
2441 );
2442 });
2443 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2444 }
2445
2446 #[gpui::test]
2447 async fn test_refusal(cx: &mut TestAppContext) {
2448 init_test(cx);
2449 let fs = FakeFs::new(cx.background_executor.clone());
2450 fs.insert_tree(path!("/"), json!({})).await;
2451 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2452
2453 let refuse_next = Arc::new(AtomicBool::new(false));
2454 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2455 let refuse_next = refuse_next.clone();
2456 move |request, thread, mut cx| {
2457 let refuse_next = refuse_next.clone();
2458 async move {
2459 if refuse_next.load(SeqCst) {
2460 return Ok(acp::PromptResponse {
2461 stop_reason: acp::StopReason::Refusal,
2462 });
2463 }
2464
2465 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2466 panic!("expected text content block");
2467 };
2468 thread.update(&mut cx, |thread, cx| {
2469 thread
2470 .handle_session_update(
2471 acp::SessionUpdate::AgentMessageChunk {
2472 content: content.text.to_uppercase().into(),
2473 },
2474 cx,
2475 )
2476 .unwrap();
2477 })?;
2478 Ok(acp::PromptResponse {
2479 stop_reason: acp::StopReason::EndTurn,
2480 })
2481 }
2482 .boxed_local()
2483 }
2484 }));
2485 let thread = cx
2486 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2487 .await
2488 .unwrap();
2489
2490 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2491 .await
2492 .unwrap();
2493 thread.read_with(cx, |thread, cx| {
2494 assert_eq!(
2495 thread.to_markdown(cx),
2496 indoc! {"
2497 ## User
2498
2499 hello
2500
2501 ## Assistant
2502
2503 HELLO
2504
2505 "}
2506 );
2507 });
2508
2509 // Simulate refusing the second message, ensuring the conversation gets
2510 // truncated to before sending it.
2511 refuse_next.store(true, SeqCst);
2512 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2513 .await
2514 .unwrap();
2515 thread.read_with(cx, |thread, cx| {
2516 assert_eq!(
2517 thread.to_markdown(cx),
2518 indoc! {"
2519 ## User
2520
2521 hello
2522
2523 ## Assistant
2524
2525 HELLO
2526
2527 "}
2528 );
2529 });
2530 }
2531
2532 async fn run_until_first_tool_call(
2533 thread: &Entity<AcpThread>,
2534 cx: &mut TestAppContext,
2535 ) -> usize {
2536 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2537
2538 let subscription = cx.update(|cx| {
2539 cx.subscribe(thread, move |thread, _, cx| {
2540 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2541 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2542 return tx.try_send(ix).unwrap();
2543 }
2544 }
2545 })
2546 });
2547
2548 select! {
2549 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2550 panic!("Timeout waiting for tool call")
2551 }
2552 ix = rx.next().fuse() => {
2553 drop(subscription);
2554 ix.unwrap()
2555 }
2556 }
2557 }
2558
2559 #[derive(Clone, Default)]
2560 struct FakeAgentConnection {
2561 auth_methods: Vec<acp::AuthMethod>,
2562 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2563 on_user_message: Option<
2564 Rc<
2565 dyn Fn(
2566 acp::PromptRequest,
2567 WeakEntity<AcpThread>,
2568 AsyncApp,
2569 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2570 + 'static,
2571 >,
2572 >,
2573 }
2574
2575 impl FakeAgentConnection {
2576 fn new() -> Self {
2577 Self {
2578 auth_methods: Vec::new(),
2579 on_user_message: None,
2580 sessions: Arc::default(),
2581 }
2582 }
2583
2584 #[expect(unused)]
2585 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2586 self.auth_methods = auth_methods;
2587 self
2588 }
2589
2590 fn on_user_message(
2591 mut self,
2592 handler: impl Fn(
2593 acp::PromptRequest,
2594 WeakEntity<AcpThread>,
2595 AsyncApp,
2596 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2597 + 'static,
2598 ) -> Self {
2599 self.on_user_message.replace(Rc::new(handler));
2600 self
2601 }
2602 }
2603
2604 impl AgentConnection for FakeAgentConnection {
2605 fn auth_methods(&self) -> &[acp::AuthMethod] {
2606 &self.auth_methods
2607 }
2608
2609 fn new_thread(
2610 self: Rc<Self>,
2611 project: Entity<Project>,
2612 _cwd: &Path,
2613 cx: &mut App,
2614 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2615 let session_id = acp::SessionId(
2616 rand::thread_rng()
2617 .sample_iter(&rand::distributions::Alphanumeric)
2618 .take(7)
2619 .map(char::from)
2620 .collect::<String>()
2621 .into(),
2622 );
2623 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2624 let thread = cx.new(|cx| {
2625 AcpThread::new(
2626 "Test",
2627 self.clone(),
2628 project,
2629 action_log,
2630 session_id.clone(),
2631 watch::Receiver::constant(acp::PromptCapabilities {
2632 image: true,
2633 audio: true,
2634 embedded_context: true,
2635 }),
2636 cx,
2637 )
2638 });
2639 self.sessions.lock().insert(session_id, thread.downgrade());
2640 Task::ready(Ok(thread))
2641 }
2642
2643 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2644 if self.auth_methods().iter().any(|m| m.id == method) {
2645 Task::ready(Ok(()))
2646 } else {
2647 Task::ready(Err(anyhow!("Invalid Auth Method")))
2648 }
2649 }
2650
2651 fn prompt(
2652 &self,
2653 _id: Option<UserMessageId>,
2654 params: acp::PromptRequest,
2655 cx: &mut App,
2656 ) -> Task<gpui::Result<acp::PromptResponse>> {
2657 let sessions = self.sessions.lock();
2658 let thread = sessions.get(¶ms.session_id).unwrap();
2659 if let Some(handler) = &self.on_user_message {
2660 let handler = handler.clone();
2661 let thread = thread.clone();
2662 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2663 } else {
2664 Task::ready(Ok(acp::PromptResponse {
2665 stop_reason: acp::StopReason::EndTurn,
2666 }))
2667 }
2668 }
2669
2670 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2671 let sessions = self.sessions.lock();
2672 let thread = sessions.get(session_id).unwrap().clone();
2673
2674 cx.spawn(async move |cx| {
2675 thread
2676 .update(cx, |thread, cx| thread.cancel(cx))
2677 .unwrap()
2678 .await
2679 })
2680 .detach();
2681 }
2682
2683 fn truncate(
2684 &self,
2685 session_id: &acp::SessionId,
2686 _cx: &App,
2687 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2688 Some(Rc::new(FakeAgentSessionEditor {
2689 _session_id: session_id.clone(),
2690 }))
2691 }
2692
2693 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2694 self
2695 }
2696 }
2697
2698 struct FakeAgentSessionEditor {
2699 _session_id: acp::SessionId,
2700 }
2701
2702 impl AgentSessionTruncate for FakeAgentSessionEditor {
2703 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2704 Task::ready(Ok(()))
2705 }
2706 }
2707}