1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
187 first_line.to_owned() + "…"
188 } else {
189 tool_call.title
190 };
191 Self {
192 id: tool_call.id,
193 label: cx
194 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
195 kind: tool_call.kind,
196 content: tool_call
197 .content
198 .into_iter()
199 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
200 .collect(),
201 locations: tool_call.locations,
202 resolved_locations: Vec::default(),
203 status,
204 raw_input: tool_call.raw_input,
205 raw_output: tool_call.raw_output,
206 }
207 }
208
209 fn update_fields(
210 &mut self,
211 fields: acp::ToolCallUpdateFields,
212 language_registry: Arc<LanguageRegistry>,
213 cx: &mut App,
214 ) {
215 let acp::ToolCallUpdateFields {
216 kind,
217 status,
218 title,
219 content,
220 locations,
221 raw_input,
222 raw_output,
223 } = fields;
224
225 if let Some(kind) = kind {
226 self.kind = kind;
227 }
228
229 if let Some(status) = status {
230 self.status = status.into();
231 }
232
233 if let Some(title) = title {
234 self.label.update(cx, |label, cx| {
235 if let Some((first_line, _)) = title.split_once("\n") {
236 label.replace(first_line.to_owned() + "…", cx)
237 } else {
238 label.replace(title, cx);
239 }
240 });
241 }
242
243 if let Some(content) = content {
244 let new_content_len = content.len();
245 let mut content = content.into_iter();
246
247 // Reuse existing content if we can
248 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
249 old.update_from_acp(new, language_registry.clone(), cx);
250 }
251 for new in content {
252 self.content.push(ToolCallContent::from_acp(
253 new,
254 language_registry.clone(),
255 cx,
256 ))
257 }
258 self.content.truncate(new_content_len);
259 }
260
261 if let Some(locations) = locations {
262 self.locations = locations;
263 }
264
265 if let Some(raw_input) = raw_input {
266 self.raw_input = Some(raw_input);
267 }
268
269 if let Some(raw_output) = raw_output {
270 if self.content.is_empty()
271 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
272 {
273 self.content
274 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
275 markdown,
276 }));
277 }
278 self.raw_output = Some(raw_output);
279 }
280 }
281
282 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
283 self.content.iter().filter_map(|content| match content {
284 ToolCallContent::Diff(diff) => Some(diff),
285 ToolCallContent::ContentBlock(_) => None,
286 ToolCallContent::Terminal(_) => None,
287 })
288 }
289
290 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
291 self.content.iter().filter_map(|content| match content {
292 ToolCallContent::Terminal(terminal) => Some(terminal),
293 ToolCallContent::ContentBlock(_) => None,
294 ToolCallContent::Diff(_) => None,
295 })
296 }
297
298 fn to_markdown(&self, cx: &App) -> String {
299 let mut markdown = format!(
300 "**Tool Call: {}**\nStatus: {}\n\n",
301 self.label.read(cx).source(),
302 self.status
303 );
304 for content in &self.content {
305 markdown.push_str(content.to_markdown(cx).as_str());
306 markdown.push_str("\n\n");
307 }
308 markdown
309 }
310
311 async fn resolve_location(
312 location: acp::ToolCallLocation,
313 project: WeakEntity<Project>,
314 cx: &mut AsyncApp,
315 ) -> Option<AgentLocation> {
316 let buffer = project
317 .update(cx, |project, cx| {
318 project
319 .project_path_for_absolute_path(&location.path, cx)
320 .map(|path| project.open_buffer(path, cx))
321 })
322 .ok()??;
323 let buffer = buffer.await.log_err()?;
324 let position = buffer
325 .update(cx, |buffer, _| {
326 if let Some(row) = location.line {
327 let snapshot = buffer.snapshot();
328 let column = snapshot.indent_size_for_line(row).len;
329 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
330 snapshot.anchor_before(point)
331 } else {
332 Anchor::MIN
333 }
334 })
335 .ok()?;
336
337 Some(AgentLocation {
338 buffer: buffer.downgrade(),
339 position,
340 })
341 }
342
343 fn resolve_locations(
344 &self,
345 project: Entity<Project>,
346 cx: &mut App,
347 ) -> Task<Vec<Option<AgentLocation>>> {
348 let locations = self.locations.clone();
349 project.update(cx, |_, cx| {
350 cx.spawn(async move |project, cx| {
351 let mut new_locations = Vec::new();
352 for location in locations {
353 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
354 }
355 new_locations
356 })
357 })
358 }
359}
360
361#[derive(Debug)]
362pub enum ToolCallStatus {
363 /// The tool call hasn't started running yet, but we start showing it to
364 /// the user.
365 Pending,
366 /// The tool call is waiting for confirmation from the user.
367 WaitingForConfirmation {
368 options: Vec<acp::PermissionOption>,
369 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
370 },
371 /// The tool call is currently running.
372 InProgress,
373 /// The tool call completed successfully.
374 Completed,
375 /// The tool call failed.
376 Failed,
377 /// The user rejected the tool call.
378 Rejected,
379 /// The user canceled generation so the tool call was canceled.
380 Canceled,
381}
382
383impl From<acp::ToolCallStatus> for ToolCallStatus {
384 fn from(status: acp::ToolCallStatus) -> Self {
385 match status {
386 acp::ToolCallStatus::Pending => Self::Pending,
387 acp::ToolCallStatus::InProgress => Self::InProgress,
388 acp::ToolCallStatus::Completed => Self::Completed,
389 acp::ToolCallStatus::Failed => Self::Failed,
390 }
391 }
392}
393
394impl Display for ToolCallStatus {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 write!(
397 f,
398 "{}",
399 match self {
400 ToolCallStatus::Pending => "Pending",
401 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
402 ToolCallStatus::InProgress => "In Progress",
403 ToolCallStatus::Completed => "Completed",
404 ToolCallStatus::Failed => "Failed",
405 ToolCallStatus::Rejected => "Rejected",
406 ToolCallStatus::Canceled => "Canceled",
407 }
408 )
409 }
410}
411
412#[derive(Debug, PartialEq, Clone)]
413pub enum ContentBlock {
414 Empty,
415 Markdown { markdown: Entity<Markdown> },
416 ResourceLink { resource_link: acp::ResourceLink },
417}
418
419impl ContentBlock {
420 pub fn new(
421 block: acp::ContentBlock,
422 language_registry: &Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let mut this = Self::Empty;
426 this.append(block, language_registry, cx);
427 this
428 }
429
430 pub fn new_combined(
431 blocks: impl IntoIterator<Item = acp::ContentBlock>,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 let mut this = Self::Empty;
436 for block in blocks {
437 this.append(block, &language_registry, cx);
438 }
439 this
440 }
441
442 pub fn append(
443 &mut self,
444 block: acp::ContentBlock,
445 language_registry: &Arc<LanguageRegistry>,
446 cx: &mut App,
447 ) {
448 if matches!(self, ContentBlock::Empty)
449 && let acp::ContentBlock::ResourceLink(resource_link) = block
450 {
451 *self = ContentBlock::ResourceLink { resource_link };
452 return;
453 }
454
455 let new_content = self.block_string_contents(block);
456
457 match self {
458 ContentBlock::Empty => {
459 *self = Self::create_markdown_block(new_content, language_registry, cx);
460 }
461 ContentBlock::Markdown { markdown } => {
462 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
463 }
464 ContentBlock::ResourceLink { resource_link } => {
465 let existing_content = Self::resource_link_md(&resource_link.uri);
466 let combined = format!("{}\n{}", existing_content, new_content);
467
468 *self = Self::create_markdown_block(combined, language_registry, cx);
469 }
470 }
471 }
472
473 fn create_markdown_block(
474 content: String,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) -> ContentBlock {
478 ContentBlock::Markdown {
479 markdown: cx
480 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
481 }
482 }
483
484 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
485 match block {
486 acp::ContentBlock::Text(text_content) => text_content.text,
487 acp::ContentBlock::ResourceLink(resource_link) => {
488 Self::resource_link_md(&resource_link.uri)
489 }
490 acp::ContentBlock::Resource(acp::EmbeddedResource {
491 resource:
492 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
493 uri,
494 ..
495 }),
496 ..
497 }) => Self::resource_link_md(&uri),
498 acp::ContentBlock::Image(image) => Self::image_md(&image),
499 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
500 }
501 }
502
503 fn resource_link_md(uri: &str) -> String {
504 if let Some(uri) = MentionUri::parse(uri).log_err() {
505 uri.as_link().to_string()
506 } else {
507 uri.to_string()
508 }
509 }
510
511 fn image_md(_image: &acp::ImageContent) -> String {
512 "`Image`".into()
513 }
514
515 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
516 match self {
517 ContentBlock::Empty => "",
518 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
519 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
520 }
521 }
522
523 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
524 match self {
525 ContentBlock::Empty => None,
526 ContentBlock::Markdown { markdown } => Some(markdown),
527 ContentBlock::ResourceLink { .. } => None,
528 }
529 }
530
531 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
532 match self {
533 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
534 _ => None,
535 }
536 }
537}
538
539#[derive(Debug)]
540pub enum ToolCallContent {
541 ContentBlock(ContentBlock),
542 Diff(Entity<Diff>),
543 Terminal(Entity<Terminal>),
544}
545
546impl ToolCallContent {
547 pub fn from_acp(
548 content: acp::ToolCallContent,
549 language_registry: Arc<LanguageRegistry>,
550 cx: &mut App,
551 ) -> Self {
552 match content {
553 acp::ToolCallContent::Content { content } => {
554 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
555 }
556 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
557 Diff::finalized(
558 diff.path,
559 diff.old_text,
560 diff.new_text,
561 language_registry,
562 cx,
563 )
564 })),
565 }
566 }
567
568 pub fn update_from_acp(
569 &mut self,
570 new: acp::ToolCallContent,
571 language_registry: Arc<LanguageRegistry>,
572 cx: &mut App,
573 ) {
574 let needs_update = match (&self, &new) {
575 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
576 old_diff.read(cx).needs_update(
577 new_diff.old_text.as_deref().unwrap_or(""),
578 &new_diff.new_text,
579 cx,
580 )
581 }
582 _ => true,
583 };
584
585 if needs_update {
586 *self = Self::from_acp(new, language_registry, cx);
587 }
588 }
589
590 pub fn to_markdown(&self, cx: &App) -> String {
591 match self {
592 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
593 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
594 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
595 }
596 }
597}
598
599#[derive(Debug, PartialEq)]
600pub enum ToolCallUpdate {
601 UpdateFields(acp::ToolCallUpdate),
602 UpdateDiff(ToolCallUpdateDiff),
603 UpdateTerminal(ToolCallUpdateTerminal),
604}
605
606impl ToolCallUpdate {
607 fn id(&self) -> &acp::ToolCallId {
608 match self {
609 Self::UpdateFields(update) => &update.id,
610 Self::UpdateDiff(diff) => &diff.id,
611 Self::UpdateTerminal(terminal) => &terminal.id,
612 }
613 }
614}
615
616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
617 fn from(update: acp::ToolCallUpdate) -> Self {
618 Self::UpdateFields(update)
619 }
620}
621
622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
623 fn from(diff: ToolCallUpdateDiff) -> Self {
624 Self::UpdateDiff(diff)
625 }
626}
627
628#[derive(Debug, PartialEq)]
629pub struct ToolCallUpdateDiff {
630 pub id: acp::ToolCallId,
631 pub diff: Entity<Diff>,
632}
633
634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
635 fn from(terminal: ToolCallUpdateTerminal) -> Self {
636 Self::UpdateTerminal(terminal)
637 }
638}
639
640#[derive(Debug, PartialEq)]
641pub struct ToolCallUpdateTerminal {
642 pub id: acp::ToolCallId,
643 pub terminal: Entity<Terminal>,
644}
645
646#[derive(Debug, Default)]
647pub struct Plan {
648 pub entries: Vec<PlanEntry>,
649}
650
651#[derive(Debug)]
652pub struct PlanStats<'a> {
653 pub in_progress_entry: Option<&'a PlanEntry>,
654 pub pending: u32,
655 pub completed: u32,
656}
657
658impl Plan {
659 pub fn is_empty(&self) -> bool {
660 self.entries.is_empty()
661 }
662
663 pub fn stats(&self) -> PlanStats<'_> {
664 let mut stats = PlanStats {
665 in_progress_entry: None,
666 pending: 0,
667 completed: 0,
668 };
669
670 for entry in &self.entries {
671 match &entry.status {
672 acp::PlanEntryStatus::Pending => {
673 stats.pending += 1;
674 }
675 acp::PlanEntryStatus::InProgress => {
676 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
677 }
678 acp::PlanEntryStatus::Completed => {
679 stats.completed += 1;
680 }
681 }
682 }
683
684 stats
685 }
686}
687
688#[derive(Debug)]
689pub struct PlanEntry {
690 pub content: Entity<Markdown>,
691 pub priority: acp::PlanEntryPriority,
692 pub status: acp::PlanEntryStatus,
693}
694
695impl PlanEntry {
696 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
697 Self {
698 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
699 priority: entry.priority,
700 status: entry.status,
701 }
702 }
703}
704
705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct TokenUsage {
707 pub max_tokens: u64,
708 pub used_tokens: u64,
709}
710
711impl TokenUsage {
712 pub fn ratio(&self) -> TokenUsageRatio {
713 #[cfg(debug_assertions)]
714 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
715 .unwrap_or("0.8".to_string())
716 .parse()
717 .unwrap();
718 #[cfg(not(debug_assertions))]
719 let warning_threshold: f32 = 0.8;
720
721 // When the maximum is unknown because there is no selected model,
722 // avoid showing the token limit warning.
723 if self.max_tokens == 0 {
724 TokenUsageRatio::Normal
725 } else if self.used_tokens >= self.max_tokens {
726 TokenUsageRatio::Exceeded
727 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
728 TokenUsageRatio::Warning
729 } else {
730 TokenUsageRatio::Normal
731 }
732 }
733}
734
735#[derive(Debug, Clone, PartialEq, Eq)]
736pub enum TokenUsageRatio {
737 Normal,
738 Warning,
739 Exceeded,
740}
741
742#[derive(Debug, Clone)]
743pub struct RetryStatus {
744 pub last_error: SharedString,
745 pub attempt: usize,
746 pub max_attempts: usize,
747 pub started_at: Instant,
748 pub duration: Duration,
749}
750
751pub struct AcpThread {
752 title: SharedString,
753 entries: Vec<AgentThreadEntry>,
754 plan: Plan,
755 project: Entity<Project>,
756 action_log: Entity<ActionLog>,
757 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
758 send_task: Option<Task<()>>,
759 connection: Rc<dyn AgentConnection>,
760 session_id: acp::SessionId,
761 token_usage: Option<TokenUsage>,
762 prompt_capabilities: acp::PromptCapabilities,
763 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
764}
765
766#[derive(Debug)]
767pub enum AcpThreadEvent {
768 NewEntry,
769 TitleUpdated,
770 TokenUsageUpdated,
771 EntryUpdated(usize),
772 EntriesRemoved(Range<usize>),
773 ToolAuthorizationRequired,
774 Retry(RetryStatus),
775 Stopped,
776 Error,
777 LoadError(LoadError),
778 PromptCapabilitiesUpdated,
779}
780
781impl EventEmitter<AcpThreadEvent> for AcpThread {}
782
783#[derive(PartialEq, Eq, Debug)]
784pub enum ThreadStatus {
785 Idle,
786 WaitingForToolConfirmation,
787 Generating,
788}
789
790#[derive(Debug, Clone)]
791pub enum LoadError {
792 Unsupported {
793 command: SharedString,
794 current_version: SharedString,
795 minimum_version: SharedString,
796 },
797 FailedToInstall(SharedString),
798 Exited {
799 status: ExitStatus,
800 },
801 Other(SharedString),
802}
803
804impl Display for LoadError {
805 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
806 match self {
807 LoadError::Unsupported {
808 command: path,
809 current_version,
810 minimum_version,
811 } => {
812 write!(
813 f,
814 "version {current_version} from {path} is not supported (need at least {minimum_version})"
815 )
816 }
817 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
818 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
819 LoadError::Other(msg) => write!(f, "{msg}"),
820 }
821 }
822}
823
824impl Error for LoadError {}
825
826impl AcpThread {
827 pub fn new(
828 title: impl Into<SharedString>,
829 connection: Rc<dyn AgentConnection>,
830 project: Entity<Project>,
831 action_log: Entity<ActionLog>,
832 session_id: acp::SessionId,
833 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
834 cx: &mut Context<Self>,
835 ) -> Self {
836 let prompt_capabilities = *prompt_capabilities_rx.borrow();
837 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
838 loop {
839 let caps = prompt_capabilities_rx.recv().await?;
840 this.update(cx, |this, cx| {
841 this.prompt_capabilities = caps;
842 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
843 })?;
844 }
845 });
846
847 Self {
848 action_log,
849 shared_buffers: Default::default(),
850 entries: Default::default(),
851 plan: Default::default(),
852 title: title.into(),
853 project,
854 send_task: None,
855 connection,
856 session_id,
857 token_usage: None,
858 prompt_capabilities,
859 _observe_prompt_capabilities: task,
860 }
861 }
862
863 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
864 self.prompt_capabilities
865 }
866
867 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
868 &self.connection
869 }
870
871 /// Returns true if the agent supports custom slash commands.
872 pub fn supports_custom_commands(&self) -> bool {
873 self.prompt_capabilities.supports_custom_commands
874 }
875
876 pub fn action_log(&self) -> &Entity<ActionLog> {
877 &self.action_log
878 }
879
880 pub fn project(&self) -> &Entity<Project> {
881 &self.project
882 }
883
884 pub fn title(&self) -> SharedString {
885 self.title.clone()
886 }
887
888 pub fn entries(&self) -> &[AgentThreadEntry] {
889 &self.entries
890 }
891
892 pub fn session_id(&self) -> &acp::SessionId {
893 &self.session_id
894 }
895
896 pub fn status(&self) -> ThreadStatus {
897 if self.send_task.is_some() {
898 if self.waiting_for_tool_confirmation() {
899 ThreadStatus::WaitingForToolConfirmation
900 } else {
901 ThreadStatus::Generating
902 }
903 } else {
904 ThreadStatus::Idle
905 }
906 }
907
908 pub fn token_usage(&self) -> Option<&TokenUsage> {
909 self.token_usage.as_ref()
910 }
911
912 pub fn has_pending_edit_tool_calls(&self) -> bool {
913 for entry in self.entries.iter().rev() {
914 match entry {
915 AgentThreadEntry::UserMessage(_) => return false,
916 AgentThreadEntry::ToolCall(
917 call @ ToolCall {
918 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
919 ..
920 },
921 ) if call.diffs().next().is_some() => {
922 return true;
923 }
924 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
925 }
926 }
927
928 false
929 }
930
931 pub fn used_tools_since_last_user_message(&self) -> bool {
932 for entry in self.entries.iter().rev() {
933 match entry {
934 AgentThreadEntry::UserMessage(..) => return false,
935 AgentThreadEntry::AssistantMessage(..) => continue,
936 AgentThreadEntry::ToolCall(..) => return true,
937 }
938 }
939
940 false
941 }
942
943 pub fn handle_session_update(
944 &mut self,
945 update: acp::SessionUpdate,
946 cx: &mut Context<Self>,
947 ) -> Result<(), acp::Error> {
948 match update {
949 acp::SessionUpdate::UserMessageChunk { content } => {
950 self.push_user_content_block(None, content, cx);
951 }
952 acp::SessionUpdate::AgentMessageChunk { content } => {
953 self.push_assistant_content_block(content, false, cx);
954 }
955 acp::SessionUpdate::AgentThoughtChunk { content } => {
956 self.push_assistant_content_block(content, true, cx);
957 }
958 acp::SessionUpdate::ToolCall(tool_call) => {
959 self.upsert_tool_call(tool_call, cx)?;
960 }
961 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
962 self.update_tool_call(tool_call_update, cx)?;
963 }
964 acp::SessionUpdate::Plan(plan) => {
965 self.update_plan(plan, cx);
966 }
967 }
968 Ok(())
969 }
970
971 pub fn push_user_content_block(
972 &mut self,
973 message_id: Option<UserMessageId>,
974 chunk: acp::ContentBlock,
975 cx: &mut Context<Self>,
976 ) {
977 let language_registry = self.project.read(cx).languages().clone();
978 let entries_len = self.entries.len();
979
980 if let Some(last_entry) = self.entries.last_mut()
981 && let AgentThreadEntry::UserMessage(UserMessage {
982 id,
983 content,
984 chunks,
985 ..
986 }) = last_entry
987 {
988 *id = message_id.or(id.take());
989 content.append(chunk.clone(), &language_registry, cx);
990 chunks.push(chunk);
991 let idx = entries_len - 1;
992 cx.emit(AcpThreadEvent::EntryUpdated(idx));
993 } else {
994 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
995 self.push_entry(
996 AgentThreadEntry::UserMessage(UserMessage {
997 id: message_id,
998 content,
999 chunks: vec![chunk],
1000 checkpoint: None,
1001 }),
1002 cx,
1003 );
1004 }
1005 }
1006
1007 pub fn push_assistant_content_block(
1008 &mut self,
1009 chunk: acp::ContentBlock,
1010 is_thought: bool,
1011 cx: &mut Context<Self>,
1012 ) {
1013 let language_registry = self.project.read(cx).languages().clone();
1014 let entries_len = self.entries.len();
1015 if let Some(last_entry) = self.entries.last_mut()
1016 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1017 {
1018 let idx = entries_len - 1;
1019 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1020 match (chunks.last_mut(), is_thought) {
1021 (Some(AssistantMessageChunk::Message { block }), false)
1022 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1023 block.append(chunk, &language_registry, cx)
1024 }
1025 _ => {
1026 let block = ContentBlock::new(chunk, &language_registry, cx);
1027 if is_thought {
1028 chunks.push(AssistantMessageChunk::Thought { block })
1029 } else {
1030 chunks.push(AssistantMessageChunk::Message { block })
1031 }
1032 }
1033 }
1034 } else {
1035 let block = ContentBlock::new(chunk, &language_registry, cx);
1036 let chunk = if is_thought {
1037 AssistantMessageChunk::Thought { block }
1038 } else {
1039 AssistantMessageChunk::Message { block }
1040 };
1041
1042 self.push_entry(
1043 AgentThreadEntry::AssistantMessage(AssistantMessage {
1044 chunks: vec![chunk],
1045 }),
1046 cx,
1047 );
1048 }
1049 }
1050
1051 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1052 self.entries.push(entry);
1053 cx.emit(AcpThreadEvent::NewEntry);
1054 }
1055
1056 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1057 self.connection.set_title(&self.session_id, cx).is_some()
1058 }
1059
1060 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1061 if title != self.title {
1062 self.title = title.clone();
1063 cx.emit(AcpThreadEvent::TitleUpdated);
1064 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1065 return set_title.run(title, cx);
1066 }
1067 }
1068 Task::ready(Ok(()))
1069 }
1070
1071 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1072 self.token_usage = usage;
1073 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1074 }
1075
1076 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1077 cx.emit(AcpThreadEvent::Retry(status));
1078 }
1079
1080 pub fn update_tool_call(
1081 &mut self,
1082 update: impl Into<ToolCallUpdate>,
1083 cx: &mut Context<Self>,
1084 ) -> Result<()> {
1085 let update = update.into();
1086 let languages = self.project.read(cx).languages().clone();
1087
1088 let (ix, current_call) = self
1089 .tool_call_mut(update.id())
1090 .context("Tool call not found")?;
1091 match update {
1092 ToolCallUpdate::UpdateFields(update) => {
1093 let location_updated = update.fields.locations.is_some();
1094 current_call.update_fields(update.fields, languages, cx);
1095 if location_updated {
1096 self.resolve_locations(update.id, cx);
1097 }
1098 }
1099 ToolCallUpdate::UpdateDiff(update) => {
1100 current_call.content.clear();
1101 current_call
1102 .content
1103 .push(ToolCallContent::Diff(update.diff));
1104 }
1105 ToolCallUpdate::UpdateTerminal(update) => {
1106 current_call.content.clear();
1107 current_call
1108 .content
1109 .push(ToolCallContent::Terminal(update.terminal));
1110 }
1111 }
1112
1113 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1114
1115 Ok(())
1116 }
1117
1118 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1119 pub fn upsert_tool_call(
1120 &mut self,
1121 tool_call: acp::ToolCall,
1122 cx: &mut Context<Self>,
1123 ) -> Result<(), acp::Error> {
1124 let status = tool_call.status.into();
1125 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1126 }
1127
1128 /// Fails if id does not match an existing entry.
1129 pub fn upsert_tool_call_inner(
1130 &mut self,
1131 tool_call_update: acp::ToolCallUpdate,
1132 status: ToolCallStatus,
1133 cx: &mut Context<Self>,
1134 ) -> Result<(), acp::Error> {
1135 let language_registry = self.project.read(cx).languages().clone();
1136 let id = tool_call_update.id.clone();
1137
1138 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1139 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1140 current_call.status = status;
1141
1142 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1143 } else {
1144 let call =
1145 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1146 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1147 };
1148
1149 self.resolve_locations(id, cx);
1150 Ok(())
1151 }
1152
1153 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1154 // The tool call we are looking for is typically the last one, or very close to the end.
1155 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1156 self.entries
1157 .iter_mut()
1158 .enumerate()
1159 .rev()
1160 .find_map(|(index, tool_call)| {
1161 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1162 && &tool_call.id == id
1163 {
1164 Some((index, tool_call))
1165 } else {
1166 None
1167 }
1168 })
1169 }
1170
1171 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1172 self.entries
1173 .iter()
1174 .enumerate()
1175 .rev()
1176 .find_map(|(index, tool_call)| {
1177 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1178 && &tool_call.id == id
1179 {
1180 Some((index, tool_call))
1181 } else {
1182 None
1183 }
1184 })
1185 }
1186
1187 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1188 let project = self.project.clone();
1189 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1190 return;
1191 };
1192 let task = tool_call.resolve_locations(project, cx);
1193 cx.spawn(async move |this, cx| {
1194 let resolved_locations = task.await;
1195 this.update(cx, |this, cx| {
1196 let project = this.project.clone();
1197 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1198 return;
1199 };
1200 if let Some(Some(location)) = resolved_locations.last() {
1201 project.update(cx, |project, cx| {
1202 if let Some(agent_location) = project.agent_location() {
1203 let should_ignore = agent_location.buffer == location.buffer
1204 && location
1205 .buffer
1206 .update(cx, |buffer, _| {
1207 let snapshot = buffer.snapshot();
1208 let old_position =
1209 agent_location.position.to_point(&snapshot);
1210 let new_position = location.position.to_point(&snapshot);
1211 // ignore this so that when we get updates from the edit tool
1212 // the position doesn't reset to the startof line
1213 old_position.row == new_position.row
1214 && old_position.column > new_position.column
1215 })
1216 .ok()
1217 .unwrap_or_default();
1218 if !should_ignore {
1219 project.set_agent_location(Some(location.clone()), cx);
1220 }
1221 }
1222 });
1223 }
1224 if tool_call.resolved_locations != resolved_locations {
1225 tool_call.resolved_locations = resolved_locations;
1226 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1227 }
1228 })
1229 })
1230 .detach();
1231 }
1232
1233 pub fn request_tool_call_authorization(
1234 &mut self,
1235 tool_call: acp::ToolCallUpdate,
1236 options: Vec<acp::PermissionOption>,
1237 cx: &mut Context<Self>,
1238 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1239 let (tx, rx) = oneshot::channel();
1240
1241 let status = ToolCallStatus::WaitingForConfirmation {
1242 options,
1243 respond_tx: tx,
1244 };
1245
1246 self.upsert_tool_call_inner(tool_call, status, cx)?;
1247 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1248 Ok(rx)
1249 }
1250
1251 pub fn authorize_tool_call(
1252 &mut self,
1253 id: acp::ToolCallId,
1254 option_id: acp::PermissionOptionId,
1255 option_kind: acp::PermissionOptionKind,
1256 cx: &mut Context<Self>,
1257 ) {
1258 let Some((ix, call)) = self.tool_call_mut(&id) else {
1259 return;
1260 };
1261
1262 let new_status = match option_kind {
1263 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1264 ToolCallStatus::Rejected
1265 }
1266 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1267 ToolCallStatus::InProgress
1268 }
1269 };
1270
1271 let curr_status = mem::replace(&mut call.status, new_status);
1272
1273 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1274 respond_tx.send(option_id).log_err();
1275 } else if cfg!(debug_assertions) {
1276 panic!("tried to authorize an already authorized tool call");
1277 }
1278
1279 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1280 }
1281
1282 /// Returns true if the last turn is awaiting tool authorization
1283 pub fn waiting_for_tool_confirmation(&self) -> bool {
1284 for entry in self.entries.iter().rev() {
1285 match &entry {
1286 AgentThreadEntry::ToolCall(call) => match call.status {
1287 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1288 ToolCallStatus::Pending
1289 | ToolCallStatus::InProgress
1290 | ToolCallStatus::Completed
1291 | ToolCallStatus::Failed
1292 | ToolCallStatus::Rejected
1293 | ToolCallStatus::Canceled => continue,
1294 },
1295 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1296 // Reached the beginning of the turn
1297 return false;
1298 }
1299 }
1300 }
1301 false
1302 }
1303
1304 pub fn plan(&self) -> &Plan {
1305 &self.plan
1306 }
1307
1308 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1309 let new_entries_len = request.entries.len();
1310 let mut new_entries = request.entries.into_iter();
1311
1312 // Reuse existing markdown to prevent flickering
1313 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1314 let PlanEntry {
1315 content,
1316 priority,
1317 status,
1318 } = old;
1319 content.update(cx, |old, cx| {
1320 old.replace(new.content, cx);
1321 });
1322 *priority = new.priority;
1323 *status = new.status;
1324 }
1325 for new in new_entries {
1326 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1327 }
1328 self.plan.entries.truncate(new_entries_len);
1329
1330 cx.notify();
1331 }
1332
1333 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1334 self.plan
1335 .entries
1336 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1337 cx.notify();
1338 }
1339
1340 #[cfg(any(test, feature = "test-support"))]
1341 pub fn send_raw(
1342 &mut self,
1343 message: &str,
1344 cx: &mut Context<Self>,
1345 ) -> BoxFuture<'static, Result<()>> {
1346 self.send(
1347 vec![acp::ContentBlock::Text(acp::TextContent {
1348 text: message.to_string(),
1349 annotations: None,
1350 })],
1351 cx,
1352 )
1353 }
1354
1355 pub fn send(
1356 &mut self,
1357 message: Vec<acp::ContentBlock>,
1358 cx: &mut Context<Self>,
1359 ) -> BoxFuture<'static, Result<()>> {
1360 let block = ContentBlock::new_combined(
1361 message.clone(),
1362 self.project.read(cx).languages().clone(),
1363 cx,
1364 );
1365 let request = acp::PromptRequest {
1366 prompt: message.clone(),
1367 session_id: self.session_id.clone(),
1368 };
1369 let git_store = self.project.read(cx).git_store().clone();
1370
1371 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1372 Some(UserMessageId::new())
1373 } else {
1374 None
1375 };
1376
1377 self.run_turn(cx, async move |this, cx| {
1378 this.update(cx, |this, cx| {
1379 this.push_entry(
1380 AgentThreadEntry::UserMessage(UserMessage {
1381 id: message_id.clone(),
1382 content: block,
1383 chunks: message,
1384 checkpoint: None,
1385 }),
1386 cx,
1387 );
1388 })
1389 .ok();
1390
1391 let old_checkpoint = git_store
1392 .update(cx, |git, cx| git.checkpoint(cx))?
1393 .await
1394 .context("failed to get old checkpoint")
1395 .log_err();
1396 this.update(cx, |this, cx| {
1397 if let Some((_ix, message)) = this.last_user_message() {
1398 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1399 git_checkpoint,
1400 show: false,
1401 });
1402 }
1403 this.connection.prompt(message_id, request, cx)
1404 })?
1405 .await
1406 })
1407 }
1408
1409 pub fn can_resume(&self, cx: &App) -> bool {
1410 self.connection.resume(&self.session_id, cx).is_some()
1411 }
1412
1413 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1414 self.run_turn(cx, async move |this, cx| {
1415 this.update(cx, |this, cx| {
1416 this.connection
1417 .resume(&this.session_id, cx)
1418 .map(|resume| resume.run(cx))
1419 })?
1420 .context("resuming a session is not supported")?
1421 .await
1422 })
1423 }
1424
1425 fn run_turn(
1426 &mut self,
1427 cx: &mut Context<Self>,
1428 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1429 ) -> BoxFuture<'static, Result<()>> {
1430 self.clear_completed_plan_entries(cx);
1431
1432 let (tx, rx) = oneshot::channel();
1433 let cancel_task = self.cancel(cx);
1434
1435 self.send_task = Some(cx.spawn(async move |this, cx| {
1436 cancel_task.await;
1437 tx.send(f(this, cx).await).ok();
1438 }));
1439
1440 cx.spawn(async move |this, cx| {
1441 let response = rx.await;
1442
1443 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1444 .await?;
1445
1446 this.update(cx, |this, cx| {
1447 this.project
1448 .update(cx, |project, cx| project.set_agent_location(None, cx));
1449 match response {
1450 Ok(Err(e)) => {
1451 this.send_task.take();
1452 cx.emit(AcpThreadEvent::Error);
1453 Err(e)
1454 }
1455 result => {
1456 let canceled = matches!(
1457 result,
1458 Ok(Ok(acp::PromptResponse {
1459 stop_reason: acp::StopReason::Cancelled
1460 }))
1461 );
1462
1463 // We only take the task if the current prompt wasn't canceled.
1464 //
1465 // This prompt may have been canceled because another one was sent
1466 // while it was still generating. In these cases, dropping `send_task`
1467 // would cause the next generation to be canceled.
1468 if !canceled {
1469 this.send_task.take();
1470 }
1471
1472 // Truncate entries if the last prompt was refused.
1473 if let Ok(Ok(acp::PromptResponse {
1474 stop_reason: acp::StopReason::Refusal,
1475 })) = result
1476 && let Some((ix, _)) = this.last_user_message()
1477 {
1478 let range = ix..this.entries.len();
1479 this.entries.truncate(ix);
1480 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1481 }
1482
1483 cx.emit(AcpThreadEvent::Stopped);
1484 Ok(())
1485 }
1486 }
1487 })?
1488 })
1489 .boxed()
1490 }
1491
1492 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1493 let Some(send_task) = self.send_task.take() else {
1494 return Task::ready(());
1495 };
1496
1497 for entry in self.entries.iter_mut() {
1498 if let AgentThreadEntry::ToolCall(call) = entry {
1499 let cancel = matches!(
1500 call.status,
1501 ToolCallStatus::Pending
1502 | ToolCallStatus::WaitingForConfirmation { .. }
1503 | ToolCallStatus::InProgress
1504 );
1505
1506 if cancel {
1507 call.status = ToolCallStatus::Canceled;
1508 }
1509 }
1510 }
1511
1512 self.connection.cancel(&self.session_id, cx);
1513
1514 // Wait for the send task to complete
1515 cx.foreground_executor().spawn(send_task)
1516 }
1517
1518 /// Rewinds this thread to before the entry at `index`, removing it and all
1519 /// subsequent entries while reverting any changes made from that point.
1520 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1521 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1522 return Task::ready(Err(anyhow!("not supported")));
1523 };
1524 let Some(message) = self.user_message(&id) else {
1525 return Task::ready(Err(anyhow!("message not found")));
1526 };
1527
1528 let checkpoint = message
1529 .checkpoint
1530 .as_ref()
1531 .map(|c| c.git_checkpoint.clone());
1532
1533 let git_store = self.project.read(cx).git_store().clone();
1534 cx.spawn(async move |this, cx| {
1535 if let Some(checkpoint) = checkpoint {
1536 git_store
1537 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1538 .await?;
1539 }
1540
1541 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1542 this.update(cx, |this, cx| {
1543 if let Some((ix, _)) = this.user_message_mut(&id) {
1544 let range = ix..this.entries.len();
1545 this.entries.truncate(ix);
1546 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1547 }
1548 })
1549 })
1550 }
1551
1552 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1553 let git_store = self.project.read(cx).git_store().clone();
1554
1555 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1556 if let Some(checkpoint) = message.checkpoint.as_ref() {
1557 checkpoint.git_checkpoint.clone()
1558 } else {
1559 return Task::ready(Ok(()));
1560 }
1561 } else {
1562 return Task::ready(Ok(()));
1563 };
1564
1565 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1566 cx.spawn(async move |this, cx| {
1567 let new_checkpoint = new_checkpoint
1568 .await
1569 .context("failed to get new checkpoint")
1570 .log_err();
1571 if let Some(new_checkpoint) = new_checkpoint {
1572 let equal = git_store
1573 .update(cx, |git, cx| {
1574 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1575 })?
1576 .await
1577 .unwrap_or(true);
1578 this.update(cx, |this, cx| {
1579 let (ix, message) = this.last_user_message().context("no user message")?;
1580 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1581 checkpoint.show = !equal;
1582 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1583 anyhow::Ok(())
1584 })??;
1585 }
1586
1587 Ok(())
1588 })
1589 }
1590
1591 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1592 self.entries
1593 .iter_mut()
1594 .enumerate()
1595 .rev()
1596 .find_map(|(ix, entry)| {
1597 if let AgentThreadEntry::UserMessage(message) = entry {
1598 Some((ix, message))
1599 } else {
1600 None
1601 }
1602 })
1603 }
1604
1605 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1606 self.entries.iter().find_map(|entry| {
1607 if let AgentThreadEntry::UserMessage(message) = entry {
1608 if message.id.as_ref() == Some(id) {
1609 Some(message)
1610 } else {
1611 None
1612 }
1613 } else {
1614 None
1615 }
1616 })
1617 }
1618
1619 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1620 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1621 if let AgentThreadEntry::UserMessage(message) = entry {
1622 if message.id.as_ref() == Some(id) {
1623 Some((ix, message))
1624 } else {
1625 None
1626 }
1627 } else {
1628 None
1629 }
1630 })
1631 }
1632
1633 pub fn read_text_file(
1634 &self,
1635 path: PathBuf,
1636 line: Option<u32>,
1637 limit: Option<u32>,
1638 reuse_shared_snapshot: bool,
1639 cx: &mut Context<Self>,
1640 ) -> Task<Result<String>> {
1641 let project = self.project.clone();
1642 let action_log = self.action_log.clone();
1643 cx.spawn(async move |this, cx| {
1644 let load = project.update(cx, |project, cx| {
1645 let path = project
1646 .project_path_for_absolute_path(&path, cx)
1647 .context("invalid path")?;
1648 anyhow::Ok(project.open_buffer(path, cx))
1649 });
1650 let buffer = load??.await?;
1651
1652 let snapshot = if reuse_shared_snapshot {
1653 this.read_with(cx, |this, _| {
1654 this.shared_buffers.get(&buffer.clone()).cloned()
1655 })
1656 .log_err()
1657 .flatten()
1658 } else {
1659 None
1660 };
1661
1662 let snapshot = if let Some(snapshot) = snapshot {
1663 snapshot
1664 } else {
1665 action_log.update(cx, |action_log, cx| {
1666 action_log.buffer_read(buffer.clone(), cx);
1667 })?;
1668 project.update(cx, |project, cx| {
1669 let position = buffer
1670 .read(cx)
1671 .snapshot()
1672 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1673 project.set_agent_location(
1674 Some(AgentLocation {
1675 buffer: buffer.downgrade(),
1676 position,
1677 }),
1678 cx,
1679 );
1680 })?;
1681
1682 buffer.update(cx, |buffer, _| buffer.snapshot())?
1683 };
1684
1685 this.update(cx, |this, _| {
1686 let text = snapshot.text();
1687 this.shared_buffers.insert(buffer.clone(), snapshot);
1688 if line.is_none() && limit.is_none() {
1689 return Ok(text);
1690 }
1691 let limit = limit.unwrap_or(u32::MAX) as usize;
1692 let Some(line) = line else {
1693 return Ok(text.lines().take(limit).collect::<String>());
1694 };
1695
1696 let count = text.lines().count();
1697 if count < line as usize {
1698 anyhow::bail!("There are only {} lines", count);
1699 }
1700 Ok(text
1701 .lines()
1702 .skip(line as usize + 1)
1703 .take(limit)
1704 .collect::<String>())
1705 })?
1706 })
1707 }
1708
1709 pub fn write_text_file(
1710 &self,
1711 path: PathBuf,
1712 content: String,
1713 cx: &mut Context<Self>,
1714 ) -> Task<Result<()>> {
1715 let project = self.project.clone();
1716 let action_log = self.action_log.clone();
1717 cx.spawn(async move |this, cx| {
1718 let load = project.update(cx, |project, cx| {
1719 let path = project
1720 .project_path_for_absolute_path(&path, cx)
1721 .context("invalid path")?;
1722 anyhow::Ok(project.open_buffer(path, cx))
1723 });
1724 let buffer = load??.await?;
1725 let snapshot = this.update(cx, |this, cx| {
1726 this.shared_buffers
1727 .get(&buffer)
1728 .cloned()
1729 .unwrap_or_else(|| buffer.read(cx).snapshot())
1730 })?;
1731 let edits = cx
1732 .background_executor()
1733 .spawn(async move {
1734 let old_text = snapshot.text();
1735 text_diff(old_text.as_str(), &content)
1736 .into_iter()
1737 .map(|(range, replacement)| {
1738 (
1739 snapshot.anchor_after(range.start)
1740 ..snapshot.anchor_before(range.end),
1741 replacement,
1742 )
1743 })
1744 .collect::<Vec<_>>()
1745 })
1746 .await;
1747
1748 project.update(cx, |project, cx| {
1749 project.set_agent_location(
1750 Some(AgentLocation {
1751 buffer: buffer.downgrade(),
1752 position: edits
1753 .last()
1754 .map(|(range, _)| range.end)
1755 .unwrap_or(Anchor::MIN),
1756 }),
1757 cx,
1758 );
1759 })?;
1760
1761 let format_on_save = cx.update(|cx| {
1762 action_log.update(cx, |action_log, cx| {
1763 action_log.buffer_read(buffer.clone(), cx);
1764 });
1765
1766 let format_on_save = buffer.update(cx, |buffer, cx| {
1767 buffer.edit(edits, None, cx);
1768
1769 let settings = language::language_settings::language_settings(
1770 buffer.language().map(|l| l.name()),
1771 buffer.file(),
1772 cx,
1773 );
1774
1775 settings.format_on_save != FormatOnSave::Off
1776 });
1777 action_log.update(cx, |action_log, cx| {
1778 action_log.buffer_edited(buffer.clone(), cx);
1779 });
1780 format_on_save
1781 })?;
1782
1783 if format_on_save {
1784 let format_task = project.update(cx, |project, cx| {
1785 project.format(
1786 HashSet::from_iter([buffer.clone()]),
1787 LspFormatTarget::Buffers,
1788 false,
1789 FormatTrigger::Save,
1790 cx,
1791 )
1792 })?;
1793 format_task.await.log_err();
1794
1795 action_log.update(cx, |action_log, cx| {
1796 action_log.buffer_edited(buffer.clone(), cx);
1797 })?;
1798 }
1799
1800 project
1801 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1802 .await
1803 })
1804 }
1805
1806 pub fn to_markdown(&self, cx: &App) -> String {
1807 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1808 }
1809
1810 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1811 cx.emit(AcpThreadEvent::LoadError(error));
1812 }
1813}
1814
1815fn markdown_for_raw_output(
1816 raw_output: &serde_json::Value,
1817 language_registry: &Arc<LanguageRegistry>,
1818 cx: &mut App,
1819) -> Option<Entity<Markdown>> {
1820 match raw_output {
1821 serde_json::Value::Null => None,
1822 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1823 Markdown::new(
1824 value.to_string().into(),
1825 Some(language_registry.clone()),
1826 None,
1827 cx,
1828 )
1829 })),
1830 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1831 Markdown::new(
1832 value.to_string().into(),
1833 Some(language_registry.clone()),
1834 None,
1835 cx,
1836 )
1837 })),
1838 serde_json::Value::String(value) => Some(cx.new(|cx| {
1839 Markdown::new(
1840 value.clone().into(),
1841 Some(language_registry.clone()),
1842 None,
1843 cx,
1844 )
1845 })),
1846 value => Some(cx.new(|cx| {
1847 Markdown::new(
1848 format!("```json\n{}\n```", value).into(),
1849 Some(language_registry.clone()),
1850 None,
1851 cx,
1852 )
1853 })),
1854 }
1855}
1856
1857#[cfg(test)]
1858mod tests {
1859 use super::*;
1860 use anyhow::anyhow;
1861 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1862 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1863 use indoc::indoc;
1864 use project::{FakeFs, Fs};
1865 use rand::Rng as _;
1866 use serde_json::json;
1867 use settings::SettingsStore;
1868 use smol::stream::StreamExt as _;
1869 use std::{
1870 any::Any,
1871 cell::RefCell,
1872 path::Path,
1873 rc::Rc,
1874 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1875 time::Duration,
1876 };
1877 use util::path;
1878
1879 fn init_test(cx: &mut TestAppContext) {
1880 env_logger::try_init().ok();
1881 cx.update(|cx| {
1882 let settings_store = SettingsStore::test(cx);
1883 cx.set_global(settings_store);
1884 Project::init_settings(cx);
1885 language::init(cx);
1886 });
1887 }
1888
1889 #[gpui::test]
1890 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1891 init_test(cx);
1892
1893 let fs = FakeFs::new(cx.executor());
1894 let project = Project::test(fs, [], cx).await;
1895 let connection = Rc::new(FakeAgentConnection::new());
1896 let thread = cx
1897 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1898 .await
1899 .unwrap();
1900
1901 // Test creating a new user message
1902 thread.update(cx, |thread, cx| {
1903 thread.push_user_content_block(
1904 None,
1905 acp::ContentBlock::Text(acp::TextContent {
1906 annotations: None,
1907 text: "Hello, ".to_string(),
1908 }),
1909 cx,
1910 );
1911 });
1912
1913 thread.update(cx, |thread, cx| {
1914 assert_eq!(thread.entries.len(), 1);
1915 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1916 assert_eq!(user_msg.id, None);
1917 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1918 } else {
1919 panic!("Expected UserMessage");
1920 }
1921 });
1922
1923 // Test appending to existing user message
1924 let message_1_id = UserMessageId::new();
1925 thread.update(cx, |thread, cx| {
1926 thread.push_user_content_block(
1927 Some(message_1_id.clone()),
1928 acp::ContentBlock::Text(acp::TextContent {
1929 annotations: None,
1930 text: "world!".to_string(),
1931 }),
1932 cx,
1933 );
1934 });
1935
1936 thread.update(cx, |thread, cx| {
1937 assert_eq!(thread.entries.len(), 1);
1938 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1939 assert_eq!(user_msg.id, Some(message_1_id));
1940 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1941 } else {
1942 panic!("Expected UserMessage");
1943 }
1944 });
1945
1946 // Test creating new user message after assistant message
1947 thread.update(cx, |thread, cx| {
1948 thread.push_assistant_content_block(
1949 acp::ContentBlock::Text(acp::TextContent {
1950 annotations: None,
1951 text: "Assistant response".to_string(),
1952 }),
1953 false,
1954 cx,
1955 );
1956 });
1957
1958 let message_2_id = UserMessageId::new();
1959 thread.update(cx, |thread, cx| {
1960 thread.push_user_content_block(
1961 Some(message_2_id.clone()),
1962 acp::ContentBlock::Text(acp::TextContent {
1963 annotations: None,
1964 text: "New user message".to_string(),
1965 }),
1966 cx,
1967 );
1968 });
1969
1970 thread.update(cx, |thread, cx| {
1971 assert_eq!(thread.entries.len(), 3);
1972 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1973 assert_eq!(user_msg.id, Some(message_2_id));
1974 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1975 } else {
1976 panic!("Expected UserMessage at index 2");
1977 }
1978 });
1979 }
1980
1981 #[gpui::test]
1982 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1983 init_test(cx);
1984
1985 let fs = FakeFs::new(cx.executor());
1986 let project = Project::test(fs, [], cx).await;
1987 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1988 |_, thread, mut cx| {
1989 async move {
1990 thread.update(&mut cx, |thread, cx| {
1991 thread
1992 .handle_session_update(
1993 acp::SessionUpdate::AgentThoughtChunk {
1994 content: "Thinking ".into(),
1995 },
1996 cx,
1997 )
1998 .unwrap();
1999 thread
2000 .handle_session_update(
2001 acp::SessionUpdate::AgentThoughtChunk {
2002 content: "hard!".into(),
2003 },
2004 cx,
2005 )
2006 .unwrap();
2007 })?;
2008 Ok(acp::PromptResponse {
2009 stop_reason: acp::StopReason::EndTurn,
2010 })
2011 }
2012 .boxed_local()
2013 },
2014 ));
2015
2016 let thread = cx
2017 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2018 .await
2019 .unwrap();
2020
2021 thread
2022 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2023 .await
2024 .unwrap();
2025
2026 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2027 assert_eq!(
2028 output,
2029 indoc! {r#"
2030 ## User
2031
2032 Hello from Zed!
2033
2034 ## Assistant
2035
2036 <thinking>
2037 Thinking hard!
2038 </thinking>
2039
2040 "#}
2041 );
2042 }
2043
2044 #[gpui::test]
2045 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2046 init_test(cx);
2047
2048 let fs = FakeFs::new(cx.executor());
2049 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2050 .await;
2051 let project = Project::test(fs.clone(), [], cx).await;
2052 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2053 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2054 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2055 move |_, thread, mut cx| {
2056 let read_file_tx = read_file_tx.clone();
2057 async move {
2058 let content = thread
2059 .update(&mut cx, |thread, cx| {
2060 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2061 })
2062 .unwrap()
2063 .await
2064 .unwrap();
2065 assert_eq!(content, "one\ntwo\nthree\n");
2066 read_file_tx.take().unwrap().send(()).unwrap();
2067 thread
2068 .update(&mut cx, |thread, cx| {
2069 thread.write_text_file(
2070 path!("/tmp/foo").into(),
2071 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2072 cx,
2073 )
2074 })
2075 .unwrap()
2076 .await
2077 .unwrap();
2078 Ok(acp::PromptResponse {
2079 stop_reason: acp::StopReason::EndTurn,
2080 })
2081 }
2082 .boxed_local()
2083 },
2084 ));
2085
2086 let (worktree, pathbuf) = project
2087 .update(cx, |project, cx| {
2088 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2089 })
2090 .await
2091 .unwrap();
2092 let buffer = project
2093 .update(cx, |project, cx| {
2094 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2095 })
2096 .await
2097 .unwrap();
2098
2099 let thread = cx
2100 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2101 .await
2102 .unwrap();
2103
2104 let request = thread.update(cx, |thread, cx| {
2105 thread.send_raw("Extend the count in /tmp/foo", cx)
2106 });
2107 read_file_rx.await.ok();
2108 buffer.update(cx, |buffer, cx| {
2109 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2110 });
2111 cx.run_until_parked();
2112 assert_eq!(
2113 buffer.read_with(cx, |buffer, _| buffer.text()),
2114 "zero\none\ntwo\nthree\nfour\nfive\n"
2115 );
2116 assert_eq!(
2117 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2118 "zero\none\ntwo\nthree\nfour\nfive\n"
2119 );
2120 request.await.unwrap();
2121 }
2122
2123 #[gpui::test]
2124 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2125 init_test(cx);
2126
2127 let fs = FakeFs::new(cx.executor());
2128 let project = Project::test(fs, [], cx).await;
2129 let id = acp::ToolCallId("test".into());
2130
2131 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2132 let id = id.clone();
2133 move |_, thread, mut cx| {
2134 let id = id.clone();
2135 async move {
2136 thread
2137 .update(&mut cx, |thread, cx| {
2138 thread.handle_session_update(
2139 acp::SessionUpdate::ToolCall(acp::ToolCall {
2140 id: id.clone(),
2141 title: "Label".into(),
2142 kind: acp::ToolKind::Fetch,
2143 status: acp::ToolCallStatus::InProgress,
2144 content: vec![],
2145 locations: vec![],
2146 raw_input: None,
2147 raw_output: None,
2148 }),
2149 cx,
2150 )
2151 })
2152 .unwrap()
2153 .unwrap();
2154 Ok(acp::PromptResponse {
2155 stop_reason: acp::StopReason::EndTurn,
2156 })
2157 }
2158 .boxed_local()
2159 }
2160 }));
2161
2162 let thread = cx
2163 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2164 .await
2165 .unwrap();
2166
2167 let request = thread.update(cx, |thread, cx| {
2168 thread.send_raw("Fetch https://example.com", cx)
2169 });
2170
2171 run_until_first_tool_call(&thread, cx).await;
2172
2173 thread.read_with(cx, |thread, _| {
2174 assert!(matches!(
2175 thread.entries[1],
2176 AgentThreadEntry::ToolCall(ToolCall {
2177 status: ToolCallStatus::InProgress,
2178 ..
2179 })
2180 ));
2181 });
2182
2183 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2184
2185 thread.read_with(cx, |thread, _| {
2186 assert!(matches!(
2187 &thread.entries[1],
2188 AgentThreadEntry::ToolCall(ToolCall {
2189 status: ToolCallStatus::Canceled,
2190 ..
2191 })
2192 ));
2193 });
2194
2195 thread
2196 .update(cx, |thread, cx| {
2197 thread.handle_session_update(
2198 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2199 id,
2200 fields: acp::ToolCallUpdateFields {
2201 status: Some(acp::ToolCallStatus::Completed),
2202 ..Default::default()
2203 },
2204 }),
2205 cx,
2206 )
2207 })
2208 .unwrap();
2209
2210 request.await.unwrap();
2211
2212 thread.read_with(cx, |thread, _| {
2213 assert!(matches!(
2214 thread.entries[1],
2215 AgentThreadEntry::ToolCall(ToolCall {
2216 status: ToolCallStatus::Completed,
2217 ..
2218 })
2219 ));
2220 });
2221 }
2222
2223 #[gpui::test]
2224 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2225 init_test(cx);
2226 let fs = FakeFs::new(cx.background_executor.clone());
2227 fs.insert_tree(path!("/test"), json!({})).await;
2228 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2229
2230 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2231 move |_, thread, mut cx| {
2232 async move {
2233 thread
2234 .update(&mut cx, |thread, cx| {
2235 thread.handle_session_update(
2236 acp::SessionUpdate::ToolCall(acp::ToolCall {
2237 id: acp::ToolCallId("test".into()),
2238 title: "Label".into(),
2239 kind: acp::ToolKind::Edit,
2240 status: acp::ToolCallStatus::Completed,
2241 content: vec![acp::ToolCallContent::Diff {
2242 diff: acp::Diff {
2243 path: "/test/test.txt".into(),
2244 old_text: None,
2245 new_text: "foo".into(),
2246 },
2247 }],
2248 locations: vec![],
2249 raw_input: None,
2250 raw_output: None,
2251 }),
2252 cx,
2253 )
2254 })
2255 .unwrap()
2256 .unwrap();
2257 Ok(acp::PromptResponse {
2258 stop_reason: acp::StopReason::EndTurn,
2259 })
2260 }
2261 .boxed_local()
2262 }
2263 }));
2264
2265 let thread = cx
2266 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2267 .await
2268 .unwrap();
2269
2270 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2271 .await
2272 .unwrap();
2273
2274 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2275 }
2276
2277 #[gpui::test(iterations = 10)]
2278 async fn test_checkpoints(cx: &mut TestAppContext) {
2279 init_test(cx);
2280 let fs = FakeFs::new(cx.background_executor.clone());
2281 fs.insert_tree(
2282 path!("/test"),
2283 json!({
2284 ".git": {}
2285 }),
2286 )
2287 .await;
2288 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2289
2290 let simulate_changes = Arc::new(AtomicBool::new(true));
2291 let next_filename = Arc::new(AtomicUsize::new(0));
2292 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2293 let simulate_changes = simulate_changes.clone();
2294 let next_filename = next_filename.clone();
2295 let fs = fs.clone();
2296 move |request, thread, mut cx| {
2297 let fs = fs.clone();
2298 let simulate_changes = simulate_changes.clone();
2299 let next_filename = next_filename.clone();
2300 async move {
2301 if simulate_changes.load(SeqCst) {
2302 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2303 fs.write(Path::new(&filename), b"").await?;
2304 }
2305
2306 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2307 panic!("expected text content block");
2308 };
2309 thread.update(&mut cx, |thread, cx| {
2310 thread
2311 .handle_session_update(
2312 acp::SessionUpdate::AgentMessageChunk {
2313 content: content.text.to_uppercase().into(),
2314 },
2315 cx,
2316 )
2317 .unwrap();
2318 })?;
2319 Ok(acp::PromptResponse {
2320 stop_reason: acp::StopReason::EndTurn,
2321 })
2322 }
2323 .boxed_local()
2324 }
2325 }));
2326 let thread = cx
2327 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2328 .await
2329 .unwrap();
2330
2331 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2332 .await
2333 .unwrap();
2334 thread.read_with(cx, |thread, cx| {
2335 assert_eq!(
2336 thread.to_markdown(cx),
2337 indoc! {"
2338 ## User (checkpoint)
2339
2340 Lorem
2341
2342 ## Assistant
2343
2344 LOREM
2345
2346 "}
2347 );
2348 });
2349 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2350
2351 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2352 .await
2353 .unwrap();
2354 thread.read_with(cx, |thread, cx| {
2355 assert_eq!(
2356 thread.to_markdown(cx),
2357 indoc! {"
2358 ## User (checkpoint)
2359
2360 Lorem
2361
2362 ## Assistant
2363
2364 LOREM
2365
2366 ## User (checkpoint)
2367
2368 ipsum
2369
2370 ## Assistant
2371
2372 IPSUM
2373
2374 "}
2375 );
2376 });
2377 assert_eq!(
2378 fs.files(),
2379 vec![
2380 Path::new(path!("/test/file-0")),
2381 Path::new(path!("/test/file-1"))
2382 ]
2383 );
2384
2385 // Checkpoint isn't stored when there are no changes.
2386 simulate_changes.store(false, SeqCst);
2387 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2388 .await
2389 .unwrap();
2390 thread.read_with(cx, |thread, cx| {
2391 assert_eq!(
2392 thread.to_markdown(cx),
2393 indoc! {"
2394 ## User (checkpoint)
2395
2396 Lorem
2397
2398 ## Assistant
2399
2400 LOREM
2401
2402 ## User (checkpoint)
2403
2404 ipsum
2405
2406 ## Assistant
2407
2408 IPSUM
2409
2410 ## User
2411
2412 dolor
2413
2414 ## Assistant
2415
2416 DOLOR
2417
2418 "}
2419 );
2420 });
2421 assert_eq!(
2422 fs.files(),
2423 vec![
2424 Path::new(path!("/test/file-0")),
2425 Path::new(path!("/test/file-1"))
2426 ]
2427 );
2428
2429 // Rewinding the conversation truncates the history and restores the checkpoint.
2430 thread
2431 .update(cx, |thread, cx| {
2432 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2433 panic!("unexpected entries {:?}", thread.entries)
2434 };
2435 thread.rewind(message.id.clone().unwrap(), cx)
2436 })
2437 .await
2438 .unwrap();
2439 thread.read_with(cx, |thread, cx| {
2440 assert_eq!(
2441 thread.to_markdown(cx),
2442 indoc! {"
2443 ## User (checkpoint)
2444
2445 Lorem
2446
2447 ## Assistant
2448
2449 LOREM
2450
2451 "}
2452 );
2453 });
2454 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2455 }
2456
2457 #[gpui::test]
2458 async fn test_refusal(cx: &mut TestAppContext) {
2459 init_test(cx);
2460 let fs = FakeFs::new(cx.background_executor.clone());
2461 fs.insert_tree(path!("/"), json!({})).await;
2462 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2463
2464 let refuse_next = Arc::new(AtomicBool::new(false));
2465 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2466 let refuse_next = refuse_next.clone();
2467 move |request, thread, mut cx| {
2468 let refuse_next = refuse_next.clone();
2469 async move {
2470 if refuse_next.load(SeqCst) {
2471 return Ok(acp::PromptResponse {
2472 stop_reason: acp::StopReason::Refusal,
2473 });
2474 }
2475
2476 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2477 panic!("expected text content block");
2478 };
2479 thread.update(&mut cx, |thread, cx| {
2480 thread
2481 .handle_session_update(
2482 acp::SessionUpdate::AgentMessageChunk {
2483 content: content.text.to_uppercase().into(),
2484 },
2485 cx,
2486 )
2487 .unwrap();
2488 })?;
2489 Ok(acp::PromptResponse {
2490 stop_reason: acp::StopReason::EndTurn,
2491 })
2492 }
2493 .boxed_local()
2494 }
2495 }));
2496 let thread = cx
2497 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2498 .await
2499 .unwrap();
2500
2501 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2502 .await
2503 .unwrap();
2504 thread.read_with(cx, |thread, cx| {
2505 assert_eq!(
2506 thread.to_markdown(cx),
2507 indoc! {"
2508 ## User
2509
2510 hello
2511
2512 ## Assistant
2513
2514 HELLO
2515
2516 "}
2517 );
2518 });
2519
2520 // Simulate refusing the second message, ensuring the conversation gets
2521 // truncated to before sending it.
2522 refuse_next.store(true, SeqCst);
2523 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2524 .await
2525 .unwrap();
2526 thread.read_with(cx, |thread, cx| {
2527 assert_eq!(
2528 thread.to_markdown(cx),
2529 indoc! {"
2530 ## User
2531
2532 hello
2533
2534 ## Assistant
2535
2536 HELLO
2537
2538 "}
2539 );
2540 });
2541 }
2542
2543 async fn run_until_first_tool_call(
2544 thread: &Entity<AcpThread>,
2545 cx: &mut TestAppContext,
2546 ) -> usize {
2547 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2548
2549 let subscription = cx.update(|cx| {
2550 cx.subscribe(thread, move |thread, _, cx| {
2551 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2552 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2553 return tx.try_send(ix).unwrap();
2554 }
2555 }
2556 })
2557 });
2558
2559 select! {
2560 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2561 panic!("Timeout waiting for tool call")
2562 }
2563 ix = rx.next().fuse() => {
2564 drop(subscription);
2565 ix.unwrap()
2566 }
2567 }
2568 }
2569
2570 #[derive(Clone, Default)]
2571 struct FakeAgentConnection {
2572 auth_methods: Vec<acp::AuthMethod>,
2573 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2574 on_user_message: Option<
2575 Rc<
2576 dyn Fn(
2577 acp::PromptRequest,
2578 WeakEntity<AcpThread>,
2579 AsyncApp,
2580 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2581 + 'static,
2582 >,
2583 >,
2584 }
2585
2586 impl FakeAgentConnection {
2587 fn new() -> Self {
2588 Self {
2589 auth_methods: Vec::new(),
2590 on_user_message: None,
2591 sessions: Arc::default(),
2592 }
2593 }
2594
2595 #[expect(unused)]
2596 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2597 self.auth_methods = auth_methods;
2598 self
2599 }
2600
2601 fn on_user_message(
2602 mut self,
2603 handler: impl Fn(
2604 acp::PromptRequest,
2605 WeakEntity<AcpThread>,
2606 AsyncApp,
2607 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2608 + 'static,
2609 ) -> Self {
2610 self.on_user_message.replace(Rc::new(handler));
2611 self
2612 }
2613 }
2614
2615 impl AgentConnection for FakeAgentConnection {
2616 fn auth_methods(&self) -> &[acp::AuthMethod] {
2617 &self.auth_methods
2618 }
2619
2620 fn new_thread(
2621 self: Rc<Self>,
2622 project: Entity<Project>,
2623 _cwd: &Path,
2624 cx: &mut App,
2625 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2626 let session_id = acp::SessionId(
2627 rand::thread_rng()
2628 .sample_iter(&rand::distributions::Alphanumeric)
2629 .take(7)
2630 .map(char::from)
2631 .collect::<String>()
2632 .into(),
2633 );
2634 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2635 let thread = cx.new(|cx| {
2636 AcpThread::new(
2637 "Test",
2638 self.clone(),
2639 project,
2640 action_log,
2641 session_id.clone(),
2642 watch::Receiver::constant(acp::PromptCapabilities {
2643 image: true,
2644 audio: true,
2645 embedded_context: true,
2646 }),
2647 cx,
2648 )
2649 });
2650 self.sessions.lock().insert(session_id, thread.downgrade());
2651 Task::ready(Ok(thread))
2652 }
2653
2654 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2655 if self.auth_methods().iter().any(|m| m.id == method) {
2656 Task::ready(Ok(()))
2657 } else {
2658 Task::ready(Err(anyhow!("Invalid Auth Method")))
2659 }
2660 }
2661
2662 fn prompt(
2663 &self,
2664 _id: Option<UserMessageId>,
2665 params: acp::PromptRequest,
2666 cx: &mut App,
2667 ) -> Task<gpui::Result<acp::PromptResponse>> {
2668 let sessions = self.sessions.lock();
2669 let thread = sessions.get(¶ms.session_id).unwrap();
2670 if let Some(handler) = &self.on_user_message {
2671 let handler = handler.clone();
2672 let thread = thread.clone();
2673 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2674 } else {
2675 Task::ready(Ok(acp::PromptResponse {
2676 stop_reason: acp::StopReason::EndTurn,
2677 }))
2678 }
2679 }
2680
2681 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2682 let sessions = self.sessions.lock();
2683 let thread = sessions.get(session_id).unwrap().clone();
2684
2685 cx.spawn(async move |cx| {
2686 thread
2687 .update(cx, |thread, cx| thread.cancel(cx))
2688 .unwrap()
2689 .await
2690 })
2691 .detach();
2692 }
2693
2694 fn truncate(
2695 &self,
2696 session_id: &acp::SessionId,
2697 _cx: &App,
2698 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2699 Some(Rc::new(FakeAgentSessionEditor {
2700 _session_id: session_id.clone(),
2701 }))
2702 }
2703
2704 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2705 self
2706 }
2707 }
2708
2709 struct FakeAgentSessionEditor {
2710 _session_id: acp::SessionId,
2711 }
2712
2713 impl AgentSessionTruncate for FakeAgentSessionEditor {
2714 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2715 Task::ready(Ok(()))
2716 }
2717 }
2718}