1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
187 first_line.to_owned() + "…"
188 } else {
189 tool_call.title
190 };
191 Self {
192 id: tool_call.id,
193 label: cx
194 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
195 kind: tool_call.kind,
196 content: tool_call
197 .content
198 .into_iter()
199 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
200 .collect(),
201 locations: tool_call.locations,
202 resolved_locations: Vec::default(),
203 status,
204 raw_input: tool_call.raw_input,
205 raw_output: tool_call.raw_output,
206 }
207 }
208
209 fn update_fields(
210 &mut self,
211 fields: acp::ToolCallUpdateFields,
212 language_registry: Arc<LanguageRegistry>,
213 cx: &mut App,
214 ) {
215 let acp::ToolCallUpdateFields {
216 kind,
217 status,
218 title,
219 content,
220 locations,
221 raw_input,
222 raw_output,
223 } = fields;
224
225 if let Some(kind) = kind {
226 self.kind = kind;
227 }
228
229 if let Some(status) = status {
230 self.status = status.into();
231 }
232
233 if let Some(title) = title {
234 self.label.update(cx, |label, cx| {
235 if let Some((first_line, _)) = title.split_once("\n") {
236 label.replace(first_line.to_owned() + "…", cx)
237 } else {
238 label.replace(title, cx);
239 }
240 });
241 }
242
243 if let Some(content) = content {
244 let new_content_len = content.len();
245 let mut content = content.into_iter();
246
247 // Reuse existing content if we can
248 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
249 old.update_from_acp(new, language_registry.clone(), cx);
250 }
251 for new in content {
252 self.content.push(ToolCallContent::from_acp(
253 new,
254 language_registry.clone(),
255 cx,
256 ))
257 }
258 self.content.truncate(new_content_len);
259 }
260
261 if let Some(locations) = locations {
262 self.locations = locations;
263 }
264
265 if let Some(raw_input) = raw_input {
266 self.raw_input = Some(raw_input);
267 }
268
269 if let Some(raw_output) = raw_output {
270 if self.content.is_empty()
271 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
272 {
273 self.content
274 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
275 markdown,
276 }));
277 }
278 self.raw_output = Some(raw_output);
279 }
280 }
281
282 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
283 self.content.iter().filter_map(|content| match content {
284 ToolCallContent::Diff(diff) => Some(diff),
285 ToolCallContent::ContentBlock(_) => None,
286 ToolCallContent::Terminal(_) => None,
287 })
288 }
289
290 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
291 self.content.iter().filter_map(|content| match content {
292 ToolCallContent::Terminal(terminal) => Some(terminal),
293 ToolCallContent::ContentBlock(_) => None,
294 ToolCallContent::Diff(_) => None,
295 })
296 }
297
298 fn to_markdown(&self, cx: &App) -> String {
299 let mut markdown = format!(
300 "**Tool Call: {}**\nStatus: {}\n\n",
301 self.label.read(cx).source(),
302 self.status
303 );
304 for content in &self.content {
305 markdown.push_str(content.to_markdown(cx).as_str());
306 markdown.push_str("\n\n");
307 }
308 markdown
309 }
310
311 async fn resolve_location(
312 location: acp::ToolCallLocation,
313 project: WeakEntity<Project>,
314 cx: &mut AsyncApp,
315 ) -> Option<AgentLocation> {
316 let buffer = project
317 .update(cx, |project, cx| {
318 project
319 .project_path_for_absolute_path(&location.path, cx)
320 .map(|path| project.open_buffer(path, cx))
321 })
322 .ok()??;
323 let buffer = buffer.await.log_err()?;
324 let position = buffer
325 .update(cx, |buffer, _| {
326 if let Some(row) = location.line {
327 let snapshot = buffer.snapshot();
328 let column = snapshot.indent_size_for_line(row).len;
329 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
330 snapshot.anchor_before(point)
331 } else {
332 Anchor::MIN
333 }
334 })
335 .ok()?;
336
337 Some(AgentLocation {
338 buffer: buffer.downgrade(),
339 position,
340 })
341 }
342
343 fn resolve_locations(
344 &self,
345 project: Entity<Project>,
346 cx: &mut App,
347 ) -> Task<Vec<Option<AgentLocation>>> {
348 let locations = self.locations.clone();
349 project.update(cx, |_, cx| {
350 cx.spawn(async move |project, cx| {
351 let mut new_locations = Vec::new();
352 for location in locations {
353 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
354 }
355 new_locations
356 })
357 })
358 }
359}
360
361#[derive(Debug)]
362pub enum ToolCallStatus {
363 /// The tool call hasn't started running yet, but we start showing it to
364 /// the user.
365 Pending,
366 /// The tool call is waiting for confirmation from the user.
367 WaitingForConfirmation {
368 options: Vec<acp::PermissionOption>,
369 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
370 },
371 /// The tool call is currently running.
372 InProgress,
373 /// The tool call completed successfully.
374 Completed,
375 /// The tool call failed.
376 Failed,
377 /// The user rejected the tool call.
378 Rejected,
379 /// The user canceled generation so the tool call was canceled.
380 Canceled,
381}
382
383impl From<acp::ToolCallStatus> for ToolCallStatus {
384 fn from(status: acp::ToolCallStatus) -> Self {
385 match status {
386 acp::ToolCallStatus::Pending => Self::Pending,
387 acp::ToolCallStatus::InProgress => Self::InProgress,
388 acp::ToolCallStatus::Completed => Self::Completed,
389 acp::ToolCallStatus::Failed => Self::Failed,
390 }
391 }
392}
393
394impl Display for ToolCallStatus {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 write!(
397 f,
398 "{}",
399 match self {
400 ToolCallStatus::Pending => "Pending",
401 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
402 ToolCallStatus::InProgress => "In Progress",
403 ToolCallStatus::Completed => "Completed",
404 ToolCallStatus::Failed => "Failed",
405 ToolCallStatus::Rejected => "Rejected",
406 ToolCallStatus::Canceled => "Canceled",
407 }
408 )
409 }
410}
411
412#[derive(Debug, PartialEq, Clone)]
413pub enum ContentBlock {
414 Empty,
415 Markdown { markdown: Entity<Markdown> },
416 ResourceLink { resource_link: acp::ResourceLink },
417}
418
419impl ContentBlock {
420 pub fn new(
421 block: acp::ContentBlock,
422 language_registry: &Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let mut this = Self::Empty;
426 this.append(block, language_registry, cx);
427 this
428 }
429
430 pub fn new_combined(
431 blocks: impl IntoIterator<Item = acp::ContentBlock>,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 let mut this = Self::Empty;
436 for block in blocks {
437 this.append(block, &language_registry, cx);
438 }
439 this
440 }
441
442 pub fn append(
443 &mut self,
444 block: acp::ContentBlock,
445 language_registry: &Arc<LanguageRegistry>,
446 cx: &mut App,
447 ) {
448 if matches!(self, ContentBlock::Empty)
449 && let acp::ContentBlock::ResourceLink(resource_link) = block
450 {
451 *self = ContentBlock::ResourceLink { resource_link };
452 return;
453 }
454
455 let new_content = self.block_string_contents(block);
456
457 match self {
458 ContentBlock::Empty => {
459 *self = Self::create_markdown_block(new_content, language_registry, cx);
460 }
461 ContentBlock::Markdown { markdown } => {
462 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
463 }
464 ContentBlock::ResourceLink { resource_link } => {
465 let existing_content = Self::resource_link_md(&resource_link.uri);
466 let combined = format!("{}\n{}", existing_content, new_content);
467
468 *self = Self::create_markdown_block(combined, language_registry, cx);
469 }
470 }
471 }
472
473 fn create_markdown_block(
474 content: String,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) -> ContentBlock {
478 ContentBlock::Markdown {
479 markdown: cx
480 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
481 }
482 }
483
484 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
485 match block {
486 acp::ContentBlock::Text(text_content) => text_content.text,
487 acp::ContentBlock::ResourceLink(resource_link) => {
488 Self::resource_link_md(&resource_link.uri)
489 }
490 acp::ContentBlock::Resource(acp::EmbeddedResource {
491 resource:
492 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
493 uri,
494 ..
495 }),
496 ..
497 }) => Self::resource_link_md(&uri),
498 acp::ContentBlock::Image(image) => Self::image_md(&image),
499 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
500 }
501 }
502
503 fn resource_link_md(uri: &str) -> String {
504 if let Some(uri) = MentionUri::parse(uri).log_err() {
505 uri.as_link().to_string()
506 } else {
507 uri.to_string()
508 }
509 }
510
511 fn image_md(_image: &acp::ImageContent) -> String {
512 "`Image`".into()
513 }
514
515 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
516 match self {
517 ContentBlock::Empty => "",
518 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
519 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
520 }
521 }
522
523 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
524 match self {
525 ContentBlock::Empty => None,
526 ContentBlock::Markdown { markdown } => Some(markdown),
527 ContentBlock::ResourceLink { .. } => None,
528 }
529 }
530
531 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
532 match self {
533 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
534 _ => None,
535 }
536 }
537}
538
539#[derive(Debug)]
540pub enum ToolCallContent {
541 ContentBlock(ContentBlock),
542 Diff(Entity<Diff>),
543 Terminal(Entity<Terminal>),
544}
545
546impl ToolCallContent {
547 pub fn from_acp(
548 content: acp::ToolCallContent,
549 language_registry: Arc<LanguageRegistry>,
550 cx: &mut App,
551 ) -> Self {
552 match content {
553 acp::ToolCallContent::Content { content } => {
554 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
555 }
556 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
557 Diff::finalized(
558 diff.path,
559 diff.old_text,
560 diff.new_text,
561 language_registry,
562 cx,
563 )
564 })),
565 }
566 }
567
568 pub fn update_from_acp(
569 &mut self,
570 new: acp::ToolCallContent,
571 language_registry: Arc<LanguageRegistry>,
572 cx: &mut App,
573 ) {
574 let needs_update = match (&self, &new) {
575 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
576 old_diff.read(cx).needs_update(
577 new_diff.old_text.as_deref().unwrap_or(""),
578 &new_diff.new_text,
579 cx,
580 )
581 }
582 _ => true,
583 };
584
585 if needs_update {
586 *self = Self::from_acp(new, language_registry, cx);
587 }
588 }
589
590 pub fn to_markdown(&self, cx: &App) -> String {
591 match self {
592 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
593 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
594 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
595 }
596 }
597}
598
599#[derive(Debug, PartialEq)]
600pub enum ToolCallUpdate {
601 UpdateFields(acp::ToolCallUpdate),
602 UpdateDiff(ToolCallUpdateDiff),
603 UpdateTerminal(ToolCallUpdateTerminal),
604}
605
606impl ToolCallUpdate {
607 fn id(&self) -> &acp::ToolCallId {
608 match self {
609 Self::UpdateFields(update) => &update.id,
610 Self::UpdateDiff(diff) => &diff.id,
611 Self::UpdateTerminal(terminal) => &terminal.id,
612 }
613 }
614}
615
616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
617 fn from(update: acp::ToolCallUpdate) -> Self {
618 Self::UpdateFields(update)
619 }
620}
621
622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
623 fn from(diff: ToolCallUpdateDiff) -> Self {
624 Self::UpdateDiff(diff)
625 }
626}
627
628#[derive(Debug, PartialEq)]
629pub struct ToolCallUpdateDiff {
630 pub id: acp::ToolCallId,
631 pub diff: Entity<Diff>,
632}
633
634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
635 fn from(terminal: ToolCallUpdateTerminal) -> Self {
636 Self::UpdateTerminal(terminal)
637 }
638}
639
640#[derive(Debug, PartialEq)]
641pub struct ToolCallUpdateTerminal {
642 pub id: acp::ToolCallId,
643 pub terminal: Entity<Terminal>,
644}
645
646#[derive(Debug, Default)]
647pub struct Plan {
648 pub entries: Vec<PlanEntry>,
649}
650
651#[derive(Debug)]
652pub struct PlanStats<'a> {
653 pub in_progress_entry: Option<&'a PlanEntry>,
654 pub pending: u32,
655 pub completed: u32,
656}
657
658impl Plan {
659 pub fn is_empty(&self) -> bool {
660 self.entries.is_empty()
661 }
662
663 pub fn stats(&self) -> PlanStats<'_> {
664 let mut stats = PlanStats {
665 in_progress_entry: None,
666 pending: 0,
667 completed: 0,
668 };
669
670 for entry in &self.entries {
671 match &entry.status {
672 acp::PlanEntryStatus::Pending => {
673 stats.pending += 1;
674 }
675 acp::PlanEntryStatus::InProgress => {
676 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
677 }
678 acp::PlanEntryStatus::Completed => {
679 stats.completed += 1;
680 }
681 }
682 }
683
684 stats
685 }
686}
687
688#[derive(Debug)]
689pub struct PlanEntry {
690 pub content: Entity<Markdown>,
691 pub priority: acp::PlanEntryPriority,
692 pub status: acp::PlanEntryStatus,
693}
694
695impl PlanEntry {
696 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
697 Self {
698 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
699 priority: entry.priority,
700 status: entry.status,
701 }
702 }
703}
704
705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct TokenUsage {
707 pub max_tokens: u64,
708 pub used_tokens: u64,
709}
710
711impl TokenUsage {
712 pub fn ratio(&self) -> TokenUsageRatio {
713 #[cfg(debug_assertions)]
714 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
715 .unwrap_or("0.8".to_string())
716 .parse()
717 .unwrap();
718 #[cfg(not(debug_assertions))]
719 let warning_threshold: f32 = 0.8;
720
721 // When the maximum is unknown because there is no selected model,
722 // avoid showing the token limit warning.
723 if self.max_tokens == 0 {
724 TokenUsageRatio::Normal
725 } else if self.used_tokens >= self.max_tokens {
726 TokenUsageRatio::Exceeded
727 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
728 TokenUsageRatio::Warning
729 } else {
730 TokenUsageRatio::Normal
731 }
732 }
733}
734
735#[derive(Debug, Clone, PartialEq, Eq)]
736pub enum TokenUsageRatio {
737 Normal,
738 Warning,
739 Exceeded,
740}
741
742#[derive(Debug, Clone)]
743pub struct RetryStatus {
744 pub last_error: SharedString,
745 pub attempt: usize,
746 pub max_attempts: usize,
747 pub started_at: Instant,
748 pub duration: Duration,
749}
750
751pub struct AcpThread {
752 title: SharedString,
753 entries: Vec<AgentThreadEntry>,
754 plan: Plan,
755 project: Entity<Project>,
756 action_log: Entity<ActionLog>,
757 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
758 send_task: Option<Task<()>>,
759 connection: Rc<dyn AgentConnection>,
760 session_id: acp::SessionId,
761 token_usage: Option<TokenUsage>,
762 prompt_capabilities: acp::PromptCapabilities,
763 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
764}
765
766#[derive(Debug)]
767pub enum AcpThreadEvent {
768 NewEntry,
769 TitleUpdated,
770 TokenUsageUpdated,
771 EntryUpdated(usize),
772 EntriesRemoved(Range<usize>),
773 ToolAuthorizationRequired,
774 Retry(RetryStatus),
775 Stopped,
776 Error,
777 LoadError(LoadError),
778 PromptCapabilitiesUpdated,
779}
780
781impl EventEmitter<AcpThreadEvent> for AcpThread {}
782
783#[derive(PartialEq, Eq, Debug)]
784pub enum ThreadStatus {
785 Idle,
786 WaitingForToolConfirmation,
787 Generating,
788}
789
790#[derive(Debug, Clone)]
791pub enum LoadError {
792 NotInstalled,
793 Unsupported {
794 command: SharedString,
795 current_version: SharedString,
796 },
797 Exited {
798 status: ExitStatus,
799 },
800 Other(SharedString),
801}
802
803impl Display for LoadError {
804 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
805 match self {
806 LoadError::NotInstalled => write!(f, "not installed"),
807 LoadError::Unsupported {
808 command: path,
809 current_version,
810 } => {
811 write!(f, "version {current_version} from {path} is not supported")
812 }
813 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
814 LoadError::Other(msg) => write!(f, "{}", msg),
815 }
816 }
817}
818
819impl Error for LoadError {}
820
821impl AcpThread {
822 pub fn new(
823 title: impl Into<SharedString>,
824 connection: Rc<dyn AgentConnection>,
825 project: Entity<Project>,
826 action_log: Entity<ActionLog>,
827 session_id: acp::SessionId,
828 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
829 cx: &mut Context<Self>,
830 ) -> Self {
831 let prompt_capabilities = *prompt_capabilities_rx.borrow();
832 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
833 loop {
834 let caps = prompt_capabilities_rx.recv().await?;
835 this.update(cx, |this, cx| {
836 this.prompt_capabilities = caps;
837 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
838 })?;
839 }
840 });
841
842 Self {
843 action_log,
844 shared_buffers: Default::default(),
845 entries: Default::default(),
846 plan: Default::default(),
847 title: title.into(),
848 project,
849 send_task: None,
850 connection,
851 session_id,
852 token_usage: None,
853 prompt_capabilities,
854 _observe_prompt_capabilities: task,
855 }
856 }
857
858 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
859 self.prompt_capabilities
860 }
861
862 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
863 &self.connection
864 }
865
866 pub fn action_log(&self) -> &Entity<ActionLog> {
867 &self.action_log
868 }
869
870 pub fn project(&self) -> &Entity<Project> {
871 &self.project
872 }
873
874 pub fn title(&self) -> SharedString {
875 self.title.clone()
876 }
877
878 pub fn entries(&self) -> &[AgentThreadEntry] {
879 &self.entries
880 }
881
882 pub fn session_id(&self) -> &acp::SessionId {
883 &self.session_id
884 }
885
886 pub fn status(&self) -> ThreadStatus {
887 if self.send_task.is_some() {
888 if self.waiting_for_tool_confirmation() {
889 ThreadStatus::WaitingForToolConfirmation
890 } else {
891 ThreadStatus::Generating
892 }
893 } else {
894 ThreadStatus::Idle
895 }
896 }
897
898 pub fn token_usage(&self) -> Option<&TokenUsage> {
899 self.token_usage.as_ref()
900 }
901
902 pub fn has_pending_edit_tool_calls(&self) -> bool {
903 for entry in self.entries.iter().rev() {
904 match entry {
905 AgentThreadEntry::UserMessage(_) => return false,
906 AgentThreadEntry::ToolCall(
907 call @ ToolCall {
908 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
909 ..
910 },
911 ) if call.diffs().next().is_some() => {
912 return true;
913 }
914 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
915 }
916 }
917
918 false
919 }
920
921 pub fn used_tools_since_last_user_message(&self) -> bool {
922 for entry in self.entries.iter().rev() {
923 match entry {
924 AgentThreadEntry::UserMessage(..) => return false,
925 AgentThreadEntry::AssistantMessage(..) => continue,
926 AgentThreadEntry::ToolCall(..) => return true,
927 }
928 }
929
930 false
931 }
932
933 pub fn handle_session_update(
934 &mut self,
935 update: acp::SessionUpdate,
936 cx: &mut Context<Self>,
937 ) -> Result<(), acp::Error> {
938 match update {
939 acp::SessionUpdate::UserMessageChunk { content } => {
940 self.push_user_content_block(None, content, cx);
941 }
942 acp::SessionUpdate::AgentMessageChunk { content } => {
943 self.push_assistant_content_block(content, false, cx);
944 }
945 acp::SessionUpdate::AgentThoughtChunk { content } => {
946 self.push_assistant_content_block(content, true, cx);
947 }
948 acp::SessionUpdate::ToolCall(tool_call) => {
949 self.upsert_tool_call(tool_call, cx)?;
950 }
951 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
952 self.update_tool_call(tool_call_update, cx)?;
953 }
954 acp::SessionUpdate::Plan(plan) => {
955 self.update_plan(plan, cx);
956 }
957 }
958 Ok(())
959 }
960
961 pub fn push_user_content_block(
962 &mut self,
963 message_id: Option<UserMessageId>,
964 chunk: acp::ContentBlock,
965 cx: &mut Context<Self>,
966 ) {
967 let language_registry = self.project.read(cx).languages().clone();
968 let entries_len = self.entries.len();
969
970 if let Some(last_entry) = self.entries.last_mut()
971 && let AgentThreadEntry::UserMessage(UserMessage {
972 id,
973 content,
974 chunks,
975 ..
976 }) = last_entry
977 {
978 *id = message_id.or(id.take());
979 content.append(chunk.clone(), &language_registry, cx);
980 chunks.push(chunk);
981 let idx = entries_len - 1;
982 cx.emit(AcpThreadEvent::EntryUpdated(idx));
983 } else {
984 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
985 self.push_entry(
986 AgentThreadEntry::UserMessage(UserMessage {
987 id: message_id,
988 content,
989 chunks: vec![chunk],
990 checkpoint: None,
991 }),
992 cx,
993 );
994 }
995 }
996
997 pub fn push_assistant_content_block(
998 &mut self,
999 chunk: acp::ContentBlock,
1000 is_thought: bool,
1001 cx: &mut Context<Self>,
1002 ) {
1003 let language_registry = self.project.read(cx).languages().clone();
1004 let entries_len = self.entries.len();
1005 if let Some(last_entry) = self.entries.last_mut()
1006 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1007 {
1008 let idx = entries_len - 1;
1009 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1010 match (chunks.last_mut(), is_thought) {
1011 (Some(AssistantMessageChunk::Message { block }), false)
1012 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1013 block.append(chunk, &language_registry, cx)
1014 }
1015 _ => {
1016 let block = ContentBlock::new(chunk, &language_registry, cx);
1017 if is_thought {
1018 chunks.push(AssistantMessageChunk::Thought { block })
1019 } else {
1020 chunks.push(AssistantMessageChunk::Message { block })
1021 }
1022 }
1023 }
1024 } else {
1025 let block = ContentBlock::new(chunk, &language_registry, cx);
1026 let chunk = if is_thought {
1027 AssistantMessageChunk::Thought { block }
1028 } else {
1029 AssistantMessageChunk::Message { block }
1030 };
1031
1032 self.push_entry(
1033 AgentThreadEntry::AssistantMessage(AssistantMessage {
1034 chunks: vec![chunk],
1035 }),
1036 cx,
1037 );
1038 }
1039 }
1040
1041 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1042 self.entries.push(entry);
1043 cx.emit(AcpThreadEvent::NewEntry);
1044 }
1045
1046 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1047 self.connection.set_title(&self.session_id, cx).is_some()
1048 }
1049
1050 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1051 if title != self.title {
1052 self.title = title.clone();
1053 cx.emit(AcpThreadEvent::TitleUpdated);
1054 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1055 return set_title.run(title, cx);
1056 }
1057 }
1058 Task::ready(Ok(()))
1059 }
1060
1061 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1062 self.token_usage = usage;
1063 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1064 }
1065
1066 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1067 cx.emit(AcpThreadEvent::Retry(status));
1068 }
1069
1070 pub fn update_tool_call(
1071 &mut self,
1072 update: impl Into<ToolCallUpdate>,
1073 cx: &mut Context<Self>,
1074 ) -> Result<()> {
1075 let update = update.into();
1076 let languages = self.project.read(cx).languages().clone();
1077
1078 let (ix, current_call) = self
1079 .tool_call_mut(update.id())
1080 .context("Tool call not found")?;
1081 match update {
1082 ToolCallUpdate::UpdateFields(update) => {
1083 let location_updated = update.fields.locations.is_some();
1084 current_call.update_fields(update.fields, languages, cx);
1085 if location_updated {
1086 self.resolve_locations(update.id, cx);
1087 }
1088 }
1089 ToolCallUpdate::UpdateDiff(update) => {
1090 current_call.content.clear();
1091 current_call
1092 .content
1093 .push(ToolCallContent::Diff(update.diff));
1094 }
1095 ToolCallUpdate::UpdateTerminal(update) => {
1096 current_call.content.clear();
1097 current_call
1098 .content
1099 .push(ToolCallContent::Terminal(update.terminal));
1100 }
1101 }
1102
1103 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1104
1105 Ok(())
1106 }
1107
1108 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1109 pub fn upsert_tool_call(
1110 &mut self,
1111 tool_call: acp::ToolCall,
1112 cx: &mut Context<Self>,
1113 ) -> Result<(), acp::Error> {
1114 let status = tool_call.status.into();
1115 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1116 }
1117
1118 /// Fails if id does not match an existing entry.
1119 pub fn upsert_tool_call_inner(
1120 &mut self,
1121 tool_call_update: acp::ToolCallUpdate,
1122 status: ToolCallStatus,
1123 cx: &mut Context<Self>,
1124 ) -> Result<(), acp::Error> {
1125 let language_registry = self.project.read(cx).languages().clone();
1126 let id = tool_call_update.id.clone();
1127
1128 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1129 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1130 current_call.status = status;
1131
1132 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1133 } else {
1134 let call =
1135 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1136 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1137 };
1138
1139 self.resolve_locations(id, cx);
1140 Ok(())
1141 }
1142
1143 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1144 // The tool call we are looking for is typically the last one, or very close to the end.
1145 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1146 self.entries
1147 .iter_mut()
1148 .enumerate()
1149 .rev()
1150 .find_map(|(index, tool_call)| {
1151 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1152 && &tool_call.id == id
1153 {
1154 Some((index, tool_call))
1155 } else {
1156 None
1157 }
1158 })
1159 }
1160
1161 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1162 self.entries
1163 .iter()
1164 .enumerate()
1165 .rev()
1166 .find_map(|(index, tool_call)| {
1167 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1168 && &tool_call.id == id
1169 {
1170 Some((index, tool_call))
1171 } else {
1172 None
1173 }
1174 })
1175 }
1176
1177 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1178 let project = self.project.clone();
1179 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1180 return;
1181 };
1182 let task = tool_call.resolve_locations(project, cx);
1183 cx.spawn(async move |this, cx| {
1184 let resolved_locations = task.await;
1185 this.update(cx, |this, cx| {
1186 let project = this.project.clone();
1187 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1188 return;
1189 };
1190 if let Some(Some(location)) = resolved_locations.last() {
1191 project.update(cx, |project, cx| {
1192 if let Some(agent_location) = project.agent_location() {
1193 let should_ignore = agent_location.buffer == location.buffer
1194 && location
1195 .buffer
1196 .update(cx, |buffer, _| {
1197 let snapshot = buffer.snapshot();
1198 let old_position =
1199 agent_location.position.to_point(&snapshot);
1200 let new_position = location.position.to_point(&snapshot);
1201 // ignore this so that when we get updates from the edit tool
1202 // the position doesn't reset to the startof line
1203 old_position.row == new_position.row
1204 && old_position.column > new_position.column
1205 })
1206 .ok()
1207 .unwrap_or_default();
1208 if !should_ignore {
1209 project.set_agent_location(Some(location.clone()), cx);
1210 }
1211 }
1212 });
1213 }
1214 if tool_call.resolved_locations != resolved_locations {
1215 tool_call.resolved_locations = resolved_locations;
1216 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1217 }
1218 })
1219 })
1220 .detach();
1221 }
1222
1223 pub fn request_tool_call_authorization(
1224 &mut self,
1225 tool_call: acp::ToolCallUpdate,
1226 options: Vec<acp::PermissionOption>,
1227 cx: &mut Context<Self>,
1228 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1229 let (tx, rx) = oneshot::channel();
1230
1231 let status = ToolCallStatus::WaitingForConfirmation {
1232 options,
1233 respond_tx: tx,
1234 };
1235
1236 self.upsert_tool_call_inner(tool_call, status, cx)?;
1237 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1238 Ok(rx)
1239 }
1240
1241 pub fn authorize_tool_call(
1242 &mut self,
1243 id: acp::ToolCallId,
1244 option_id: acp::PermissionOptionId,
1245 option_kind: acp::PermissionOptionKind,
1246 cx: &mut Context<Self>,
1247 ) {
1248 let Some((ix, call)) = self.tool_call_mut(&id) else {
1249 return;
1250 };
1251
1252 let new_status = match option_kind {
1253 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1254 ToolCallStatus::Rejected
1255 }
1256 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1257 ToolCallStatus::InProgress
1258 }
1259 };
1260
1261 let curr_status = mem::replace(&mut call.status, new_status);
1262
1263 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1264 respond_tx.send(option_id).log_err();
1265 } else if cfg!(debug_assertions) {
1266 panic!("tried to authorize an already authorized tool call");
1267 }
1268
1269 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1270 }
1271
1272 /// Returns true if the last turn is awaiting tool authorization
1273 pub fn waiting_for_tool_confirmation(&self) -> bool {
1274 for entry in self.entries.iter().rev() {
1275 match &entry {
1276 AgentThreadEntry::ToolCall(call) => match call.status {
1277 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1278 ToolCallStatus::Pending
1279 | ToolCallStatus::InProgress
1280 | ToolCallStatus::Completed
1281 | ToolCallStatus::Failed
1282 | ToolCallStatus::Rejected
1283 | ToolCallStatus::Canceled => continue,
1284 },
1285 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1286 // Reached the beginning of the turn
1287 return false;
1288 }
1289 }
1290 }
1291 false
1292 }
1293
1294 pub fn plan(&self) -> &Plan {
1295 &self.plan
1296 }
1297
1298 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1299 let new_entries_len = request.entries.len();
1300 let mut new_entries = request.entries.into_iter();
1301
1302 // Reuse existing markdown to prevent flickering
1303 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1304 let PlanEntry {
1305 content,
1306 priority,
1307 status,
1308 } = old;
1309 content.update(cx, |old, cx| {
1310 old.replace(new.content, cx);
1311 });
1312 *priority = new.priority;
1313 *status = new.status;
1314 }
1315 for new in new_entries {
1316 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1317 }
1318 self.plan.entries.truncate(new_entries_len);
1319
1320 cx.notify();
1321 }
1322
1323 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1324 self.plan
1325 .entries
1326 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1327 cx.notify();
1328 }
1329
1330 #[cfg(any(test, feature = "test-support"))]
1331 pub fn send_raw(
1332 &mut self,
1333 message: &str,
1334 cx: &mut Context<Self>,
1335 ) -> BoxFuture<'static, Result<()>> {
1336 self.send(
1337 vec![acp::ContentBlock::Text(acp::TextContent {
1338 text: message.to_string(),
1339 annotations: None,
1340 })],
1341 cx,
1342 )
1343 }
1344
1345 pub fn send(
1346 &mut self,
1347 message: Vec<acp::ContentBlock>,
1348 cx: &mut Context<Self>,
1349 ) -> BoxFuture<'static, Result<()>> {
1350 let block = ContentBlock::new_combined(
1351 message.clone(),
1352 self.project.read(cx).languages().clone(),
1353 cx,
1354 );
1355 let request = acp::PromptRequest {
1356 prompt: message.clone(),
1357 session_id: self.session_id.clone(),
1358 };
1359 let git_store = self.project.read(cx).git_store().clone();
1360
1361 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1362 Some(UserMessageId::new())
1363 } else {
1364 None
1365 };
1366
1367 self.run_turn(cx, async move |this, cx| {
1368 this.update(cx, |this, cx| {
1369 this.push_entry(
1370 AgentThreadEntry::UserMessage(UserMessage {
1371 id: message_id.clone(),
1372 content: block,
1373 chunks: message,
1374 checkpoint: None,
1375 }),
1376 cx,
1377 );
1378 })
1379 .ok();
1380
1381 let old_checkpoint = git_store
1382 .update(cx, |git, cx| git.checkpoint(cx))?
1383 .await
1384 .context("failed to get old checkpoint")
1385 .log_err();
1386 this.update(cx, |this, cx| {
1387 if let Some((_ix, message)) = this.last_user_message() {
1388 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1389 git_checkpoint,
1390 show: false,
1391 });
1392 }
1393 this.connection.prompt(message_id, request, cx)
1394 })?
1395 .await
1396 })
1397 }
1398
1399 pub fn can_resume(&self, cx: &App) -> bool {
1400 self.connection.resume(&self.session_id, cx).is_some()
1401 }
1402
1403 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1404 self.run_turn(cx, async move |this, cx| {
1405 this.update(cx, |this, cx| {
1406 this.connection
1407 .resume(&this.session_id, cx)
1408 .map(|resume| resume.run(cx))
1409 })?
1410 .context("resuming a session is not supported")?
1411 .await
1412 })
1413 }
1414
1415 fn run_turn(
1416 &mut self,
1417 cx: &mut Context<Self>,
1418 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1419 ) -> BoxFuture<'static, Result<()>> {
1420 self.clear_completed_plan_entries(cx);
1421
1422 let (tx, rx) = oneshot::channel();
1423 let cancel_task = self.cancel(cx);
1424
1425 self.send_task = Some(cx.spawn(async move |this, cx| {
1426 cancel_task.await;
1427 tx.send(f(this, cx).await).ok();
1428 }));
1429
1430 cx.spawn(async move |this, cx| {
1431 let response = rx.await;
1432
1433 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1434 .await?;
1435
1436 this.update(cx, |this, cx| {
1437 this.project
1438 .update(cx, |project, cx| project.set_agent_location(None, cx));
1439 match response {
1440 Ok(Err(e)) => {
1441 this.send_task.take();
1442 cx.emit(AcpThreadEvent::Error);
1443 Err(e)
1444 }
1445 result => {
1446 let canceled = matches!(
1447 result,
1448 Ok(Ok(acp::PromptResponse {
1449 stop_reason: acp::StopReason::Cancelled
1450 }))
1451 );
1452
1453 // We only take the task if the current prompt wasn't canceled.
1454 //
1455 // This prompt may have been canceled because another one was sent
1456 // while it was still generating. In these cases, dropping `send_task`
1457 // would cause the next generation to be canceled.
1458 if !canceled {
1459 this.send_task.take();
1460 }
1461
1462 // Truncate entries if the last prompt was refused.
1463 if let Ok(Ok(acp::PromptResponse {
1464 stop_reason: acp::StopReason::Refusal,
1465 })) = result
1466 && let Some((ix, _)) = this.last_user_message()
1467 {
1468 let range = ix..this.entries.len();
1469 this.entries.truncate(ix);
1470 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1471 }
1472
1473 cx.emit(AcpThreadEvent::Stopped);
1474 Ok(())
1475 }
1476 }
1477 })?
1478 })
1479 .boxed()
1480 }
1481
1482 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1483 let Some(send_task) = self.send_task.take() else {
1484 return Task::ready(());
1485 };
1486
1487 for entry in self.entries.iter_mut() {
1488 if let AgentThreadEntry::ToolCall(call) = entry {
1489 let cancel = matches!(
1490 call.status,
1491 ToolCallStatus::Pending
1492 | ToolCallStatus::WaitingForConfirmation { .. }
1493 | ToolCallStatus::InProgress
1494 );
1495
1496 if cancel {
1497 call.status = ToolCallStatus::Canceled;
1498 }
1499 }
1500 }
1501
1502 self.connection.cancel(&self.session_id, cx);
1503
1504 // Wait for the send task to complete
1505 cx.foreground_executor().spawn(send_task)
1506 }
1507
1508 /// Rewinds this thread to before the entry at `index`, removing it and all
1509 /// subsequent entries while reverting any changes made from that point.
1510 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1511 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1512 return Task::ready(Err(anyhow!("not supported")));
1513 };
1514 let Some(message) = self.user_message(&id) else {
1515 return Task::ready(Err(anyhow!("message not found")));
1516 };
1517
1518 let checkpoint = message
1519 .checkpoint
1520 .as_ref()
1521 .map(|c| c.git_checkpoint.clone());
1522
1523 let git_store = self.project.read(cx).git_store().clone();
1524 cx.spawn(async move |this, cx| {
1525 if let Some(checkpoint) = checkpoint {
1526 git_store
1527 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1528 .await?;
1529 }
1530
1531 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1532 this.update(cx, |this, cx| {
1533 if let Some((ix, _)) = this.user_message_mut(&id) {
1534 let range = ix..this.entries.len();
1535 this.entries.truncate(ix);
1536 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1537 }
1538 })
1539 })
1540 }
1541
1542 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1543 let git_store = self.project.read(cx).git_store().clone();
1544
1545 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1546 if let Some(checkpoint) = message.checkpoint.as_ref() {
1547 checkpoint.git_checkpoint.clone()
1548 } else {
1549 return Task::ready(Ok(()));
1550 }
1551 } else {
1552 return Task::ready(Ok(()));
1553 };
1554
1555 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1556 cx.spawn(async move |this, cx| {
1557 let new_checkpoint = new_checkpoint
1558 .await
1559 .context("failed to get new checkpoint")
1560 .log_err();
1561 if let Some(new_checkpoint) = new_checkpoint {
1562 let equal = git_store
1563 .update(cx, |git, cx| {
1564 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1565 })?
1566 .await
1567 .unwrap_or(true);
1568 this.update(cx, |this, cx| {
1569 let (ix, message) = this.last_user_message().context("no user message")?;
1570 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1571 checkpoint.show = !equal;
1572 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1573 anyhow::Ok(())
1574 })??;
1575 }
1576
1577 Ok(())
1578 })
1579 }
1580
1581 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1582 self.entries
1583 .iter_mut()
1584 .enumerate()
1585 .rev()
1586 .find_map(|(ix, entry)| {
1587 if let AgentThreadEntry::UserMessage(message) = entry {
1588 Some((ix, message))
1589 } else {
1590 None
1591 }
1592 })
1593 }
1594
1595 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1596 self.entries.iter().find_map(|entry| {
1597 if let AgentThreadEntry::UserMessage(message) = entry {
1598 if message.id.as_ref() == Some(id) {
1599 Some(message)
1600 } else {
1601 None
1602 }
1603 } else {
1604 None
1605 }
1606 })
1607 }
1608
1609 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1610 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1611 if let AgentThreadEntry::UserMessage(message) = entry {
1612 if message.id.as_ref() == Some(id) {
1613 Some((ix, message))
1614 } else {
1615 None
1616 }
1617 } else {
1618 None
1619 }
1620 })
1621 }
1622
1623 pub fn read_text_file(
1624 &self,
1625 path: PathBuf,
1626 line: Option<u32>,
1627 limit: Option<u32>,
1628 reuse_shared_snapshot: bool,
1629 cx: &mut Context<Self>,
1630 ) -> Task<Result<String>> {
1631 let project = self.project.clone();
1632 let action_log = self.action_log.clone();
1633 cx.spawn(async move |this, cx| {
1634 let load = project.update(cx, |project, cx| {
1635 let path = project
1636 .project_path_for_absolute_path(&path, cx)
1637 .context("invalid path")?;
1638 anyhow::Ok(project.open_buffer(path, cx))
1639 });
1640 let buffer = load??.await?;
1641
1642 let snapshot = if reuse_shared_snapshot {
1643 this.read_with(cx, |this, _| {
1644 this.shared_buffers.get(&buffer.clone()).cloned()
1645 })
1646 .log_err()
1647 .flatten()
1648 } else {
1649 None
1650 };
1651
1652 let snapshot = if let Some(snapshot) = snapshot {
1653 snapshot
1654 } else {
1655 action_log.update(cx, |action_log, cx| {
1656 action_log.buffer_read(buffer.clone(), cx);
1657 })?;
1658 project.update(cx, |project, cx| {
1659 let position = buffer
1660 .read(cx)
1661 .snapshot()
1662 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1663 project.set_agent_location(
1664 Some(AgentLocation {
1665 buffer: buffer.downgrade(),
1666 position,
1667 }),
1668 cx,
1669 );
1670 })?;
1671
1672 buffer.update(cx, |buffer, _| buffer.snapshot())?
1673 };
1674
1675 this.update(cx, |this, _| {
1676 let text = snapshot.text();
1677 this.shared_buffers.insert(buffer.clone(), snapshot);
1678 if line.is_none() && limit.is_none() {
1679 return Ok(text);
1680 }
1681 let limit = limit.unwrap_or(u32::MAX) as usize;
1682 let Some(line) = line else {
1683 return Ok(text.lines().take(limit).collect::<String>());
1684 };
1685
1686 let count = text.lines().count();
1687 if count < line as usize {
1688 anyhow::bail!("There are only {} lines", count);
1689 }
1690 Ok(text
1691 .lines()
1692 .skip(line as usize + 1)
1693 .take(limit)
1694 .collect::<String>())
1695 })?
1696 })
1697 }
1698
1699 pub fn write_text_file(
1700 &self,
1701 path: PathBuf,
1702 content: String,
1703 cx: &mut Context<Self>,
1704 ) -> Task<Result<()>> {
1705 let project = self.project.clone();
1706 let action_log = self.action_log.clone();
1707 cx.spawn(async move |this, cx| {
1708 let load = project.update(cx, |project, cx| {
1709 let path = project
1710 .project_path_for_absolute_path(&path, cx)
1711 .context("invalid path")?;
1712 anyhow::Ok(project.open_buffer(path, cx))
1713 });
1714 let buffer = load??.await?;
1715 let snapshot = this.update(cx, |this, cx| {
1716 this.shared_buffers
1717 .get(&buffer)
1718 .cloned()
1719 .unwrap_or_else(|| buffer.read(cx).snapshot())
1720 })?;
1721 let edits = cx
1722 .background_executor()
1723 .spawn(async move {
1724 let old_text = snapshot.text();
1725 text_diff(old_text.as_str(), &content)
1726 .into_iter()
1727 .map(|(range, replacement)| {
1728 (
1729 snapshot.anchor_after(range.start)
1730 ..snapshot.anchor_before(range.end),
1731 replacement,
1732 )
1733 })
1734 .collect::<Vec<_>>()
1735 })
1736 .await;
1737
1738 project.update(cx, |project, cx| {
1739 project.set_agent_location(
1740 Some(AgentLocation {
1741 buffer: buffer.downgrade(),
1742 position: edits
1743 .last()
1744 .map(|(range, _)| range.end)
1745 .unwrap_or(Anchor::MIN),
1746 }),
1747 cx,
1748 );
1749 })?;
1750
1751 let format_on_save = cx.update(|cx| {
1752 action_log.update(cx, |action_log, cx| {
1753 action_log.buffer_read(buffer.clone(), cx);
1754 });
1755
1756 let format_on_save = buffer.update(cx, |buffer, cx| {
1757 buffer.edit(edits, None, cx);
1758
1759 let settings = language::language_settings::language_settings(
1760 buffer.language().map(|l| l.name()),
1761 buffer.file(),
1762 cx,
1763 );
1764
1765 settings.format_on_save != FormatOnSave::Off
1766 });
1767 action_log.update(cx, |action_log, cx| {
1768 action_log.buffer_edited(buffer.clone(), cx);
1769 });
1770 format_on_save
1771 })?;
1772
1773 if format_on_save {
1774 let format_task = project.update(cx, |project, cx| {
1775 project.format(
1776 HashSet::from_iter([buffer.clone()]),
1777 LspFormatTarget::Buffers,
1778 false,
1779 FormatTrigger::Save,
1780 cx,
1781 )
1782 })?;
1783 format_task.await.log_err();
1784
1785 action_log.update(cx, |action_log, cx| {
1786 action_log.buffer_edited(buffer.clone(), cx);
1787 })?;
1788 }
1789
1790 project
1791 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1792 .await
1793 })
1794 }
1795
1796 pub fn to_markdown(&self, cx: &App) -> String {
1797 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1798 }
1799
1800 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1801 cx.emit(AcpThreadEvent::LoadError(error));
1802 }
1803}
1804
1805fn markdown_for_raw_output(
1806 raw_output: &serde_json::Value,
1807 language_registry: &Arc<LanguageRegistry>,
1808 cx: &mut App,
1809) -> Option<Entity<Markdown>> {
1810 match raw_output {
1811 serde_json::Value::Null => None,
1812 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1813 Markdown::new(
1814 value.to_string().into(),
1815 Some(language_registry.clone()),
1816 None,
1817 cx,
1818 )
1819 })),
1820 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1821 Markdown::new(
1822 value.to_string().into(),
1823 Some(language_registry.clone()),
1824 None,
1825 cx,
1826 )
1827 })),
1828 serde_json::Value::String(value) => Some(cx.new(|cx| {
1829 Markdown::new(
1830 value.clone().into(),
1831 Some(language_registry.clone()),
1832 None,
1833 cx,
1834 )
1835 })),
1836 value => Some(cx.new(|cx| {
1837 Markdown::new(
1838 format!("```json\n{}\n```", value).into(),
1839 Some(language_registry.clone()),
1840 None,
1841 cx,
1842 )
1843 })),
1844 }
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849 use super::*;
1850 use anyhow::anyhow;
1851 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1852 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1853 use indoc::indoc;
1854 use project::{FakeFs, Fs};
1855 use rand::Rng as _;
1856 use serde_json::json;
1857 use settings::SettingsStore;
1858 use smol::stream::StreamExt as _;
1859 use std::{
1860 any::Any,
1861 cell::RefCell,
1862 path::Path,
1863 rc::Rc,
1864 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1865 time::Duration,
1866 };
1867 use util::path;
1868
1869 fn init_test(cx: &mut TestAppContext) {
1870 env_logger::try_init().ok();
1871 cx.update(|cx| {
1872 let settings_store = SettingsStore::test(cx);
1873 cx.set_global(settings_store);
1874 Project::init_settings(cx);
1875 language::init(cx);
1876 });
1877 }
1878
1879 #[gpui::test]
1880 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1881 init_test(cx);
1882
1883 let fs = FakeFs::new(cx.executor());
1884 let project = Project::test(fs, [], cx).await;
1885 let connection = Rc::new(FakeAgentConnection::new());
1886 let thread = cx
1887 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1888 .await
1889 .unwrap();
1890
1891 // Test creating a new user message
1892 thread.update(cx, |thread, cx| {
1893 thread.push_user_content_block(
1894 None,
1895 acp::ContentBlock::Text(acp::TextContent {
1896 annotations: None,
1897 text: "Hello, ".to_string(),
1898 }),
1899 cx,
1900 );
1901 });
1902
1903 thread.update(cx, |thread, cx| {
1904 assert_eq!(thread.entries.len(), 1);
1905 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1906 assert_eq!(user_msg.id, None);
1907 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1908 } else {
1909 panic!("Expected UserMessage");
1910 }
1911 });
1912
1913 // Test appending to existing user message
1914 let message_1_id = UserMessageId::new();
1915 thread.update(cx, |thread, cx| {
1916 thread.push_user_content_block(
1917 Some(message_1_id.clone()),
1918 acp::ContentBlock::Text(acp::TextContent {
1919 annotations: None,
1920 text: "world!".to_string(),
1921 }),
1922 cx,
1923 );
1924 });
1925
1926 thread.update(cx, |thread, cx| {
1927 assert_eq!(thread.entries.len(), 1);
1928 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1929 assert_eq!(user_msg.id, Some(message_1_id));
1930 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1931 } else {
1932 panic!("Expected UserMessage");
1933 }
1934 });
1935
1936 // Test creating new user message after assistant message
1937 thread.update(cx, |thread, cx| {
1938 thread.push_assistant_content_block(
1939 acp::ContentBlock::Text(acp::TextContent {
1940 annotations: None,
1941 text: "Assistant response".to_string(),
1942 }),
1943 false,
1944 cx,
1945 );
1946 });
1947
1948 let message_2_id = UserMessageId::new();
1949 thread.update(cx, |thread, cx| {
1950 thread.push_user_content_block(
1951 Some(message_2_id.clone()),
1952 acp::ContentBlock::Text(acp::TextContent {
1953 annotations: None,
1954 text: "New user message".to_string(),
1955 }),
1956 cx,
1957 );
1958 });
1959
1960 thread.update(cx, |thread, cx| {
1961 assert_eq!(thread.entries.len(), 3);
1962 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1963 assert_eq!(user_msg.id, Some(message_2_id));
1964 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1965 } else {
1966 panic!("Expected UserMessage at index 2");
1967 }
1968 });
1969 }
1970
1971 #[gpui::test]
1972 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1973 init_test(cx);
1974
1975 let fs = FakeFs::new(cx.executor());
1976 let project = Project::test(fs, [], cx).await;
1977 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1978 |_, thread, mut cx| {
1979 async move {
1980 thread.update(&mut cx, |thread, cx| {
1981 thread
1982 .handle_session_update(
1983 acp::SessionUpdate::AgentThoughtChunk {
1984 content: "Thinking ".into(),
1985 },
1986 cx,
1987 )
1988 .unwrap();
1989 thread
1990 .handle_session_update(
1991 acp::SessionUpdate::AgentThoughtChunk {
1992 content: "hard!".into(),
1993 },
1994 cx,
1995 )
1996 .unwrap();
1997 })?;
1998 Ok(acp::PromptResponse {
1999 stop_reason: acp::StopReason::EndTurn,
2000 })
2001 }
2002 .boxed_local()
2003 },
2004 ));
2005
2006 let thread = cx
2007 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2008 .await
2009 .unwrap();
2010
2011 thread
2012 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2013 .await
2014 .unwrap();
2015
2016 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2017 assert_eq!(
2018 output,
2019 indoc! {r#"
2020 ## User
2021
2022 Hello from Zed!
2023
2024 ## Assistant
2025
2026 <thinking>
2027 Thinking hard!
2028 </thinking>
2029
2030 "#}
2031 );
2032 }
2033
2034 #[gpui::test]
2035 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2036 init_test(cx);
2037
2038 let fs = FakeFs::new(cx.executor());
2039 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2040 .await;
2041 let project = Project::test(fs.clone(), [], cx).await;
2042 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2043 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2044 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2045 move |_, thread, mut cx| {
2046 let read_file_tx = read_file_tx.clone();
2047 async move {
2048 let content = thread
2049 .update(&mut cx, |thread, cx| {
2050 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2051 })
2052 .unwrap()
2053 .await
2054 .unwrap();
2055 assert_eq!(content, "one\ntwo\nthree\n");
2056 read_file_tx.take().unwrap().send(()).unwrap();
2057 thread
2058 .update(&mut cx, |thread, cx| {
2059 thread.write_text_file(
2060 path!("/tmp/foo").into(),
2061 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2062 cx,
2063 )
2064 })
2065 .unwrap()
2066 .await
2067 .unwrap();
2068 Ok(acp::PromptResponse {
2069 stop_reason: acp::StopReason::EndTurn,
2070 })
2071 }
2072 .boxed_local()
2073 },
2074 ));
2075
2076 let (worktree, pathbuf) = project
2077 .update(cx, |project, cx| {
2078 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2079 })
2080 .await
2081 .unwrap();
2082 let buffer = project
2083 .update(cx, |project, cx| {
2084 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2085 })
2086 .await
2087 .unwrap();
2088
2089 let thread = cx
2090 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2091 .await
2092 .unwrap();
2093
2094 let request = thread.update(cx, |thread, cx| {
2095 thread.send_raw("Extend the count in /tmp/foo", cx)
2096 });
2097 read_file_rx.await.ok();
2098 buffer.update(cx, |buffer, cx| {
2099 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2100 });
2101 cx.run_until_parked();
2102 assert_eq!(
2103 buffer.read_with(cx, |buffer, _| buffer.text()),
2104 "zero\none\ntwo\nthree\nfour\nfive\n"
2105 );
2106 assert_eq!(
2107 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2108 "zero\none\ntwo\nthree\nfour\nfive\n"
2109 );
2110 request.await.unwrap();
2111 }
2112
2113 #[gpui::test]
2114 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2115 init_test(cx);
2116
2117 let fs = FakeFs::new(cx.executor());
2118 let project = Project::test(fs, [], cx).await;
2119 let id = acp::ToolCallId("test".into());
2120
2121 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2122 let id = id.clone();
2123 move |_, thread, mut cx| {
2124 let id = id.clone();
2125 async move {
2126 thread
2127 .update(&mut cx, |thread, cx| {
2128 thread.handle_session_update(
2129 acp::SessionUpdate::ToolCall(acp::ToolCall {
2130 id: id.clone(),
2131 title: "Label".into(),
2132 kind: acp::ToolKind::Fetch,
2133 status: acp::ToolCallStatus::InProgress,
2134 content: vec![],
2135 locations: vec![],
2136 raw_input: None,
2137 raw_output: None,
2138 }),
2139 cx,
2140 )
2141 })
2142 .unwrap()
2143 .unwrap();
2144 Ok(acp::PromptResponse {
2145 stop_reason: acp::StopReason::EndTurn,
2146 })
2147 }
2148 .boxed_local()
2149 }
2150 }));
2151
2152 let thread = cx
2153 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2154 .await
2155 .unwrap();
2156
2157 let request = thread.update(cx, |thread, cx| {
2158 thread.send_raw("Fetch https://example.com", cx)
2159 });
2160
2161 run_until_first_tool_call(&thread, cx).await;
2162
2163 thread.read_with(cx, |thread, _| {
2164 assert!(matches!(
2165 thread.entries[1],
2166 AgentThreadEntry::ToolCall(ToolCall {
2167 status: ToolCallStatus::InProgress,
2168 ..
2169 })
2170 ));
2171 });
2172
2173 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2174
2175 thread.read_with(cx, |thread, _| {
2176 assert!(matches!(
2177 &thread.entries[1],
2178 AgentThreadEntry::ToolCall(ToolCall {
2179 status: ToolCallStatus::Canceled,
2180 ..
2181 })
2182 ));
2183 });
2184
2185 thread
2186 .update(cx, |thread, cx| {
2187 thread.handle_session_update(
2188 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2189 id,
2190 fields: acp::ToolCallUpdateFields {
2191 status: Some(acp::ToolCallStatus::Completed),
2192 ..Default::default()
2193 },
2194 }),
2195 cx,
2196 )
2197 })
2198 .unwrap();
2199
2200 request.await.unwrap();
2201
2202 thread.read_with(cx, |thread, _| {
2203 assert!(matches!(
2204 thread.entries[1],
2205 AgentThreadEntry::ToolCall(ToolCall {
2206 status: ToolCallStatus::Completed,
2207 ..
2208 })
2209 ));
2210 });
2211 }
2212
2213 #[gpui::test]
2214 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2215 init_test(cx);
2216 let fs = FakeFs::new(cx.background_executor.clone());
2217 fs.insert_tree(path!("/test"), json!({})).await;
2218 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2219
2220 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2221 move |_, thread, mut cx| {
2222 async move {
2223 thread
2224 .update(&mut cx, |thread, cx| {
2225 thread.handle_session_update(
2226 acp::SessionUpdate::ToolCall(acp::ToolCall {
2227 id: acp::ToolCallId("test".into()),
2228 title: "Label".into(),
2229 kind: acp::ToolKind::Edit,
2230 status: acp::ToolCallStatus::Completed,
2231 content: vec![acp::ToolCallContent::Diff {
2232 diff: acp::Diff {
2233 path: "/test/test.txt".into(),
2234 old_text: None,
2235 new_text: "foo".into(),
2236 },
2237 }],
2238 locations: vec![],
2239 raw_input: None,
2240 raw_output: None,
2241 }),
2242 cx,
2243 )
2244 })
2245 .unwrap()
2246 .unwrap();
2247 Ok(acp::PromptResponse {
2248 stop_reason: acp::StopReason::EndTurn,
2249 })
2250 }
2251 .boxed_local()
2252 }
2253 }));
2254
2255 let thread = cx
2256 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2257 .await
2258 .unwrap();
2259
2260 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2261 .await
2262 .unwrap();
2263
2264 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2265 }
2266
2267 #[gpui::test(iterations = 10)]
2268 async fn test_checkpoints(cx: &mut TestAppContext) {
2269 init_test(cx);
2270 let fs = FakeFs::new(cx.background_executor.clone());
2271 fs.insert_tree(
2272 path!("/test"),
2273 json!({
2274 ".git": {}
2275 }),
2276 )
2277 .await;
2278 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2279
2280 let simulate_changes = Arc::new(AtomicBool::new(true));
2281 let next_filename = Arc::new(AtomicUsize::new(0));
2282 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2283 let simulate_changes = simulate_changes.clone();
2284 let next_filename = next_filename.clone();
2285 let fs = fs.clone();
2286 move |request, thread, mut cx| {
2287 let fs = fs.clone();
2288 let simulate_changes = simulate_changes.clone();
2289 let next_filename = next_filename.clone();
2290 async move {
2291 if simulate_changes.load(SeqCst) {
2292 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2293 fs.write(Path::new(&filename), b"").await?;
2294 }
2295
2296 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2297 panic!("expected text content block");
2298 };
2299 thread.update(&mut cx, |thread, cx| {
2300 thread
2301 .handle_session_update(
2302 acp::SessionUpdate::AgentMessageChunk {
2303 content: content.text.to_uppercase().into(),
2304 },
2305 cx,
2306 )
2307 .unwrap();
2308 })?;
2309 Ok(acp::PromptResponse {
2310 stop_reason: acp::StopReason::EndTurn,
2311 })
2312 }
2313 .boxed_local()
2314 }
2315 }));
2316 let thread = cx
2317 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2318 .await
2319 .unwrap();
2320
2321 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2322 .await
2323 .unwrap();
2324 thread.read_with(cx, |thread, cx| {
2325 assert_eq!(
2326 thread.to_markdown(cx),
2327 indoc! {"
2328 ## User (checkpoint)
2329
2330 Lorem
2331
2332 ## Assistant
2333
2334 LOREM
2335
2336 "}
2337 );
2338 });
2339 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2340
2341 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2342 .await
2343 .unwrap();
2344 thread.read_with(cx, |thread, cx| {
2345 assert_eq!(
2346 thread.to_markdown(cx),
2347 indoc! {"
2348 ## User (checkpoint)
2349
2350 Lorem
2351
2352 ## Assistant
2353
2354 LOREM
2355
2356 ## User (checkpoint)
2357
2358 ipsum
2359
2360 ## Assistant
2361
2362 IPSUM
2363
2364 "}
2365 );
2366 });
2367 assert_eq!(
2368 fs.files(),
2369 vec![
2370 Path::new(path!("/test/file-0")),
2371 Path::new(path!("/test/file-1"))
2372 ]
2373 );
2374
2375 // Checkpoint isn't stored when there are no changes.
2376 simulate_changes.store(false, SeqCst);
2377 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2378 .await
2379 .unwrap();
2380 thread.read_with(cx, |thread, cx| {
2381 assert_eq!(
2382 thread.to_markdown(cx),
2383 indoc! {"
2384 ## User (checkpoint)
2385
2386 Lorem
2387
2388 ## Assistant
2389
2390 LOREM
2391
2392 ## User (checkpoint)
2393
2394 ipsum
2395
2396 ## Assistant
2397
2398 IPSUM
2399
2400 ## User
2401
2402 dolor
2403
2404 ## Assistant
2405
2406 DOLOR
2407
2408 "}
2409 );
2410 });
2411 assert_eq!(
2412 fs.files(),
2413 vec![
2414 Path::new(path!("/test/file-0")),
2415 Path::new(path!("/test/file-1"))
2416 ]
2417 );
2418
2419 // Rewinding the conversation truncates the history and restores the checkpoint.
2420 thread
2421 .update(cx, |thread, cx| {
2422 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2423 panic!("unexpected entries {:?}", thread.entries)
2424 };
2425 thread.rewind(message.id.clone().unwrap(), cx)
2426 })
2427 .await
2428 .unwrap();
2429 thread.read_with(cx, |thread, cx| {
2430 assert_eq!(
2431 thread.to_markdown(cx),
2432 indoc! {"
2433 ## User (checkpoint)
2434
2435 Lorem
2436
2437 ## Assistant
2438
2439 LOREM
2440
2441 "}
2442 );
2443 });
2444 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2445 }
2446
2447 #[gpui::test]
2448 async fn test_refusal(cx: &mut TestAppContext) {
2449 init_test(cx);
2450 let fs = FakeFs::new(cx.background_executor.clone());
2451 fs.insert_tree(path!("/"), json!({})).await;
2452 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2453
2454 let refuse_next = Arc::new(AtomicBool::new(false));
2455 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2456 let refuse_next = refuse_next.clone();
2457 move |request, thread, mut cx| {
2458 let refuse_next = refuse_next.clone();
2459 async move {
2460 if refuse_next.load(SeqCst) {
2461 return Ok(acp::PromptResponse {
2462 stop_reason: acp::StopReason::Refusal,
2463 });
2464 }
2465
2466 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2467 panic!("expected text content block");
2468 };
2469 thread.update(&mut cx, |thread, cx| {
2470 thread
2471 .handle_session_update(
2472 acp::SessionUpdate::AgentMessageChunk {
2473 content: content.text.to_uppercase().into(),
2474 },
2475 cx,
2476 )
2477 .unwrap();
2478 })?;
2479 Ok(acp::PromptResponse {
2480 stop_reason: acp::StopReason::EndTurn,
2481 })
2482 }
2483 .boxed_local()
2484 }
2485 }));
2486 let thread = cx
2487 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2488 .await
2489 .unwrap();
2490
2491 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2492 .await
2493 .unwrap();
2494 thread.read_with(cx, |thread, cx| {
2495 assert_eq!(
2496 thread.to_markdown(cx),
2497 indoc! {"
2498 ## User
2499
2500 hello
2501
2502 ## Assistant
2503
2504 HELLO
2505
2506 "}
2507 );
2508 });
2509
2510 // Simulate refusing the second message, ensuring the conversation gets
2511 // truncated to before sending it.
2512 refuse_next.store(true, SeqCst);
2513 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2514 .await
2515 .unwrap();
2516 thread.read_with(cx, |thread, cx| {
2517 assert_eq!(
2518 thread.to_markdown(cx),
2519 indoc! {"
2520 ## User
2521
2522 hello
2523
2524 ## Assistant
2525
2526 HELLO
2527
2528 "}
2529 );
2530 });
2531 }
2532
2533 async fn run_until_first_tool_call(
2534 thread: &Entity<AcpThread>,
2535 cx: &mut TestAppContext,
2536 ) -> usize {
2537 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2538
2539 let subscription = cx.update(|cx| {
2540 cx.subscribe(thread, move |thread, _, cx| {
2541 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2542 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2543 return tx.try_send(ix).unwrap();
2544 }
2545 }
2546 })
2547 });
2548
2549 select! {
2550 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2551 panic!("Timeout waiting for tool call")
2552 }
2553 ix = rx.next().fuse() => {
2554 drop(subscription);
2555 ix.unwrap()
2556 }
2557 }
2558 }
2559
2560 #[derive(Clone, Default)]
2561 struct FakeAgentConnection {
2562 auth_methods: Vec<acp::AuthMethod>,
2563 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2564 on_user_message: Option<
2565 Rc<
2566 dyn Fn(
2567 acp::PromptRequest,
2568 WeakEntity<AcpThread>,
2569 AsyncApp,
2570 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2571 + 'static,
2572 >,
2573 >,
2574 }
2575
2576 impl FakeAgentConnection {
2577 fn new() -> Self {
2578 Self {
2579 auth_methods: Vec::new(),
2580 on_user_message: None,
2581 sessions: Arc::default(),
2582 }
2583 }
2584
2585 #[expect(unused)]
2586 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2587 self.auth_methods = auth_methods;
2588 self
2589 }
2590
2591 fn on_user_message(
2592 mut self,
2593 handler: impl Fn(
2594 acp::PromptRequest,
2595 WeakEntity<AcpThread>,
2596 AsyncApp,
2597 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2598 + 'static,
2599 ) -> Self {
2600 self.on_user_message.replace(Rc::new(handler));
2601 self
2602 }
2603 }
2604
2605 impl AgentConnection for FakeAgentConnection {
2606 fn auth_methods(&self) -> &[acp::AuthMethod] {
2607 &self.auth_methods
2608 }
2609
2610 fn new_thread(
2611 self: Rc<Self>,
2612 project: Entity<Project>,
2613 _cwd: &Path,
2614 cx: &mut App,
2615 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2616 let session_id = acp::SessionId(
2617 rand::thread_rng()
2618 .sample_iter(&rand::distributions::Alphanumeric)
2619 .take(7)
2620 .map(char::from)
2621 .collect::<String>()
2622 .into(),
2623 );
2624 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2625 let thread = cx.new(|cx| {
2626 AcpThread::new(
2627 "Test",
2628 self.clone(),
2629 project,
2630 action_log,
2631 session_id.clone(),
2632 watch::Receiver::constant(acp::PromptCapabilities {
2633 image: true,
2634 audio: true,
2635 embedded_context: true,
2636 }),
2637 cx,
2638 )
2639 });
2640 self.sessions.lock().insert(session_id, thread.downgrade());
2641 Task::ready(Ok(thread))
2642 }
2643
2644 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2645 if self.auth_methods().iter().any(|m| m.id == method) {
2646 Task::ready(Ok(()))
2647 } else {
2648 Task::ready(Err(anyhow!("Invalid Auth Method")))
2649 }
2650 }
2651
2652 fn prompt(
2653 &self,
2654 _id: Option<UserMessageId>,
2655 params: acp::PromptRequest,
2656 cx: &mut App,
2657 ) -> Task<gpui::Result<acp::PromptResponse>> {
2658 let sessions = self.sessions.lock();
2659 let thread = sessions.get(¶ms.session_id).unwrap();
2660 if let Some(handler) = &self.on_user_message {
2661 let handler = handler.clone();
2662 let thread = thread.clone();
2663 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2664 } else {
2665 Task::ready(Ok(acp::PromptResponse {
2666 stop_reason: acp::StopReason::EndTurn,
2667 }))
2668 }
2669 }
2670
2671 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2672 let sessions = self.sessions.lock();
2673 let thread = sessions.get(session_id).unwrap().clone();
2674
2675 cx.spawn(async move |cx| {
2676 thread
2677 .update(cx, |thread, cx| thread.cancel(cx))
2678 .unwrap()
2679 .await
2680 })
2681 .detach();
2682 }
2683
2684 fn truncate(
2685 &self,
2686 session_id: &acp::SessionId,
2687 _cx: &App,
2688 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2689 Some(Rc::new(FakeAgentSessionEditor {
2690 _session_id: session_id.clone(),
2691 }))
2692 }
2693
2694 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2695 self
2696 }
2697 }
2698
2699 struct FakeAgentSessionEditor {
2700 _session_id: acp::SessionId,
2701 }
2702
2703 impl AgentSessionTruncate for FakeAgentSessionEditor {
2704 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2705 Task::ready(Ok(()))
2706 }
2707 }
2708}