1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
187 first_line.to_owned() + "…"
188 } else {
189 tool_call.title
190 };
191 Self {
192 id: tool_call.id,
193 label: cx
194 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
195 kind: tool_call.kind,
196 content: tool_call
197 .content
198 .into_iter()
199 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
200 .collect(),
201 locations: tool_call.locations,
202 resolved_locations: Vec::default(),
203 status,
204 raw_input: tool_call.raw_input,
205 raw_output: tool_call.raw_output,
206 }
207 }
208
209 fn update_fields(
210 &mut self,
211 fields: acp::ToolCallUpdateFields,
212 language_registry: Arc<LanguageRegistry>,
213 cx: &mut App,
214 ) {
215 let acp::ToolCallUpdateFields {
216 kind,
217 status,
218 title,
219 content,
220 locations,
221 raw_input,
222 raw_output,
223 } = fields;
224
225 if let Some(kind) = kind {
226 self.kind = kind;
227 }
228
229 if let Some(status) = status {
230 self.status = status.into();
231 }
232
233 if let Some(title) = title {
234 self.label.update(cx, |label, cx| {
235 if let Some((first_line, _)) = title.split_once("\n") {
236 label.replace(first_line.to_owned() + "…", cx)
237 } else {
238 label.replace(title, cx);
239 }
240 });
241 }
242
243 if let Some(content) = content {
244 let new_content_len = content.len();
245 let mut content = content.into_iter();
246
247 // Reuse existing content if we can
248 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
249 old.update_from_acp(new, language_registry.clone(), cx);
250 }
251 for new in content {
252 self.content.push(ToolCallContent::from_acp(
253 new,
254 language_registry.clone(),
255 cx,
256 ))
257 }
258 self.content.truncate(new_content_len);
259 }
260
261 if let Some(locations) = locations {
262 self.locations = locations;
263 }
264
265 if let Some(raw_input) = raw_input {
266 self.raw_input = Some(raw_input);
267 }
268
269 if let Some(raw_output) = raw_output {
270 if self.content.is_empty()
271 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
272 {
273 self.content
274 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
275 markdown,
276 }));
277 }
278 self.raw_output = Some(raw_output);
279 }
280 }
281
282 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
283 self.content.iter().filter_map(|content| match content {
284 ToolCallContent::Diff(diff) => Some(diff),
285 ToolCallContent::ContentBlock(_) => None,
286 ToolCallContent::Terminal(_) => None,
287 })
288 }
289
290 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
291 self.content.iter().filter_map(|content| match content {
292 ToolCallContent::Terminal(terminal) => Some(terminal),
293 ToolCallContent::ContentBlock(_) => None,
294 ToolCallContent::Diff(_) => None,
295 })
296 }
297
298 fn to_markdown(&self, cx: &App) -> String {
299 let mut markdown = format!(
300 "**Tool Call: {}**\nStatus: {}\n\n",
301 self.label.read(cx).source(),
302 self.status
303 );
304 for content in &self.content {
305 markdown.push_str(content.to_markdown(cx).as_str());
306 markdown.push_str("\n\n");
307 }
308 markdown
309 }
310
311 async fn resolve_location(
312 location: acp::ToolCallLocation,
313 project: WeakEntity<Project>,
314 cx: &mut AsyncApp,
315 ) -> Option<AgentLocation> {
316 let buffer = project
317 .update(cx, |project, cx| {
318 project
319 .project_path_for_absolute_path(&location.path, cx)
320 .map(|path| project.open_buffer(path, cx))
321 })
322 .ok()??;
323 let buffer = buffer.await.log_err()?;
324 let position = buffer
325 .update(cx, |buffer, _| {
326 if let Some(row) = location.line {
327 let snapshot = buffer.snapshot();
328 let column = snapshot.indent_size_for_line(row).len;
329 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
330 snapshot.anchor_before(point)
331 } else {
332 Anchor::MIN
333 }
334 })
335 .ok()?;
336
337 Some(AgentLocation {
338 buffer: buffer.downgrade(),
339 position,
340 })
341 }
342
343 fn resolve_locations(
344 &self,
345 project: Entity<Project>,
346 cx: &mut App,
347 ) -> Task<Vec<Option<AgentLocation>>> {
348 let locations = self.locations.clone();
349 project.update(cx, |_, cx| {
350 cx.spawn(async move |project, cx| {
351 let mut new_locations = Vec::new();
352 for location in locations {
353 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
354 }
355 new_locations
356 })
357 })
358 }
359}
360
361#[derive(Debug)]
362pub enum ToolCallStatus {
363 /// The tool call hasn't started running yet, but we start showing it to
364 /// the user.
365 Pending,
366 /// The tool call is waiting for confirmation from the user.
367 WaitingForConfirmation {
368 options: Vec<acp::PermissionOption>,
369 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
370 },
371 /// The tool call is currently running.
372 InProgress,
373 /// The tool call completed successfully.
374 Completed,
375 /// The tool call failed.
376 Failed,
377 /// The user rejected the tool call.
378 Rejected,
379 /// The user canceled generation so the tool call was canceled.
380 Canceled,
381}
382
383impl From<acp::ToolCallStatus> for ToolCallStatus {
384 fn from(status: acp::ToolCallStatus) -> Self {
385 match status {
386 acp::ToolCallStatus::Pending => Self::Pending,
387 acp::ToolCallStatus::InProgress => Self::InProgress,
388 acp::ToolCallStatus::Completed => Self::Completed,
389 acp::ToolCallStatus::Failed => Self::Failed,
390 }
391 }
392}
393
394impl Display for ToolCallStatus {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 write!(
397 f,
398 "{}",
399 match self {
400 ToolCallStatus::Pending => "Pending",
401 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
402 ToolCallStatus::InProgress => "In Progress",
403 ToolCallStatus::Completed => "Completed",
404 ToolCallStatus::Failed => "Failed",
405 ToolCallStatus::Rejected => "Rejected",
406 ToolCallStatus::Canceled => "Canceled",
407 }
408 )
409 }
410}
411
412#[derive(Debug, PartialEq, Clone)]
413pub enum ContentBlock {
414 Empty,
415 Markdown { markdown: Entity<Markdown> },
416 ResourceLink { resource_link: acp::ResourceLink },
417}
418
419impl ContentBlock {
420 pub fn new(
421 block: acp::ContentBlock,
422 language_registry: &Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let mut this = Self::Empty;
426 this.append(block, language_registry, cx);
427 this
428 }
429
430 pub fn new_combined(
431 blocks: impl IntoIterator<Item = acp::ContentBlock>,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 let mut this = Self::Empty;
436 for block in blocks {
437 this.append(block, &language_registry, cx);
438 }
439 this
440 }
441
442 pub fn append(
443 &mut self,
444 block: acp::ContentBlock,
445 language_registry: &Arc<LanguageRegistry>,
446 cx: &mut App,
447 ) {
448 if matches!(self, ContentBlock::Empty)
449 && let acp::ContentBlock::ResourceLink(resource_link) = block
450 {
451 *self = ContentBlock::ResourceLink { resource_link };
452 return;
453 }
454
455 let new_content = self.block_string_contents(block);
456
457 match self {
458 ContentBlock::Empty => {
459 *self = Self::create_markdown_block(new_content, language_registry, cx);
460 }
461 ContentBlock::Markdown { markdown } => {
462 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
463 }
464 ContentBlock::ResourceLink { resource_link } => {
465 let existing_content = Self::resource_link_md(&resource_link.uri);
466 let combined = format!("{}\n{}", existing_content, new_content);
467
468 *self = Self::create_markdown_block(combined, language_registry, cx);
469 }
470 }
471 }
472
473 fn create_markdown_block(
474 content: String,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) -> ContentBlock {
478 ContentBlock::Markdown {
479 markdown: cx
480 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
481 }
482 }
483
484 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
485 match block {
486 acp::ContentBlock::Text(text_content) => text_content.text,
487 acp::ContentBlock::ResourceLink(resource_link) => {
488 Self::resource_link_md(&resource_link.uri)
489 }
490 acp::ContentBlock::Resource(acp::EmbeddedResource {
491 resource:
492 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
493 uri,
494 ..
495 }),
496 ..
497 }) => Self::resource_link_md(&uri),
498 acp::ContentBlock::Image(image) => Self::image_md(&image),
499 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
500 }
501 }
502
503 fn resource_link_md(uri: &str) -> String {
504 if let Some(uri) = MentionUri::parse(uri).log_err() {
505 uri.as_link().to_string()
506 } else {
507 uri.to_string()
508 }
509 }
510
511 fn image_md(_image: &acp::ImageContent) -> String {
512 "`Image`".into()
513 }
514
515 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
516 match self {
517 ContentBlock::Empty => "",
518 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
519 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
520 }
521 }
522
523 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
524 match self {
525 ContentBlock::Empty => None,
526 ContentBlock::Markdown { markdown } => Some(markdown),
527 ContentBlock::ResourceLink { .. } => None,
528 }
529 }
530
531 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
532 match self {
533 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
534 _ => None,
535 }
536 }
537}
538
539#[derive(Debug)]
540pub enum ToolCallContent {
541 ContentBlock(ContentBlock),
542 Diff(Entity<Diff>),
543 Terminal(Entity<Terminal>),
544}
545
546impl ToolCallContent {
547 pub fn from_acp(
548 content: acp::ToolCallContent,
549 language_registry: Arc<LanguageRegistry>,
550 cx: &mut App,
551 ) -> Self {
552 match content {
553 acp::ToolCallContent::Content { content } => {
554 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
555 }
556 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
557 Diff::finalized(
558 diff.path,
559 diff.old_text,
560 diff.new_text,
561 language_registry,
562 cx,
563 )
564 })),
565 }
566 }
567
568 pub fn update_from_acp(
569 &mut self,
570 new: acp::ToolCallContent,
571 language_registry: Arc<LanguageRegistry>,
572 cx: &mut App,
573 ) {
574 let needs_update = match (&self, &new) {
575 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
576 old_diff.read(cx).needs_update(
577 new_diff.old_text.as_deref().unwrap_or(""),
578 &new_diff.new_text,
579 cx,
580 )
581 }
582 _ => true,
583 };
584
585 if needs_update {
586 *self = Self::from_acp(new, language_registry, cx);
587 }
588 }
589
590 pub fn to_markdown(&self, cx: &App) -> String {
591 match self {
592 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
593 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
594 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
595 }
596 }
597}
598
599#[derive(Debug, PartialEq)]
600pub enum ToolCallUpdate {
601 UpdateFields(acp::ToolCallUpdate),
602 UpdateDiff(ToolCallUpdateDiff),
603 UpdateTerminal(ToolCallUpdateTerminal),
604}
605
606impl ToolCallUpdate {
607 fn id(&self) -> &acp::ToolCallId {
608 match self {
609 Self::UpdateFields(update) => &update.id,
610 Self::UpdateDiff(diff) => &diff.id,
611 Self::UpdateTerminal(terminal) => &terminal.id,
612 }
613 }
614}
615
616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
617 fn from(update: acp::ToolCallUpdate) -> Self {
618 Self::UpdateFields(update)
619 }
620}
621
622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
623 fn from(diff: ToolCallUpdateDiff) -> Self {
624 Self::UpdateDiff(diff)
625 }
626}
627
628#[derive(Debug, PartialEq)]
629pub struct ToolCallUpdateDiff {
630 pub id: acp::ToolCallId,
631 pub diff: Entity<Diff>,
632}
633
634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
635 fn from(terminal: ToolCallUpdateTerminal) -> Self {
636 Self::UpdateTerminal(terminal)
637 }
638}
639
640#[derive(Debug, PartialEq)]
641pub struct ToolCallUpdateTerminal {
642 pub id: acp::ToolCallId,
643 pub terminal: Entity<Terminal>,
644}
645
646#[derive(Debug, Default)]
647pub struct Plan {
648 pub entries: Vec<PlanEntry>,
649}
650
651#[derive(Debug)]
652pub struct PlanStats<'a> {
653 pub in_progress_entry: Option<&'a PlanEntry>,
654 pub pending: u32,
655 pub completed: u32,
656}
657
658impl Plan {
659 pub fn is_empty(&self) -> bool {
660 self.entries.is_empty()
661 }
662
663 pub fn stats(&self) -> PlanStats<'_> {
664 let mut stats = PlanStats {
665 in_progress_entry: None,
666 pending: 0,
667 completed: 0,
668 };
669
670 for entry in &self.entries {
671 match &entry.status {
672 acp::PlanEntryStatus::Pending => {
673 stats.pending += 1;
674 }
675 acp::PlanEntryStatus::InProgress => {
676 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
677 }
678 acp::PlanEntryStatus::Completed => {
679 stats.completed += 1;
680 }
681 }
682 }
683
684 stats
685 }
686}
687
688#[derive(Debug)]
689pub struct PlanEntry {
690 pub content: Entity<Markdown>,
691 pub priority: acp::PlanEntryPriority,
692 pub status: acp::PlanEntryStatus,
693}
694
695impl PlanEntry {
696 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
697 Self {
698 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
699 priority: entry.priority,
700 status: entry.status,
701 }
702 }
703}
704
705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct TokenUsage {
707 pub max_tokens: u64,
708 pub used_tokens: u64,
709}
710
711impl TokenUsage {
712 pub fn ratio(&self) -> TokenUsageRatio {
713 #[cfg(debug_assertions)]
714 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
715 .unwrap_or("0.8".to_string())
716 .parse()
717 .unwrap();
718 #[cfg(not(debug_assertions))]
719 let warning_threshold: f32 = 0.8;
720
721 // When the maximum is unknown because there is no selected model,
722 // avoid showing the token limit warning.
723 if self.max_tokens == 0 {
724 TokenUsageRatio::Normal
725 } else if self.used_tokens >= self.max_tokens {
726 TokenUsageRatio::Exceeded
727 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
728 TokenUsageRatio::Warning
729 } else {
730 TokenUsageRatio::Normal
731 }
732 }
733}
734
735#[derive(Debug, Clone, PartialEq, Eq)]
736pub enum TokenUsageRatio {
737 Normal,
738 Warning,
739 Exceeded,
740}
741
742#[derive(Debug, Clone)]
743pub struct RetryStatus {
744 pub last_error: SharedString,
745 pub attempt: usize,
746 pub max_attempts: usize,
747 pub started_at: Instant,
748 pub duration: Duration,
749}
750
751pub struct AcpThread {
752 title: SharedString,
753 entries: Vec<AgentThreadEntry>,
754 plan: Plan,
755 project: Entity<Project>,
756 action_log: Entity<ActionLog>,
757 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
758 send_task: Option<Task<()>>,
759 connection: Rc<dyn AgentConnection>,
760 session_id: acp::SessionId,
761 token_usage: Option<TokenUsage>,
762 prompt_capabilities: acp::PromptCapabilities,
763 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
764}
765
766#[derive(Debug)]
767pub enum AcpThreadEvent {
768 NewEntry,
769 TitleUpdated,
770 TokenUsageUpdated,
771 EntryUpdated(usize),
772 EntriesRemoved(Range<usize>),
773 ToolAuthorizationRequired,
774 Retry(RetryStatus),
775 Stopped,
776 Error,
777 LoadError(LoadError),
778 PromptCapabilitiesUpdated,
779}
780
781impl EventEmitter<AcpThreadEvent> for AcpThread {}
782
783#[derive(PartialEq, Eq, Debug)]
784pub enum ThreadStatus {
785 Idle,
786 WaitingForToolConfirmation,
787 Generating,
788}
789
790#[derive(Debug, Clone)]
791pub enum LoadError {
792 Unsupported {
793 command: SharedString,
794 current_version: SharedString,
795 minimum_version: SharedString,
796 },
797 FailedToInstall(SharedString),
798 Exited {
799 status: ExitStatus,
800 },
801 Other(SharedString),
802}
803
804impl Display for LoadError {
805 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
806 match self {
807 LoadError::Unsupported {
808 command: path,
809 current_version,
810 minimum_version,
811 } => {
812 write!(
813 f,
814 "version {current_version} from {path} is not supported (need at least {minimum_version})"
815 )
816 }
817 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
818 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
819 LoadError::Other(msg) => write!(f, "{msg}"),
820 }
821 }
822}
823
824impl Error for LoadError {}
825
826impl AcpThread {
827 pub fn new(
828 title: impl Into<SharedString>,
829 connection: Rc<dyn AgentConnection>,
830 project: Entity<Project>,
831 action_log: Entity<ActionLog>,
832 session_id: acp::SessionId,
833 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
834 cx: &mut Context<Self>,
835 ) -> Self {
836 let prompt_capabilities = *prompt_capabilities_rx.borrow();
837 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
838 loop {
839 let caps = prompt_capabilities_rx.recv().await?;
840 this.update(cx, |this, cx| {
841 this.prompt_capabilities = caps;
842 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
843 })?;
844 }
845 });
846
847 Self {
848 action_log,
849 shared_buffers: Default::default(),
850 entries: Default::default(),
851 plan: Default::default(),
852 title: title.into(),
853 project,
854 send_task: None,
855 connection,
856 session_id,
857 token_usage: None,
858 prompt_capabilities,
859 _observe_prompt_capabilities: task,
860 }
861 }
862
863 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
864 self.prompt_capabilities
865 }
866
867 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
868 &self.connection
869 }
870
871 pub fn action_log(&self) -> &Entity<ActionLog> {
872 &self.action_log
873 }
874
875 pub fn project(&self) -> &Entity<Project> {
876 &self.project
877 }
878
879 pub fn title(&self) -> SharedString {
880 self.title.clone()
881 }
882
883 pub fn entries(&self) -> &[AgentThreadEntry] {
884 &self.entries
885 }
886
887 pub fn session_id(&self) -> &acp::SessionId {
888 &self.session_id
889 }
890
891 pub fn status(&self) -> ThreadStatus {
892 if self.send_task.is_some() {
893 if self.waiting_for_tool_confirmation() {
894 ThreadStatus::WaitingForToolConfirmation
895 } else {
896 ThreadStatus::Generating
897 }
898 } else {
899 ThreadStatus::Idle
900 }
901 }
902
903 pub fn token_usage(&self) -> Option<&TokenUsage> {
904 self.token_usage.as_ref()
905 }
906
907 pub fn has_pending_edit_tool_calls(&self) -> bool {
908 for entry in self.entries.iter().rev() {
909 match entry {
910 AgentThreadEntry::UserMessage(_) => return false,
911 AgentThreadEntry::ToolCall(
912 call @ ToolCall {
913 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
914 ..
915 },
916 ) if call.diffs().next().is_some() => {
917 return true;
918 }
919 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
920 }
921 }
922
923 false
924 }
925
926 pub fn used_tools_since_last_user_message(&self) -> bool {
927 for entry in self.entries.iter().rev() {
928 match entry {
929 AgentThreadEntry::UserMessage(..) => return false,
930 AgentThreadEntry::AssistantMessage(..) => continue,
931 AgentThreadEntry::ToolCall(..) => return true,
932 }
933 }
934
935 false
936 }
937
938 pub fn handle_session_update(
939 &mut self,
940 update: acp::SessionUpdate,
941 cx: &mut Context<Self>,
942 ) -> Result<(), acp::Error> {
943 match update {
944 acp::SessionUpdate::UserMessageChunk { content } => {
945 self.push_user_content_block(None, content, cx);
946 }
947 acp::SessionUpdate::AgentMessageChunk { content } => {
948 self.push_assistant_content_block(content, false, cx);
949 }
950 acp::SessionUpdate::AgentThoughtChunk { content } => {
951 self.push_assistant_content_block(content, true, cx);
952 }
953 acp::SessionUpdate::ToolCall(tool_call) => {
954 self.upsert_tool_call(tool_call, cx)?;
955 }
956 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
957 self.update_tool_call(tool_call_update, cx)?;
958 }
959 acp::SessionUpdate::Plan(plan) => {
960 self.update_plan(plan, cx);
961 }
962 }
963 Ok(())
964 }
965
966 pub fn push_user_content_block(
967 &mut self,
968 message_id: Option<UserMessageId>,
969 chunk: acp::ContentBlock,
970 cx: &mut Context<Self>,
971 ) {
972 let language_registry = self.project.read(cx).languages().clone();
973 let entries_len = self.entries.len();
974
975 if let Some(last_entry) = self.entries.last_mut()
976 && let AgentThreadEntry::UserMessage(UserMessage {
977 id,
978 content,
979 chunks,
980 ..
981 }) = last_entry
982 {
983 *id = message_id.or(id.take());
984 content.append(chunk.clone(), &language_registry, cx);
985 chunks.push(chunk);
986 let idx = entries_len - 1;
987 cx.emit(AcpThreadEvent::EntryUpdated(idx));
988 } else {
989 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
990 self.push_entry(
991 AgentThreadEntry::UserMessage(UserMessage {
992 id: message_id,
993 content,
994 chunks: vec![chunk],
995 checkpoint: None,
996 }),
997 cx,
998 );
999 }
1000 }
1001
1002 pub fn push_assistant_content_block(
1003 &mut self,
1004 chunk: acp::ContentBlock,
1005 is_thought: bool,
1006 cx: &mut Context<Self>,
1007 ) {
1008 let language_registry = self.project.read(cx).languages().clone();
1009 let entries_len = self.entries.len();
1010 if let Some(last_entry) = self.entries.last_mut()
1011 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1012 {
1013 let idx = entries_len - 1;
1014 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1015 match (chunks.last_mut(), is_thought) {
1016 (Some(AssistantMessageChunk::Message { block }), false)
1017 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1018 block.append(chunk, &language_registry, cx)
1019 }
1020 _ => {
1021 let block = ContentBlock::new(chunk, &language_registry, cx);
1022 if is_thought {
1023 chunks.push(AssistantMessageChunk::Thought { block })
1024 } else {
1025 chunks.push(AssistantMessageChunk::Message { block })
1026 }
1027 }
1028 }
1029 } else {
1030 let block = ContentBlock::new(chunk, &language_registry, cx);
1031 let chunk = if is_thought {
1032 AssistantMessageChunk::Thought { block }
1033 } else {
1034 AssistantMessageChunk::Message { block }
1035 };
1036
1037 self.push_entry(
1038 AgentThreadEntry::AssistantMessage(AssistantMessage {
1039 chunks: vec![chunk],
1040 }),
1041 cx,
1042 );
1043 }
1044 }
1045
1046 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1047 self.entries.push(entry);
1048 cx.emit(AcpThreadEvent::NewEntry);
1049 }
1050
1051 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1052 self.connection.set_title(&self.session_id, cx).is_some()
1053 }
1054
1055 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1056 if title != self.title {
1057 self.title = title.clone();
1058 cx.emit(AcpThreadEvent::TitleUpdated);
1059 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1060 return set_title.run(title, cx);
1061 }
1062 }
1063 Task::ready(Ok(()))
1064 }
1065
1066 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1067 self.token_usage = usage;
1068 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1069 }
1070
1071 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1072 cx.emit(AcpThreadEvent::Retry(status));
1073 }
1074
1075 pub fn update_tool_call(
1076 &mut self,
1077 update: impl Into<ToolCallUpdate>,
1078 cx: &mut Context<Self>,
1079 ) -> Result<()> {
1080 let update = update.into();
1081 let languages = self.project.read(cx).languages().clone();
1082
1083 let (ix, current_call) = self
1084 .tool_call_mut(update.id())
1085 .context("Tool call not found")?;
1086 match update {
1087 ToolCallUpdate::UpdateFields(update) => {
1088 let location_updated = update.fields.locations.is_some();
1089 current_call.update_fields(update.fields, languages, cx);
1090 if location_updated {
1091 self.resolve_locations(update.id, cx);
1092 }
1093 }
1094 ToolCallUpdate::UpdateDiff(update) => {
1095 current_call.content.clear();
1096 current_call
1097 .content
1098 .push(ToolCallContent::Diff(update.diff));
1099 }
1100 ToolCallUpdate::UpdateTerminal(update) => {
1101 current_call.content.clear();
1102 current_call
1103 .content
1104 .push(ToolCallContent::Terminal(update.terminal));
1105 }
1106 }
1107
1108 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1109
1110 Ok(())
1111 }
1112
1113 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1114 pub fn upsert_tool_call(
1115 &mut self,
1116 tool_call: acp::ToolCall,
1117 cx: &mut Context<Self>,
1118 ) -> Result<(), acp::Error> {
1119 let status = tool_call.status.into();
1120 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1121 }
1122
1123 /// Fails if id does not match an existing entry.
1124 pub fn upsert_tool_call_inner(
1125 &mut self,
1126 tool_call_update: acp::ToolCallUpdate,
1127 status: ToolCallStatus,
1128 cx: &mut Context<Self>,
1129 ) -> Result<(), acp::Error> {
1130 let language_registry = self.project.read(cx).languages().clone();
1131 let id = tool_call_update.id.clone();
1132
1133 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1134 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1135 current_call.status = status;
1136
1137 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1138 } else {
1139 let call =
1140 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1141 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1142 };
1143
1144 self.resolve_locations(id, cx);
1145 Ok(())
1146 }
1147
1148 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1149 // The tool call we are looking for is typically the last one, or very close to the end.
1150 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1151 self.entries
1152 .iter_mut()
1153 .enumerate()
1154 .rev()
1155 .find_map(|(index, tool_call)| {
1156 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1157 && &tool_call.id == id
1158 {
1159 Some((index, tool_call))
1160 } else {
1161 None
1162 }
1163 })
1164 }
1165
1166 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1167 self.entries
1168 .iter()
1169 .enumerate()
1170 .rev()
1171 .find_map(|(index, tool_call)| {
1172 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1173 && &tool_call.id == id
1174 {
1175 Some((index, tool_call))
1176 } else {
1177 None
1178 }
1179 })
1180 }
1181
1182 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1183 let project = self.project.clone();
1184 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1185 return;
1186 };
1187 let task = tool_call.resolve_locations(project, cx);
1188 cx.spawn(async move |this, cx| {
1189 let resolved_locations = task.await;
1190 this.update(cx, |this, cx| {
1191 let project = this.project.clone();
1192 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1193 return;
1194 };
1195 if let Some(Some(location)) = resolved_locations.last() {
1196 project.update(cx, |project, cx| {
1197 if let Some(agent_location) = project.agent_location() {
1198 let should_ignore = agent_location.buffer == location.buffer
1199 && location
1200 .buffer
1201 .update(cx, |buffer, _| {
1202 let snapshot = buffer.snapshot();
1203 let old_position =
1204 agent_location.position.to_point(&snapshot);
1205 let new_position = location.position.to_point(&snapshot);
1206 // ignore this so that when we get updates from the edit tool
1207 // the position doesn't reset to the startof line
1208 old_position.row == new_position.row
1209 && old_position.column > new_position.column
1210 })
1211 .ok()
1212 .unwrap_or_default();
1213 if !should_ignore {
1214 project.set_agent_location(Some(location.clone()), cx);
1215 }
1216 }
1217 });
1218 }
1219 if tool_call.resolved_locations != resolved_locations {
1220 tool_call.resolved_locations = resolved_locations;
1221 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1222 }
1223 })
1224 })
1225 .detach();
1226 }
1227
1228 pub fn request_tool_call_authorization(
1229 &mut self,
1230 tool_call: acp::ToolCallUpdate,
1231 options: Vec<acp::PermissionOption>,
1232 cx: &mut Context<Self>,
1233 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1234 let (tx, rx) = oneshot::channel();
1235
1236 let status = ToolCallStatus::WaitingForConfirmation {
1237 options,
1238 respond_tx: tx,
1239 };
1240
1241 self.upsert_tool_call_inner(tool_call, status, cx)?;
1242 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1243 Ok(rx)
1244 }
1245
1246 pub fn authorize_tool_call(
1247 &mut self,
1248 id: acp::ToolCallId,
1249 option_id: acp::PermissionOptionId,
1250 option_kind: acp::PermissionOptionKind,
1251 cx: &mut Context<Self>,
1252 ) {
1253 let Some((ix, call)) = self.tool_call_mut(&id) else {
1254 return;
1255 };
1256
1257 let new_status = match option_kind {
1258 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1259 ToolCallStatus::Rejected
1260 }
1261 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1262 ToolCallStatus::InProgress
1263 }
1264 };
1265
1266 let curr_status = mem::replace(&mut call.status, new_status);
1267
1268 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1269 respond_tx.send(option_id).log_err();
1270 } else if cfg!(debug_assertions) {
1271 panic!("tried to authorize an already authorized tool call");
1272 }
1273
1274 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1275 }
1276
1277 /// Returns true if the last turn is awaiting tool authorization
1278 pub fn waiting_for_tool_confirmation(&self) -> bool {
1279 for entry in self.entries.iter().rev() {
1280 match &entry {
1281 AgentThreadEntry::ToolCall(call) => match call.status {
1282 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1283 ToolCallStatus::Pending
1284 | ToolCallStatus::InProgress
1285 | ToolCallStatus::Completed
1286 | ToolCallStatus::Failed
1287 | ToolCallStatus::Rejected
1288 | ToolCallStatus::Canceled => continue,
1289 },
1290 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1291 // Reached the beginning of the turn
1292 return false;
1293 }
1294 }
1295 }
1296 false
1297 }
1298
1299 pub fn plan(&self) -> &Plan {
1300 &self.plan
1301 }
1302
1303 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1304 let new_entries_len = request.entries.len();
1305 let mut new_entries = request.entries.into_iter();
1306
1307 // Reuse existing markdown to prevent flickering
1308 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1309 let PlanEntry {
1310 content,
1311 priority,
1312 status,
1313 } = old;
1314 content.update(cx, |old, cx| {
1315 old.replace(new.content, cx);
1316 });
1317 *priority = new.priority;
1318 *status = new.status;
1319 }
1320 for new in new_entries {
1321 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1322 }
1323 self.plan.entries.truncate(new_entries_len);
1324
1325 cx.notify();
1326 }
1327
1328 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1329 self.plan
1330 .entries
1331 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1332 cx.notify();
1333 }
1334
1335 #[cfg(any(test, feature = "test-support"))]
1336 pub fn send_raw(
1337 &mut self,
1338 message: &str,
1339 cx: &mut Context<Self>,
1340 ) -> BoxFuture<'static, Result<()>> {
1341 self.send(
1342 vec![acp::ContentBlock::Text(acp::TextContent {
1343 text: message.to_string(),
1344 annotations: None,
1345 })],
1346 cx,
1347 )
1348 }
1349
1350 pub fn send(
1351 &mut self,
1352 message: Vec<acp::ContentBlock>,
1353 cx: &mut Context<Self>,
1354 ) -> BoxFuture<'static, Result<()>> {
1355 let block = ContentBlock::new_combined(
1356 message.clone(),
1357 self.project.read(cx).languages().clone(),
1358 cx,
1359 );
1360 let request = acp::PromptRequest {
1361 prompt: message.clone(),
1362 session_id: self.session_id.clone(),
1363 };
1364 let git_store = self.project.read(cx).git_store().clone();
1365
1366 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1367 Some(UserMessageId::new())
1368 } else {
1369 None
1370 };
1371
1372 self.run_turn(cx, async move |this, cx| {
1373 this.update(cx, |this, cx| {
1374 this.push_entry(
1375 AgentThreadEntry::UserMessage(UserMessage {
1376 id: message_id.clone(),
1377 content: block,
1378 chunks: message,
1379 checkpoint: None,
1380 }),
1381 cx,
1382 );
1383 })
1384 .ok();
1385
1386 let old_checkpoint = git_store
1387 .update(cx, |git, cx| git.checkpoint(cx))?
1388 .await
1389 .context("failed to get old checkpoint")
1390 .log_err();
1391 this.update(cx, |this, cx| {
1392 if let Some((_ix, message)) = this.last_user_message() {
1393 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1394 git_checkpoint,
1395 show: false,
1396 });
1397 }
1398 this.connection.prompt(message_id, request, cx)
1399 })?
1400 .await
1401 })
1402 }
1403
1404 pub fn can_resume(&self, cx: &App) -> bool {
1405 self.connection.resume(&self.session_id, cx).is_some()
1406 }
1407
1408 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1409 self.run_turn(cx, async move |this, cx| {
1410 this.update(cx, |this, cx| {
1411 this.connection
1412 .resume(&this.session_id, cx)
1413 .map(|resume| resume.run(cx))
1414 })?
1415 .context("resuming a session is not supported")?
1416 .await
1417 })
1418 }
1419
1420 fn run_turn(
1421 &mut self,
1422 cx: &mut Context<Self>,
1423 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1424 ) -> BoxFuture<'static, Result<()>> {
1425 self.clear_completed_plan_entries(cx);
1426
1427 let (tx, rx) = oneshot::channel();
1428 let cancel_task = self.cancel(cx);
1429
1430 self.send_task = Some(cx.spawn(async move |this, cx| {
1431 cancel_task.await;
1432 tx.send(f(this, cx).await).ok();
1433 }));
1434
1435 cx.spawn(async move |this, cx| {
1436 let response = rx.await;
1437
1438 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1439 .await?;
1440
1441 this.update(cx, |this, cx| {
1442 this.project
1443 .update(cx, |project, cx| project.set_agent_location(None, cx));
1444 match response {
1445 Ok(Err(e)) => {
1446 this.send_task.take();
1447 cx.emit(AcpThreadEvent::Error);
1448 Err(e)
1449 }
1450 result => {
1451 let canceled = matches!(
1452 result,
1453 Ok(Ok(acp::PromptResponse {
1454 stop_reason: acp::StopReason::Cancelled
1455 }))
1456 );
1457
1458 // We only take the task if the current prompt wasn't canceled.
1459 //
1460 // This prompt may have been canceled because another one was sent
1461 // while it was still generating. In these cases, dropping `send_task`
1462 // would cause the next generation to be canceled.
1463 if !canceled {
1464 this.send_task.take();
1465 }
1466
1467 // Truncate entries if the last prompt was refused.
1468 if let Ok(Ok(acp::PromptResponse {
1469 stop_reason: acp::StopReason::Refusal,
1470 })) = result
1471 && let Some((ix, _)) = this.last_user_message()
1472 {
1473 let range = ix..this.entries.len();
1474 this.entries.truncate(ix);
1475 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1476 }
1477
1478 cx.emit(AcpThreadEvent::Stopped);
1479 Ok(())
1480 }
1481 }
1482 })?
1483 })
1484 .boxed()
1485 }
1486
1487 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1488 let Some(send_task) = self.send_task.take() else {
1489 return Task::ready(());
1490 };
1491
1492 for entry in self.entries.iter_mut() {
1493 if let AgentThreadEntry::ToolCall(call) = entry {
1494 let cancel = matches!(
1495 call.status,
1496 ToolCallStatus::Pending
1497 | ToolCallStatus::WaitingForConfirmation { .. }
1498 | ToolCallStatus::InProgress
1499 );
1500
1501 if cancel {
1502 call.status = ToolCallStatus::Canceled;
1503 }
1504 }
1505 }
1506
1507 self.connection.cancel(&self.session_id, cx);
1508
1509 // Wait for the send task to complete
1510 cx.foreground_executor().spawn(send_task)
1511 }
1512
1513 /// Rewinds this thread to before the entry at `index`, removing it and all
1514 /// subsequent entries while reverting any changes made from that point.
1515 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1516 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1517 return Task::ready(Err(anyhow!("not supported")));
1518 };
1519 let Some(message) = self.user_message(&id) else {
1520 return Task::ready(Err(anyhow!("message not found")));
1521 };
1522
1523 let checkpoint = message
1524 .checkpoint
1525 .as_ref()
1526 .map(|c| c.git_checkpoint.clone());
1527
1528 let git_store = self.project.read(cx).git_store().clone();
1529 cx.spawn(async move |this, cx| {
1530 if let Some(checkpoint) = checkpoint {
1531 git_store
1532 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1533 .await?;
1534 }
1535
1536 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1537 this.update(cx, |this, cx| {
1538 if let Some((ix, _)) = this.user_message_mut(&id) {
1539 let range = ix..this.entries.len();
1540 this.entries.truncate(ix);
1541 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1542 }
1543 })
1544 })
1545 }
1546
1547 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1548 let git_store = self.project.read(cx).git_store().clone();
1549
1550 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1551 if let Some(checkpoint) = message.checkpoint.as_ref() {
1552 checkpoint.git_checkpoint.clone()
1553 } else {
1554 return Task::ready(Ok(()));
1555 }
1556 } else {
1557 return Task::ready(Ok(()));
1558 };
1559
1560 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1561 cx.spawn(async move |this, cx| {
1562 let new_checkpoint = new_checkpoint
1563 .await
1564 .context("failed to get new checkpoint")
1565 .log_err();
1566 if let Some(new_checkpoint) = new_checkpoint {
1567 let equal = git_store
1568 .update(cx, |git, cx| {
1569 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1570 })?
1571 .await
1572 .unwrap_or(true);
1573 this.update(cx, |this, cx| {
1574 let (ix, message) = this.last_user_message().context("no user message")?;
1575 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1576 checkpoint.show = !equal;
1577 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1578 anyhow::Ok(())
1579 })??;
1580 }
1581
1582 Ok(())
1583 })
1584 }
1585
1586 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1587 self.entries
1588 .iter_mut()
1589 .enumerate()
1590 .rev()
1591 .find_map(|(ix, entry)| {
1592 if let AgentThreadEntry::UserMessage(message) = entry {
1593 Some((ix, message))
1594 } else {
1595 None
1596 }
1597 })
1598 }
1599
1600 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1601 self.entries.iter().find_map(|entry| {
1602 if let AgentThreadEntry::UserMessage(message) = entry {
1603 if message.id.as_ref() == Some(id) {
1604 Some(message)
1605 } else {
1606 None
1607 }
1608 } else {
1609 None
1610 }
1611 })
1612 }
1613
1614 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1615 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1616 if let AgentThreadEntry::UserMessage(message) = entry {
1617 if message.id.as_ref() == Some(id) {
1618 Some((ix, message))
1619 } else {
1620 None
1621 }
1622 } else {
1623 None
1624 }
1625 })
1626 }
1627
1628 pub fn read_text_file(
1629 &self,
1630 path: PathBuf,
1631 line: Option<u32>,
1632 limit: Option<u32>,
1633 reuse_shared_snapshot: bool,
1634 cx: &mut Context<Self>,
1635 ) -> Task<Result<String>> {
1636 let project = self.project.clone();
1637 let action_log = self.action_log.clone();
1638 cx.spawn(async move |this, cx| {
1639 let load = project.update(cx, |project, cx| {
1640 let path = project
1641 .project_path_for_absolute_path(&path, cx)
1642 .context("invalid path")?;
1643 anyhow::Ok(project.open_buffer(path, cx))
1644 });
1645 let buffer = load??.await?;
1646
1647 let snapshot = if reuse_shared_snapshot {
1648 this.read_with(cx, |this, _| {
1649 this.shared_buffers.get(&buffer.clone()).cloned()
1650 })
1651 .log_err()
1652 .flatten()
1653 } else {
1654 None
1655 };
1656
1657 let snapshot = if let Some(snapshot) = snapshot {
1658 snapshot
1659 } else {
1660 action_log.update(cx, |action_log, cx| {
1661 action_log.buffer_read(buffer.clone(), cx);
1662 })?;
1663 project.update(cx, |project, cx| {
1664 let position = buffer
1665 .read(cx)
1666 .snapshot()
1667 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1668 project.set_agent_location(
1669 Some(AgentLocation {
1670 buffer: buffer.downgrade(),
1671 position,
1672 }),
1673 cx,
1674 );
1675 })?;
1676
1677 buffer.update(cx, |buffer, _| buffer.snapshot())?
1678 };
1679
1680 this.update(cx, |this, _| {
1681 let text = snapshot.text();
1682 this.shared_buffers.insert(buffer.clone(), snapshot);
1683 if line.is_none() && limit.is_none() {
1684 return Ok(text);
1685 }
1686 let limit = limit.unwrap_or(u32::MAX) as usize;
1687 let Some(line) = line else {
1688 return Ok(text.lines().take(limit).collect::<String>());
1689 };
1690
1691 let count = text.lines().count();
1692 if count < line as usize {
1693 anyhow::bail!("There are only {} lines", count);
1694 }
1695 Ok(text
1696 .lines()
1697 .skip(line as usize + 1)
1698 .take(limit)
1699 .collect::<String>())
1700 })?
1701 })
1702 }
1703
1704 pub fn write_text_file(
1705 &self,
1706 path: PathBuf,
1707 content: String,
1708 cx: &mut Context<Self>,
1709 ) -> Task<Result<()>> {
1710 let project = self.project.clone();
1711 let action_log = self.action_log.clone();
1712 cx.spawn(async move |this, cx| {
1713 let load = project.update(cx, |project, cx| {
1714 let path = project
1715 .project_path_for_absolute_path(&path, cx)
1716 .context("invalid path")?;
1717 anyhow::Ok(project.open_buffer(path, cx))
1718 });
1719 let buffer = load??.await?;
1720 let snapshot = this.update(cx, |this, cx| {
1721 this.shared_buffers
1722 .get(&buffer)
1723 .cloned()
1724 .unwrap_or_else(|| buffer.read(cx).snapshot())
1725 })?;
1726 let edits = cx
1727 .background_executor()
1728 .spawn(async move {
1729 let old_text = snapshot.text();
1730 text_diff(old_text.as_str(), &content)
1731 .into_iter()
1732 .map(|(range, replacement)| {
1733 (
1734 snapshot.anchor_after(range.start)
1735 ..snapshot.anchor_before(range.end),
1736 replacement,
1737 )
1738 })
1739 .collect::<Vec<_>>()
1740 })
1741 .await;
1742
1743 project.update(cx, |project, cx| {
1744 project.set_agent_location(
1745 Some(AgentLocation {
1746 buffer: buffer.downgrade(),
1747 position: edits
1748 .last()
1749 .map(|(range, _)| range.end)
1750 .unwrap_or(Anchor::MIN),
1751 }),
1752 cx,
1753 );
1754 })?;
1755
1756 let format_on_save = cx.update(|cx| {
1757 action_log.update(cx, |action_log, cx| {
1758 action_log.buffer_read(buffer.clone(), cx);
1759 });
1760
1761 let format_on_save = buffer.update(cx, |buffer, cx| {
1762 buffer.edit(edits, None, cx);
1763
1764 let settings = language::language_settings::language_settings(
1765 buffer.language().map(|l| l.name()),
1766 buffer.file(),
1767 cx,
1768 );
1769
1770 settings.format_on_save != FormatOnSave::Off
1771 });
1772 action_log.update(cx, |action_log, cx| {
1773 action_log.buffer_edited(buffer.clone(), cx);
1774 });
1775 format_on_save
1776 })?;
1777
1778 if format_on_save {
1779 let format_task = project.update(cx, |project, cx| {
1780 project.format(
1781 HashSet::from_iter([buffer.clone()]),
1782 LspFormatTarget::Buffers,
1783 false,
1784 FormatTrigger::Save,
1785 cx,
1786 )
1787 })?;
1788 format_task.await.log_err();
1789
1790 action_log.update(cx, |action_log, cx| {
1791 action_log.buffer_edited(buffer.clone(), cx);
1792 })?;
1793 }
1794
1795 project
1796 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1797 .await
1798 })
1799 }
1800
1801 pub fn to_markdown(&self, cx: &App) -> String {
1802 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1803 }
1804
1805 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1806 cx.emit(AcpThreadEvent::LoadError(error));
1807 }
1808}
1809
1810fn markdown_for_raw_output(
1811 raw_output: &serde_json::Value,
1812 language_registry: &Arc<LanguageRegistry>,
1813 cx: &mut App,
1814) -> Option<Entity<Markdown>> {
1815 match raw_output {
1816 serde_json::Value::Null => None,
1817 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1818 Markdown::new(
1819 value.to_string().into(),
1820 Some(language_registry.clone()),
1821 None,
1822 cx,
1823 )
1824 })),
1825 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1826 Markdown::new(
1827 value.to_string().into(),
1828 Some(language_registry.clone()),
1829 None,
1830 cx,
1831 )
1832 })),
1833 serde_json::Value::String(value) => Some(cx.new(|cx| {
1834 Markdown::new(
1835 value.clone().into(),
1836 Some(language_registry.clone()),
1837 None,
1838 cx,
1839 )
1840 })),
1841 value => Some(cx.new(|cx| {
1842 Markdown::new(
1843 format!("```json\n{}\n```", value).into(),
1844 Some(language_registry.clone()),
1845 None,
1846 cx,
1847 )
1848 })),
1849 }
1850}
1851
1852#[cfg(test)]
1853mod tests {
1854 use super::*;
1855 use anyhow::anyhow;
1856 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1857 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1858 use indoc::indoc;
1859 use project::{FakeFs, Fs};
1860 use rand::Rng as _;
1861 use serde_json::json;
1862 use settings::SettingsStore;
1863 use smol::stream::StreamExt as _;
1864 use std::{
1865 any::Any,
1866 cell::RefCell,
1867 path::Path,
1868 rc::Rc,
1869 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1870 time::Duration,
1871 };
1872 use util::path;
1873
1874 fn init_test(cx: &mut TestAppContext) {
1875 env_logger::try_init().ok();
1876 cx.update(|cx| {
1877 let settings_store = SettingsStore::test(cx);
1878 cx.set_global(settings_store);
1879 Project::init_settings(cx);
1880 language::init(cx);
1881 });
1882 }
1883
1884 #[gpui::test]
1885 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1886 init_test(cx);
1887
1888 let fs = FakeFs::new(cx.executor());
1889 let project = Project::test(fs, [], cx).await;
1890 let connection = Rc::new(FakeAgentConnection::new());
1891 let thread = cx
1892 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1893 .await
1894 .unwrap();
1895
1896 // Test creating a new user message
1897 thread.update(cx, |thread, cx| {
1898 thread.push_user_content_block(
1899 None,
1900 acp::ContentBlock::Text(acp::TextContent {
1901 annotations: None,
1902 text: "Hello, ".to_string(),
1903 }),
1904 cx,
1905 );
1906 });
1907
1908 thread.update(cx, |thread, cx| {
1909 assert_eq!(thread.entries.len(), 1);
1910 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1911 assert_eq!(user_msg.id, None);
1912 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1913 } else {
1914 panic!("Expected UserMessage");
1915 }
1916 });
1917
1918 // Test appending to existing user message
1919 let message_1_id = UserMessageId::new();
1920 thread.update(cx, |thread, cx| {
1921 thread.push_user_content_block(
1922 Some(message_1_id.clone()),
1923 acp::ContentBlock::Text(acp::TextContent {
1924 annotations: None,
1925 text: "world!".to_string(),
1926 }),
1927 cx,
1928 );
1929 });
1930
1931 thread.update(cx, |thread, cx| {
1932 assert_eq!(thread.entries.len(), 1);
1933 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1934 assert_eq!(user_msg.id, Some(message_1_id));
1935 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1936 } else {
1937 panic!("Expected UserMessage");
1938 }
1939 });
1940
1941 // Test creating new user message after assistant message
1942 thread.update(cx, |thread, cx| {
1943 thread.push_assistant_content_block(
1944 acp::ContentBlock::Text(acp::TextContent {
1945 annotations: None,
1946 text: "Assistant response".to_string(),
1947 }),
1948 false,
1949 cx,
1950 );
1951 });
1952
1953 let message_2_id = UserMessageId::new();
1954 thread.update(cx, |thread, cx| {
1955 thread.push_user_content_block(
1956 Some(message_2_id.clone()),
1957 acp::ContentBlock::Text(acp::TextContent {
1958 annotations: None,
1959 text: "New user message".to_string(),
1960 }),
1961 cx,
1962 );
1963 });
1964
1965 thread.update(cx, |thread, cx| {
1966 assert_eq!(thread.entries.len(), 3);
1967 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1968 assert_eq!(user_msg.id, Some(message_2_id));
1969 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1970 } else {
1971 panic!("Expected UserMessage at index 2");
1972 }
1973 });
1974 }
1975
1976 #[gpui::test]
1977 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1978 init_test(cx);
1979
1980 let fs = FakeFs::new(cx.executor());
1981 let project = Project::test(fs, [], cx).await;
1982 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1983 |_, thread, mut cx| {
1984 async move {
1985 thread.update(&mut cx, |thread, cx| {
1986 thread
1987 .handle_session_update(
1988 acp::SessionUpdate::AgentThoughtChunk {
1989 content: "Thinking ".into(),
1990 },
1991 cx,
1992 )
1993 .unwrap();
1994 thread
1995 .handle_session_update(
1996 acp::SessionUpdate::AgentThoughtChunk {
1997 content: "hard!".into(),
1998 },
1999 cx,
2000 )
2001 .unwrap();
2002 })?;
2003 Ok(acp::PromptResponse {
2004 stop_reason: acp::StopReason::EndTurn,
2005 })
2006 }
2007 .boxed_local()
2008 },
2009 ));
2010
2011 let thread = cx
2012 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2013 .await
2014 .unwrap();
2015
2016 thread
2017 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2018 .await
2019 .unwrap();
2020
2021 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2022 assert_eq!(
2023 output,
2024 indoc! {r#"
2025 ## User
2026
2027 Hello from Zed!
2028
2029 ## Assistant
2030
2031 <thinking>
2032 Thinking hard!
2033 </thinking>
2034
2035 "#}
2036 );
2037 }
2038
2039 #[gpui::test]
2040 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2041 init_test(cx);
2042
2043 let fs = FakeFs::new(cx.executor());
2044 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2045 .await;
2046 let project = Project::test(fs.clone(), [], cx).await;
2047 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2048 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2049 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2050 move |_, thread, mut cx| {
2051 let read_file_tx = read_file_tx.clone();
2052 async move {
2053 let content = thread
2054 .update(&mut cx, |thread, cx| {
2055 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2056 })
2057 .unwrap()
2058 .await
2059 .unwrap();
2060 assert_eq!(content, "one\ntwo\nthree\n");
2061 read_file_tx.take().unwrap().send(()).unwrap();
2062 thread
2063 .update(&mut cx, |thread, cx| {
2064 thread.write_text_file(
2065 path!("/tmp/foo").into(),
2066 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2067 cx,
2068 )
2069 })
2070 .unwrap()
2071 .await
2072 .unwrap();
2073 Ok(acp::PromptResponse {
2074 stop_reason: acp::StopReason::EndTurn,
2075 })
2076 }
2077 .boxed_local()
2078 },
2079 ));
2080
2081 let (worktree, pathbuf) = project
2082 .update(cx, |project, cx| {
2083 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2084 })
2085 .await
2086 .unwrap();
2087 let buffer = project
2088 .update(cx, |project, cx| {
2089 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2090 })
2091 .await
2092 .unwrap();
2093
2094 let thread = cx
2095 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2096 .await
2097 .unwrap();
2098
2099 let request = thread.update(cx, |thread, cx| {
2100 thread.send_raw("Extend the count in /tmp/foo", cx)
2101 });
2102 read_file_rx.await.ok();
2103 buffer.update(cx, |buffer, cx| {
2104 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2105 });
2106 cx.run_until_parked();
2107 assert_eq!(
2108 buffer.read_with(cx, |buffer, _| buffer.text()),
2109 "zero\none\ntwo\nthree\nfour\nfive\n"
2110 );
2111 assert_eq!(
2112 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2113 "zero\none\ntwo\nthree\nfour\nfive\n"
2114 );
2115 request.await.unwrap();
2116 }
2117
2118 #[gpui::test]
2119 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2120 init_test(cx);
2121
2122 let fs = FakeFs::new(cx.executor());
2123 let project = Project::test(fs, [], cx).await;
2124 let id = acp::ToolCallId("test".into());
2125
2126 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2127 let id = id.clone();
2128 move |_, thread, mut cx| {
2129 let id = id.clone();
2130 async move {
2131 thread
2132 .update(&mut cx, |thread, cx| {
2133 thread.handle_session_update(
2134 acp::SessionUpdate::ToolCall(acp::ToolCall {
2135 id: id.clone(),
2136 title: "Label".into(),
2137 kind: acp::ToolKind::Fetch,
2138 status: acp::ToolCallStatus::InProgress,
2139 content: vec![],
2140 locations: vec![],
2141 raw_input: None,
2142 raw_output: None,
2143 }),
2144 cx,
2145 )
2146 })
2147 .unwrap()
2148 .unwrap();
2149 Ok(acp::PromptResponse {
2150 stop_reason: acp::StopReason::EndTurn,
2151 })
2152 }
2153 .boxed_local()
2154 }
2155 }));
2156
2157 let thread = cx
2158 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2159 .await
2160 .unwrap();
2161
2162 let request = thread.update(cx, |thread, cx| {
2163 thread.send_raw("Fetch https://example.com", cx)
2164 });
2165
2166 run_until_first_tool_call(&thread, cx).await;
2167
2168 thread.read_with(cx, |thread, _| {
2169 assert!(matches!(
2170 thread.entries[1],
2171 AgentThreadEntry::ToolCall(ToolCall {
2172 status: ToolCallStatus::InProgress,
2173 ..
2174 })
2175 ));
2176 });
2177
2178 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2179
2180 thread.read_with(cx, |thread, _| {
2181 assert!(matches!(
2182 &thread.entries[1],
2183 AgentThreadEntry::ToolCall(ToolCall {
2184 status: ToolCallStatus::Canceled,
2185 ..
2186 })
2187 ));
2188 });
2189
2190 thread
2191 .update(cx, |thread, cx| {
2192 thread.handle_session_update(
2193 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2194 id,
2195 fields: acp::ToolCallUpdateFields {
2196 status: Some(acp::ToolCallStatus::Completed),
2197 ..Default::default()
2198 },
2199 }),
2200 cx,
2201 )
2202 })
2203 .unwrap();
2204
2205 request.await.unwrap();
2206
2207 thread.read_with(cx, |thread, _| {
2208 assert!(matches!(
2209 thread.entries[1],
2210 AgentThreadEntry::ToolCall(ToolCall {
2211 status: ToolCallStatus::Completed,
2212 ..
2213 })
2214 ));
2215 });
2216 }
2217
2218 #[gpui::test]
2219 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2220 init_test(cx);
2221 let fs = FakeFs::new(cx.background_executor.clone());
2222 fs.insert_tree(path!("/test"), json!({})).await;
2223 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2224
2225 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2226 move |_, thread, mut cx| {
2227 async move {
2228 thread
2229 .update(&mut cx, |thread, cx| {
2230 thread.handle_session_update(
2231 acp::SessionUpdate::ToolCall(acp::ToolCall {
2232 id: acp::ToolCallId("test".into()),
2233 title: "Label".into(),
2234 kind: acp::ToolKind::Edit,
2235 status: acp::ToolCallStatus::Completed,
2236 content: vec![acp::ToolCallContent::Diff {
2237 diff: acp::Diff {
2238 path: "/test/test.txt".into(),
2239 old_text: None,
2240 new_text: "foo".into(),
2241 },
2242 }],
2243 locations: vec![],
2244 raw_input: None,
2245 raw_output: None,
2246 }),
2247 cx,
2248 )
2249 })
2250 .unwrap()
2251 .unwrap();
2252 Ok(acp::PromptResponse {
2253 stop_reason: acp::StopReason::EndTurn,
2254 })
2255 }
2256 .boxed_local()
2257 }
2258 }));
2259
2260 let thread = cx
2261 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2262 .await
2263 .unwrap();
2264
2265 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2266 .await
2267 .unwrap();
2268
2269 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2270 }
2271
2272 #[gpui::test(iterations = 10)]
2273 async fn test_checkpoints(cx: &mut TestAppContext) {
2274 init_test(cx);
2275 let fs = FakeFs::new(cx.background_executor.clone());
2276 fs.insert_tree(
2277 path!("/test"),
2278 json!({
2279 ".git": {}
2280 }),
2281 )
2282 .await;
2283 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2284
2285 let simulate_changes = Arc::new(AtomicBool::new(true));
2286 let next_filename = Arc::new(AtomicUsize::new(0));
2287 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2288 let simulate_changes = simulate_changes.clone();
2289 let next_filename = next_filename.clone();
2290 let fs = fs.clone();
2291 move |request, thread, mut cx| {
2292 let fs = fs.clone();
2293 let simulate_changes = simulate_changes.clone();
2294 let next_filename = next_filename.clone();
2295 async move {
2296 if simulate_changes.load(SeqCst) {
2297 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2298 fs.write(Path::new(&filename), b"").await?;
2299 }
2300
2301 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2302 panic!("expected text content block");
2303 };
2304 thread.update(&mut cx, |thread, cx| {
2305 thread
2306 .handle_session_update(
2307 acp::SessionUpdate::AgentMessageChunk {
2308 content: content.text.to_uppercase().into(),
2309 },
2310 cx,
2311 )
2312 .unwrap();
2313 })?;
2314 Ok(acp::PromptResponse {
2315 stop_reason: acp::StopReason::EndTurn,
2316 })
2317 }
2318 .boxed_local()
2319 }
2320 }));
2321 let thread = cx
2322 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2323 .await
2324 .unwrap();
2325
2326 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2327 .await
2328 .unwrap();
2329 thread.read_with(cx, |thread, cx| {
2330 assert_eq!(
2331 thread.to_markdown(cx),
2332 indoc! {"
2333 ## User (checkpoint)
2334
2335 Lorem
2336
2337 ## Assistant
2338
2339 LOREM
2340
2341 "}
2342 );
2343 });
2344 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2345
2346 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2347 .await
2348 .unwrap();
2349 thread.read_with(cx, |thread, cx| {
2350 assert_eq!(
2351 thread.to_markdown(cx),
2352 indoc! {"
2353 ## User (checkpoint)
2354
2355 Lorem
2356
2357 ## Assistant
2358
2359 LOREM
2360
2361 ## User (checkpoint)
2362
2363 ipsum
2364
2365 ## Assistant
2366
2367 IPSUM
2368
2369 "}
2370 );
2371 });
2372 assert_eq!(
2373 fs.files(),
2374 vec![
2375 Path::new(path!("/test/file-0")),
2376 Path::new(path!("/test/file-1"))
2377 ]
2378 );
2379
2380 // Checkpoint isn't stored when there are no changes.
2381 simulate_changes.store(false, SeqCst);
2382 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2383 .await
2384 .unwrap();
2385 thread.read_with(cx, |thread, cx| {
2386 assert_eq!(
2387 thread.to_markdown(cx),
2388 indoc! {"
2389 ## User (checkpoint)
2390
2391 Lorem
2392
2393 ## Assistant
2394
2395 LOREM
2396
2397 ## User (checkpoint)
2398
2399 ipsum
2400
2401 ## Assistant
2402
2403 IPSUM
2404
2405 ## User
2406
2407 dolor
2408
2409 ## Assistant
2410
2411 DOLOR
2412
2413 "}
2414 );
2415 });
2416 assert_eq!(
2417 fs.files(),
2418 vec![
2419 Path::new(path!("/test/file-0")),
2420 Path::new(path!("/test/file-1"))
2421 ]
2422 );
2423
2424 // Rewinding the conversation truncates the history and restores the checkpoint.
2425 thread
2426 .update(cx, |thread, cx| {
2427 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2428 panic!("unexpected entries {:?}", thread.entries)
2429 };
2430 thread.rewind(message.id.clone().unwrap(), cx)
2431 })
2432 .await
2433 .unwrap();
2434 thread.read_with(cx, |thread, cx| {
2435 assert_eq!(
2436 thread.to_markdown(cx),
2437 indoc! {"
2438 ## User (checkpoint)
2439
2440 Lorem
2441
2442 ## Assistant
2443
2444 LOREM
2445
2446 "}
2447 );
2448 });
2449 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2450 }
2451
2452 #[gpui::test]
2453 async fn test_refusal(cx: &mut TestAppContext) {
2454 init_test(cx);
2455 let fs = FakeFs::new(cx.background_executor.clone());
2456 fs.insert_tree(path!("/"), json!({})).await;
2457 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2458
2459 let refuse_next = Arc::new(AtomicBool::new(false));
2460 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2461 let refuse_next = refuse_next.clone();
2462 move |request, thread, mut cx| {
2463 let refuse_next = refuse_next.clone();
2464 async move {
2465 if refuse_next.load(SeqCst) {
2466 return Ok(acp::PromptResponse {
2467 stop_reason: acp::StopReason::Refusal,
2468 });
2469 }
2470
2471 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2472 panic!("expected text content block");
2473 };
2474 thread.update(&mut cx, |thread, cx| {
2475 thread
2476 .handle_session_update(
2477 acp::SessionUpdate::AgentMessageChunk {
2478 content: content.text.to_uppercase().into(),
2479 },
2480 cx,
2481 )
2482 .unwrap();
2483 })?;
2484 Ok(acp::PromptResponse {
2485 stop_reason: acp::StopReason::EndTurn,
2486 })
2487 }
2488 .boxed_local()
2489 }
2490 }));
2491 let thread = cx
2492 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2493 .await
2494 .unwrap();
2495
2496 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2497 .await
2498 .unwrap();
2499 thread.read_with(cx, |thread, cx| {
2500 assert_eq!(
2501 thread.to_markdown(cx),
2502 indoc! {"
2503 ## User
2504
2505 hello
2506
2507 ## Assistant
2508
2509 HELLO
2510
2511 "}
2512 );
2513 });
2514
2515 // Simulate refusing the second message, ensuring the conversation gets
2516 // truncated to before sending it.
2517 refuse_next.store(true, SeqCst);
2518 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2519 .await
2520 .unwrap();
2521 thread.read_with(cx, |thread, cx| {
2522 assert_eq!(
2523 thread.to_markdown(cx),
2524 indoc! {"
2525 ## User
2526
2527 hello
2528
2529 ## Assistant
2530
2531 HELLO
2532
2533 "}
2534 );
2535 });
2536 }
2537
2538 async fn run_until_first_tool_call(
2539 thread: &Entity<AcpThread>,
2540 cx: &mut TestAppContext,
2541 ) -> usize {
2542 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2543
2544 let subscription = cx.update(|cx| {
2545 cx.subscribe(thread, move |thread, _, cx| {
2546 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2547 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2548 return tx.try_send(ix).unwrap();
2549 }
2550 }
2551 })
2552 });
2553
2554 select! {
2555 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2556 panic!("Timeout waiting for tool call")
2557 }
2558 ix = rx.next().fuse() => {
2559 drop(subscription);
2560 ix.unwrap()
2561 }
2562 }
2563 }
2564
2565 #[derive(Clone, Default)]
2566 struct FakeAgentConnection {
2567 auth_methods: Vec<acp::AuthMethod>,
2568 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2569 on_user_message: Option<
2570 Rc<
2571 dyn Fn(
2572 acp::PromptRequest,
2573 WeakEntity<AcpThread>,
2574 AsyncApp,
2575 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2576 + 'static,
2577 >,
2578 >,
2579 }
2580
2581 impl FakeAgentConnection {
2582 fn new() -> Self {
2583 Self {
2584 auth_methods: Vec::new(),
2585 on_user_message: None,
2586 sessions: Arc::default(),
2587 }
2588 }
2589
2590 #[expect(unused)]
2591 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2592 self.auth_methods = auth_methods;
2593 self
2594 }
2595
2596 fn on_user_message(
2597 mut self,
2598 handler: impl Fn(
2599 acp::PromptRequest,
2600 WeakEntity<AcpThread>,
2601 AsyncApp,
2602 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2603 + 'static,
2604 ) -> Self {
2605 self.on_user_message.replace(Rc::new(handler));
2606 self
2607 }
2608 }
2609
2610 impl AgentConnection for FakeAgentConnection {
2611 fn auth_methods(&self) -> &[acp::AuthMethod] {
2612 &self.auth_methods
2613 }
2614
2615 fn new_thread(
2616 self: Rc<Self>,
2617 project: Entity<Project>,
2618 _cwd: &Path,
2619 cx: &mut App,
2620 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2621 let session_id = acp::SessionId(
2622 rand::thread_rng()
2623 .sample_iter(&rand::distributions::Alphanumeric)
2624 .take(7)
2625 .map(char::from)
2626 .collect::<String>()
2627 .into(),
2628 );
2629 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2630 let thread = cx.new(|cx| {
2631 AcpThread::new(
2632 "Test",
2633 self.clone(),
2634 project,
2635 action_log,
2636 session_id.clone(),
2637 watch::Receiver::constant(acp::PromptCapabilities {
2638 image: true,
2639 audio: true,
2640 embedded_context: true,
2641 }),
2642 cx,
2643 )
2644 });
2645 self.sessions.lock().insert(session_id, thread.downgrade());
2646 Task::ready(Ok(thread))
2647 }
2648
2649 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2650 if self.auth_methods().iter().any(|m| m.id == method) {
2651 Task::ready(Ok(()))
2652 } else {
2653 Task::ready(Err(anyhow!("Invalid Auth Method")))
2654 }
2655 }
2656
2657 fn prompt(
2658 &self,
2659 _id: Option<UserMessageId>,
2660 params: acp::PromptRequest,
2661 cx: &mut App,
2662 ) -> Task<gpui::Result<acp::PromptResponse>> {
2663 let sessions = self.sessions.lock();
2664 let thread = sessions.get(¶ms.session_id).unwrap();
2665 if let Some(handler) = &self.on_user_message {
2666 let handler = handler.clone();
2667 let thread = thread.clone();
2668 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2669 } else {
2670 Task::ready(Ok(acp::PromptResponse {
2671 stop_reason: acp::StopReason::EndTurn,
2672 }))
2673 }
2674 }
2675
2676 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2677 let sessions = self.sessions.lock();
2678 let thread = sessions.get(session_id).unwrap().clone();
2679
2680 cx.spawn(async move |cx| {
2681 thread
2682 .update(cx, |thread, cx| thread.cancel(cx))
2683 .unwrap()
2684 .await
2685 })
2686 .detach();
2687 }
2688
2689 fn truncate(
2690 &self,
2691 session_id: &acp::SessionId,
2692 _cx: &App,
2693 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2694 Some(Rc::new(FakeAgentSessionEditor {
2695 _session_id: session_id.clone(),
2696 }))
2697 }
2698
2699 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2700 self
2701 }
2702 }
2703
2704 struct FakeAgentSessionEditor {
2705 _session_id: acp::SessionId,
2706 }
2707
2708 impl AgentSessionTruncate for FakeAgentSessionEditor {
2709 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2710 Task::ready(Ok(()))
2711 }
2712 }
2713}