1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
187 first_line.to_owned() + "…"
188 } else {
189 tool_call.title
190 };
191 Self {
192 id: tool_call.id,
193 label: cx
194 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
195 kind: tool_call.kind,
196 content: tool_call
197 .content
198 .into_iter()
199 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
200 .collect(),
201 locations: tool_call.locations,
202 resolved_locations: Vec::default(),
203 status,
204 raw_input: tool_call.raw_input,
205 raw_output: tool_call.raw_output,
206 }
207 }
208
209 fn update_fields(
210 &mut self,
211 fields: acp::ToolCallUpdateFields,
212 language_registry: Arc<LanguageRegistry>,
213 cx: &mut App,
214 ) {
215 let acp::ToolCallUpdateFields {
216 kind,
217 status,
218 title,
219 content,
220 locations,
221 raw_input,
222 raw_output,
223 } = fields;
224
225 if let Some(kind) = kind {
226 self.kind = kind;
227 }
228
229 if let Some(status) = status {
230 self.status = status.into();
231 }
232
233 if let Some(title) = title {
234 self.label.update(cx, |label, cx| {
235 if let Some((first_line, _)) = title.split_once("\n") {
236 label.replace(first_line.to_owned() + "…", cx)
237 } else {
238 label.replace(title, cx);
239 }
240 });
241 }
242
243 if let Some(content) = content {
244 let new_content_len = content.len();
245 let mut content = content.into_iter();
246
247 // Reuse existing content if we can
248 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
249 old.update_from_acp(new, language_registry.clone(), cx);
250 }
251 for new in content {
252 self.content.push(ToolCallContent::from_acp(
253 new,
254 language_registry.clone(),
255 cx,
256 ))
257 }
258 self.content.truncate(new_content_len);
259 }
260
261 if let Some(locations) = locations {
262 self.locations = locations;
263 }
264
265 if let Some(raw_input) = raw_input {
266 self.raw_input = Some(raw_input);
267 }
268
269 if let Some(raw_output) = raw_output {
270 if self.content.is_empty()
271 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
272 {
273 self.content
274 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
275 markdown,
276 }));
277 }
278 self.raw_output = Some(raw_output);
279 }
280 }
281
282 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
283 self.content.iter().filter_map(|content| match content {
284 ToolCallContent::Diff(diff) => Some(diff),
285 ToolCallContent::ContentBlock(_) => None,
286 ToolCallContent::Terminal(_) => None,
287 })
288 }
289
290 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
291 self.content.iter().filter_map(|content| match content {
292 ToolCallContent::Terminal(terminal) => Some(terminal),
293 ToolCallContent::ContentBlock(_) => None,
294 ToolCallContent::Diff(_) => None,
295 })
296 }
297
298 fn to_markdown(&self, cx: &App) -> String {
299 let mut markdown = format!(
300 "**Tool Call: {}**\nStatus: {}\n\n",
301 self.label.read(cx).source(),
302 self.status
303 );
304 for content in &self.content {
305 markdown.push_str(content.to_markdown(cx).as_str());
306 markdown.push_str("\n\n");
307 }
308 markdown
309 }
310
311 async fn resolve_location(
312 location: acp::ToolCallLocation,
313 project: WeakEntity<Project>,
314 cx: &mut AsyncApp,
315 ) -> Option<AgentLocation> {
316 let buffer = project
317 .update(cx, |project, cx| {
318 project
319 .project_path_for_absolute_path(&location.path, cx)
320 .map(|path| project.open_buffer(path, cx))
321 })
322 .ok()??;
323 let buffer = buffer.await.log_err()?;
324 let position = buffer
325 .update(cx, |buffer, _| {
326 if let Some(row) = location.line {
327 let snapshot = buffer.snapshot();
328 let column = snapshot.indent_size_for_line(row).len;
329 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
330 snapshot.anchor_before(point)
331 } else {
332 Anchor::MIN
333 }
334 })
335 .ok()?;
336
337 Some(AgentLocation {
338 buffer: buffer.downgrade(),
339 position,
340 })
341 }
342
343 fn resolve_locations(
344 &self,
345 project: Entity<Project>,
346 cx: &mut App,
347 ) -> Task<Vec<Option<AgentLocation>>> {
348 let locations = self.locations.clone();
349 project.update(cx, |_, cx| {
350 cx.spawn(async move |project, cx| {
351 let mut new_locations = Vec::new();
352 for location in locations {
353 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
354 }
355 new_locations
356 })
357 })
358 }
359}
360
361#[derive(Debug)]
362pub enum ToolCallStatus {
363 /// The tool call hasn't started running yet, but we start showing it to
364 /// the user.
365 Pending,
366 /// The tool call is waiting for confirmation from the user.
367 WaitingForConfirmation {
368 options: Vec<acp::PermissionOption>,
369 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
370 },
371 /// The tool call is currently running.
372 InProgress,
373 /// The tool call completed successfully.
374 Completed,
375 /// The tool call failed.
376 Failed,
377 /// The user rejected the tool call.
378 Rejected,
379 /// The user canceled generation so the tool call was canceled.
380 Canceled,
381}
382
383impl From<acp::ToolCallStatus> for ToolCallStatus {
384 fn from(status: acp::ToolCallStatus) -> Self {
385 match status {
386 acp::ToolCallStatus::Pending => Self::Pending,
387 acp::ToolCallStatus::InProgress => Self::InProgress,
388 acp::ToolCallStatus::Completed => Self::Completed,
389 acp::ToolCallStatus::Failed => Self::Failed,
390 }
391 }
392}
393
394impl Display for ToolCallStatus {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 write!(
397 f,
398 "{}",
399 match self {
400 ToolCallStatus::Pending => "Pending",
401 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
402 ToolCallStatus::InProgress => "In Progress",
403 ToolCallStatus::Completed => "Completed",
404 ToolCallStatus::Failed => "Failed",
405 ToolCallStatus::Rejected => "Rejected",
406 ToolCallStatus::Canceled => "Canceled",
407 }
408 )
409 }
410}
411
412#[derive(Debug, PartialEq, Clone)]
413pub enum ContentBlock {
414 Empty,
415 Markdown { markdown: Entity<Markdown> },
416 ResourceLink { resource_link: acp::ResourceLink },
417}
418
419impl ContentBlock {
420 pub fn new(
421 block: acp::ContentBlock,
422 language_registry: &Arc<LanguageRegistry>,
423 cx: &mut App,
424 ) -> Self {
425 let mut this = Self::Empty;
426 this.append(block, language_registry, cx);
427 this
428 }
429
430 pub fn new_combined(
431 blocks: impl IntoIterator<Item = acp::ContentBlock>,
432 language_registry: Arc<LanguageRegistry>,
433 cx: &mut App,
434 ) -> Self {
435 let mut this = Self::Empty;
436 for block in blocks {
437 this.append(block, &language_registry, cx);
438 }
439 this
440 }
441
442 pub fn append(
443 &mut self,
444 block: acp::ContentBlock,
445 language_registry: &Arc<LanguageRegistry>,
446 cx: &mut App,
447 ) {
448 if matches!(self, ContentBlock::Empty)
449 && let acp::ContentBlock::ResourceLink(resource_link) = block
450 {
451 *self = ContentBlock::ResourceLink { resource_link };
452 return;
453 }
454
455 let new_content = self.block_string_contents(block);
456
457 match self {
458 ContentBlock::Empty => {
459 *self = Self::create_markdown_block(new_content, language_registry, cx);
460 }
461 ContentBlock::Markdown { markdown } => {
462 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
463 }
464 ContentBlock::ResourceLink { resource_link } => {
465 let existing_content = Self::resource_link_md(&resource_link.uri);
466 let combined = format!("{}\n{}", existing_content, new_content);
467
468 *self = Self::create_markdown_block(combined, language_registry, cx);
469 }
470 }
471 }
472
473 fn create_markdown_block(
474 content: String,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) -> ContentBlock {
478 ContentBlock::Markdown {
479 markdown: cx
480 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
481 }
482 }
483
484 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
485 match block {
486 acp::ContentBlock::Text(text_content) => text_content.text,
487 acp::ContentBlock::ResourceLink(resource_link) => {
488 Self::resource_link_md(&resource_link.uri)
489 }
490 acp::ContentBlock::Resource(acp::EmbeddedResource {
491 resource:
492 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
493 uri,
494 ..
495 }),
496 ..
497 }) => Self::resource_link_md(&uri),
498 acp::ContentBlock::Image(image) => Self::image_md(&image),
499 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
500 }
501 }
502
503 fn resource_link_md(uri: &str) -> String {
504 if let Some(uri) = MentionUri::parse(uri).log_err() {
505 uri.as_link().to_string()
506 } else {
507 uri.to_string()
508 }
509 }
510
511 fn image_md(_image: &acp::ImageContent) -> String {
512 "`Image`".into()
513 }
514
515 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
516 match self {
517 ContentBlock::Empty => "",
518 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
519 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
520 }
521 }
522
523 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
524 match self {
525 ContentBlock::Empty => None,
526 ContentBlock::Markdown { markdown } => Some(markdown),
527 ContentBlock::ResourceLink { .. } => None,
528 }
529 }
530
531 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
532 match self {
533 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
534 _ => None,
535 }
536 }
537}
538
539#[derive(Debug)]
540pub enum ToolCallContent {
541 ContentBlock(ContentBlock),
542 Diff(Entity<Diff>),
543 Terminal(Entity<Terminal>),
544}
545
546impl ToolCallContent {
547 pub fn from_acp(
548 content: acp::ToolCallContent,
549 language_registry: Arc<LanguageRegistry>,
550 cx: &mut App,
551 ) -> Self {
552 match content {
553 acp::ToolCallContent::Content { content } => {
554 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
555 }
556 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
557 Diff::finalized(
558 diff.path,
559 diff.old_text,
560 diff.new_text,
561 language_registry,
562 cx,
563 )
564 })),
565 }
566 }
567
568 pub fn update_from_acp(
569 &mut self,
570 new: acp::ToolCallContent,
571 language_registry: Arc<LanguageRegistry>,
572 cx: &mut App,
573 ) {
574 let needs_update = match (&self, &new) {
575 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
576 old_diff.read(cx).needs_update(
577 new_diff.old_text.as_deref().unwrap_or(""),
578 &new_diff.new_text,
579 cx,
580 )
581 }
582 _ => true,
583 };
584
585 if needs_update {
586 *self = Self::from_acp(new, language_registry, cx);
587 }
588 }
589
590 pub fn to_markdown(&self, cx: &App) -> String {
591 match self {
592 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
593 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
594 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
595 }
596 }
597}
598
599#[derive(Debug, PartialEq)]
600pub enum ToolCallUpdate {
601 UpdateFields(acp::ToolCallUpdate),
602 UpdateDiff(ToolCallUpdateDiff),
603 UpdateTerminal(ToolCallUpdateTerminal),
604}
605
606impl ToolCallUpdate {
607 fn id(&self) -> &acp::ToolCallId {
608 match self {
609 Self::UpdateFields(update) => &update.id,
610 Self::UpdateDiff(diff) => &diff.id,
611 Self::UpdateTerminal(terminal) => &terminal.id,
612 }
613 }
614}
615
616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
617 fn from(update: acp::ToolCallUpdate) -> Self {
618 Self::UpdateFields(update)
619 }
620}
621
622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
623 fn from(diff: ToolCallUpdateDiff) -> Self {
624 Self::UpdateDiff(diff)
625 }
626}
627
628#[derive(Debug, PartialEq)]
629pub struct ToolCallUpdateDiff {
630 pub id: acp::ToolCallId,
631 pub diff: Entity<Diff>,
632}
633
634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
635 fn from(terminal: ToolCallUpdateTerminal) -> Self {
636 Self::UpdateTerminal(terminal)
637 }
638}
639
640#[derive(Debug, PartialEq)]
641pub struct ToolCallUpdateTerminal {
642 pub id: acp::ToolCallId,
643 pub terminal: Entity<Terminal>,
644}
645
646#[derive(Debug, Default)]
647pub struct Plan {
648 pub entries: Vec<PlanEntry>,
649}
650
651#[derive(Debug)]
652pub struct PlanStats<'a> {
653 pub in_progress_entry: Option<&'a PlanEntry>,
654 pub pending: u32,
655 pub completed: u32,
656}
657
658impl Plan {
659 pub fn is_empty(&self) -> bool {
660 self.entries.is_empty()
661 }
662
663 pub fn stats(&self) -> PlanStats<'_> {
664 let mut stats = PlanStats {
665 in_progress_entry: None,
666 pending: 0,
667 completed: 0,
668 };
669
670 for entry in &self.entries {
671 match &entry.status {
672 acp::PlanEntryStatus::Pending => {
673 stats.pending += 1;
674 }
675 acp::PlanEntryStatus::InProgress => {
676 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
677 }
678 acp::PlanEntryStatus::Completed => {
679 stats.completed += 1;
680 }
681 }
682 }
683
684 stats
685 }
686}
687
688#[derive(Debug)]
689pub struct PlanEntry {
690 pub content: Entity<Markdown>,
691 pub priority: acp::PlanEntryPriority,
692 pub status: acp::PlanEntryStatus,
693}
694
695impl PlanEntry {
696 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
697 Self {
698 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
699 priority: entry.priority,
700 status: entry.status,
701 }
702 }
703}
704
705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct TokenUsage {
707 pub max_tokens: u64,
708 pub used_tokens: u64,
709}
710
711impl TokenUsage {
712 pub fn ratio(&self) -> TokenUsageRatio {
713 #[cfg(debug_assertions)]
714 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
715 .unwrap_or("0.8".to_string())
716 .parse()
717 .unwrap();
718 #[cfg(not(debug_assertions))]
719 let warning_threshold: f32 = 0.8;
720
721 // When the maximum is unknown because there is no selected model,
722 // avoid showing the token limit warning.
723 if self.max_tokens == 0 {
724 TokenUsageRatio::Normal
725 } else if self.used_tokens >= self.max_tokens {
726 TokenUsageRatio::Exceeded
727 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
728 TokenUsageRatio::Warning
729 } else {
730 TokenUsageRatio::Normal
731 }
732 }
733}
734
735#[derive(Debug, Clone, PartialEq, Eq)]
736pub enum TokenUsageRatio {
737 Normal,
738 Warning,
739 Exceeded,
740}
741
742#[derive(Debug, Clone)]
743pub struct RetryStatus {
744 pub last_error: SharedString,
745 pub attempt: usize,
746 pub max_attempts: usize,
747 pub started_at: Instant,
748 pub duration: Duration,
749}
750
751pub struct AcpThread {
752 title: SharedString,
753 entries: Vec<AgentThreadEntry>,
754 plan: Plan,
755 project: Entity<Project>,
756 action_log: Entity<ActionLog>,
757 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
758 send_task: Option<Task<()>>,
759 connection: Rc<dyn AgentConnection>,
760 session_id: acp::SessionId,
761 token_usage: Option<TokenUsage>,
762 prompt_capabilities: acp::PromptCapabilities,
763 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
764}
765
766#[derive(Debug)]
767pub enum AcpThreadEvent {
768 NewEntry,
769 TitleUpdated,
770 TokenUsageUpdated,
771 EntryUpdated(usize),
772 EntriesRemoved(Range<usize>),
773 ToolAuthorizationRequired,
774 Retry(RetryStatus),
775 Stopped,
776 Error,
777 LoadError(LoadError),
778 PromptCapabilitiesUpdated,
779}
780
781impl EventEmitter<AcpThreadEvent> for AcpThread {}
782
783#[derive(PartialEq, Eq, Debug)]
784pub enum ThreadStatus {
785 Idle,
786 WaitingForToolConfirmation,
787 Generating,
788}
789
790#[derive(Debug, Clone)]
791pub enum LoadError {
792 NotInstalled {
793 error_message: SharedString,
794 install_message: SharedString,
795 install_command: String,
796 },
797 Unsupported {
798 error_message: SharedString,
799 upgrade_message: SharedString,
800 upgrade_command: String,
801 },
802 Exited {
803 status: ExitStatus,
804 },
805 Other(SharedString),
806}
807
808impl Display for LoadError {
809 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
810 match self {
811 LoadError::NotInstalled { error_message, .. }
812 | LoadError::Unsupported { error_message, .. } => {
813 write!(f, "{error_message}")
814 }
815 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
816 LoadError::Other(msg) => write!(f, "{}", msg),
817 }
818 }
819}
820
821impl Error for LoadError {}
822
823impl AcpThread {
824 pub fn new(
825 title: impl Into<SharedString>,
826 connection: Rc<dyn AgentConnection>,
827 project: Entity<Project>,
828 action_log: Entity<ActionLog>,
829 session_id: acp::SessionId,
830 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
831 cx: &mut Context<Self>,
832 ) -> Self {
833 let prompt_capabilities = *prompt_capabilities_rx.borrow();
834 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
835 loop {
836 let caps = prompt_capabilities_rx.recv().await?;
837 this.update(cx, |this, cx| {
838 this.prompt_capabilities = caps;
839 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
840 })?;
841 }
842 });
843
844 Self {
845 action_log,
846 shared_buffers: Default::default(),
847 entries: Default::default(),
848 plan: Default::default(),
849 title: title.into(),
850 project,
851 send_task: None,
852 connection,
853 session_id,
854 token_usage: None,
855 prompt_capabilities,
856 _observe_prompt_capabilities: task,
857 }
858 }
859
860 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
861 self.prompt_capabilities
862 }
863
864 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
865 &self.connection
866 }
867
868 pub fn action_log(&self) -> &Entity<ActionLog> {
869 &self.action_log
870 }
871
872 pub fn project(&self) -> &Entity<Project> {
873 &self.project
874 }
875
876 pub fn title(&self) -> SharedString {
877 self.title.clone()
878 }
879
880 pub fn entries(&self) -> &[AgentThreadEntry] {
881 &self.entries
882 }
883
884 pub fn session_id(&self) -> &acp::SessionId {
885 &self.session_id
886 }
887
888 pub fn status(&self) -> ThreadStatus {
889 if self.send_task.is_some() {
890 if self.waiting_for_tool_confirmation() {
891 ThreadStatus::WaitingForToolConfirmation
892 } else {
893 ThreadStatus::Generating
894 }
895 } else {
896 ThreadStatus::Idle
897 }
898 }
899
900 pub fn token_usage(&self) -> Option<&TokenUsage> {
901 self.token_usage.as_ref()
902 }
903
904 pub fn has_pending_edit_tool_calls(&self) -> bool {
905 for entry in self.entries.iter().rev() {
906 match entry {
907 AgentThreadEntry::UserMessage(_) => return false,
908 AgentThreadEntry::ToolCall(
909 call @ ToolCall {
910 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
911 ..
912 },
913 ) if call.diffs().next().is_some() => {
914 return true;
915 }
916 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
917 }
918 }
919
920 false
921 }
922
923 pub fn used_tools_since_last_user_message(&self) -> bool {
924 for entry in self.entries.iter().rev() {
925 match entry {
926 AgentThreadEntry::UserMessage(..) => return false,
927 AgentThreadEntry::AssistantMessage(..) => continue,
928 AgentThreadEntry::ToolCall(..) => return true,
929 }
930 }
931
932 false
933 }
934
935 pub fn handle_session_update(
936 &mut self,
937 update: acp::SessionUpdate,
938 cx: &mut Context<Self>,
939 ) -> Result<(), acp::Error> {
940 match update {
941 acp::SessionUpdate::UserMessageChunk { content } => {
942 self.push_user_content_block(None, content, cx);
943 }
944 acp::SessionUpdate::AgentMessageChunk { content } => {
945 self.push_assistant_content_block(content, false, cx);
946 }
947 acp::SessionUpdate::AgentThoughtChunk { content } => {
948 self.push_assistant_content_block(content, true, cx);
949 }
950 acp::SessionUpdate::ToolCall(tool_call) => {
951 self.upsert_tool_call(tool_call, cx)?;
952 }
953 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
954 self.update_tool_call(tool_call_update, cx)?;
955 }
956 acp::SessionUpdate::Plan(plan) => {
957 self.update_plan(plan, cx);
958 }
959 }
960 Ok(())
961 }
962
963 pub fn push_user_content_block(
964 &mut self,
965 message_id: Option<UserMessageId>,
966 chunk: acp::ContentBlock,
967 cx: &mut Context<Self>,
968 ) {
969 let language_registry = self.project.read(cx).languages().clone();
970 let entries_len = self.entries.len();
971
972 if let Some(last_entry) = self.entries.last_mut()
973 && let AgentThreadEntry::UserMessage(UserMessage {
974 id,
975 content,
976 chunks,
977 ..
978 }) = last_entry
979 {
980 *id = message_id.or(id.take());
981 content.append(chunk.clone(), &language_registry, cx);
982 chunks.push(chunk);
983 let idx = entries_len - 1;
984 cx.emit(AcpThreadEvent::EntryUpdated(idx));
985 } else {
986 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
987 self.push_entry(
988 AgentThreadEntry::UserMessage(UserMessage {
989 id: message_id,
990 content,
991 chunks: vec![chunk],
992 checkpoint: None,
993 }),
994 cx,
995 );
996 }
997 }
998
999 pub fn push_assistant_content_block(
1000 &mut self,
1001 chunk: acp::ContentBlock,
1002 is_thought: bool,
1003 cx: &mut Context<Self>,
1004 ) {
1005 let language_registry = self.project.read(cx).languages().clone();
1006 let entries_len = self.entries.len();
1007 if let Some(last_entry) = self.entries.last_mut()
1008 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1009 {
1010 let idx = entries_len - 1;
1011 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1012 match (chunks.last_mut(), is_thought) {
1013 (Some(AssistantMessageChunk::Message { block }), false)
1014 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1015 block.append(chunk, &language_registry, cx)
1016 }
1017 _ => {
1018 let block = ContentBlock::new(chunk, &language_registry, cx);
1019 if is_thought {
1020 chunks.push(AssistantMessageChunk::Thought { block })
1021 } else {
1022 chunks.push(AssistantMessageChunk::Message { block })
1023 }
1024 }
1025 }
1026 } else {
1027 let block = ContentBlock::new(chunk, &language_registry, cx);
1028 let chunk = if is_thought {
1029 AssistantMessageChunk::Thought { block }
1030 } else {
1031 AssistantMessageChunk::Message { block }
1032 };
1033
1034 self.push_entry(
1035 AgentThreadEntry::AssistantMessage(AssistantMessage {
1036 chunks: vec![chunk],
1037 }),
1038 cx,
1039 );
1040 }
1041 }
1042
1043 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1044 self.entries.push(entry);
1045 cx.emit(AcpThreadEvent::NewEntry);
1046 }
1047
1048 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1049 self.connection.set_title(&self.session_id, cx).is_some()
1050 }
1051
1052 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1053 if title != self.title {
1054 self.title = title.clone();
1055 cx.emit(AcpThreadEvent::TitleUpdated);
1056 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1057 return set_title.run(title, cx);
1058 }
1059 }
1060 Task::ready(Ok(()))
1061 }
1062
1063 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1064 self.token_usage = usage;
1065 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1066 }
1067
1068 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1069 cx.emit(AcpThreadEvent::Retry(status));
1070 }
1071
1072 pub fn update_tool_call(
1073 &mut self,
1074 update: impl Into<ToolCallUpdate>,
1075 cx: &mut Context<Self>,
1076 ) -> Result<()> {
1077 let update = update.into();
1078 let languages = self.project.read(cx).languages().clone();
1079
1080 let (ix, current_call) = self
1081 .tool_call_mut(update.id())
1082 .context("Tool call not found")?;
1083 match update {
1084 ToolCallUpdate::UpdateFields(update) => {
1085 let location_updated = update.fields.locations.is_some();
1086 current_call.update_fields(update.fields, languages, cx);
1087 if location_updated {
1088 self.resolve_locations(update.id, cx);
1089 }
1090 }
1091 ToolCallUpdate::UpdateDiff(update) => {
1092 current_call.content.clear();
1093 current_call
1094 .content
1095 .push(ToolCallContent::Diff(update.diff));
1096 }
1097 ToolCallUpdate::UpdateTerminal(update) => {
1098 current_call.content.clear();
1099 current_call
1100 .content
1101 .push(ToolCallContent::Terminal(update.terminal));
1102 }
1103 }
1104
1105 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1106
1107 Ok(())
1108 }
1109
1110 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1111 pub fn upsert_tool_call(
1112 &mut self,
1113 tool_call: acp::ToolCall,
1114 cx: &mut Context<Self>,
1115 ) -> Result<(), acp::Error> {
1116 let status = tool_call.status.into();
1117 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1118 }
1119
1120 /// Fails if id does not match an existing entry.
1121 pub fn upsert_tool_call_inner(
1122 &mut self,
1123 tool_call_update: acp::ToolCallUpdate,
1124 status: ToolCallStatus,
1125 cx: &mut Context<Self>,
1126 ) -> Result<(), acp::Error> {
1127 let language_registry = self.project.read(cx).languages().clone();
1128 let id = tool_call_update.id.clone();
1129
1130 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1131 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1132 current_call.status = status;
1133
1134 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1135 } else {
1136 let call =
1137 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1138 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1139 };
1140
1141 self.resolve_locations(id, cx);
1142 Ok(())
1143 }
1144
1145 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1146 // The tool call we are looking for is typically the last one, or very close to the end.
1147 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1148 self.entries
1149 .iter_mut()
1150 .enumerate()
1151 .rev()
1152 .find_map(|(index, tool_call)| {
1153 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1154 && &tool_call.id == id
1155 {
1156 Some((index, tool_call))
1157 } else {
1158 None
1159 }
1160 })
1161 }
1162
1163 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1164 self.entries
1165 .iter()
1166 .enumerate()
1167 .rev()
1168 .find_map(|(index, tool_call)| {
1169 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1170 && &tool_call.id == id
1171 {
1172 Some((index, tool_call))
1173 } else {
1174 None
1175 }
1176 })
1177 }
1178
1179 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1180 let project = self.project.clone();
1181 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1182 return;
1183 };
1184 let task = tool_call.resolve_locations(project, cx);
1185 cx.spawn(async move |this, cx| {
1186 let resolved_locations = task.await;
1187 this.update(cx, |this, cx| {
1188 let project = this.project.clone();
1189 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1190 return;
1191 };
1192 if let Some(Some(location)) = resolved_locations.last() {
1193 project.update(cx, |project, cx| {
1194 if let Some(agent_location) = project.agent_location() {
1195 let should_ignore = agent_location.buffer == location.buffer
1196 && location
1197 .buffer
1198 .update(cx, |buffer, _| {
1199 let snapshot = buffer.snapshot();
1200 let old_position =
1201 agent_location.position.to_point(&snapshot);
1202 let new_position = location.position.to_point(&snapshot);
1203 // ignore this so that when we get updates from the edit tool
1204 // the position doesn't reset to the startof line
1205 old_position.row == new_position.row
1206 && old_position.column > new_position.column
1207 })
1208 .ok()
1209 .unwrap_or_default();
1210 if !should_ignore {
1211 project.set_agent_location(Some(location.clone()), cx);
1212 }
1213 }
1214 });
1215 }
1216 if tool_call.resolved_locations != resolved_locations {
1217 tool_call.resolved_locations = resolved_locations;
1218 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1219 }
1220 })
1221 })
1222 .detach();
1223 }
1224
1225 pub fn request_tool_call_authorization(
1226 &mut self,
1227 tool_call: acp::ToolCallUpdate,
1228 options: Vec<acp::PermissionOption>,
1229 cx: &mut Context<Self>,
1230 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1231 let (tx, rx) = oneshot::channel();
1232
1233 let status = ToolCallStatus::WaitingForConfirmation {
1234 options,
1235 respond_tx: tx,
1236 };
1237
1238 self.upsert_tool_call_inner(tool_call, status, cx)?;
1239 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1240 Ok(rx)
1241 }
1242
1243 pub fn authorize_tool_call(
1244 &mut self,
1245 id: acp::ToolCallId,
1246 option_id: acp::PermissionOptionId,
1247 option_kind: acp::PermissionOptionKind,
1248 cx: &mut Context<Self>,
1249 ) {
1250 let Some((ix, call)) = self.tool_call_mut(&id) else {
1251 return;
1252 };
1253
1254 let new_status = match option_kind {
1255 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1256 ToolCallStatus::Rejected
1257 }
1258 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1259 ToolCallStatus::InProgress
1260 }
1261 };
1262
1263 let curr_status = mem::replace(&mut call.status, new_status);
1264
1265 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1266 respond_tx.send(option_id).log_err();
1267 } else if cfg!(debug_assertions) {
1268 panic!("tried to authorize an already authorized tool call");
1269 }
1270
1271 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1272 }
1273
1274 /// Returns true if the last turn is awaiting tool authorization
1275 pub fn waiting_for_tool_confirmation(&self) -> bool {
1276 for entry in self.entries.iter().rev() {
1277 match &entry {
1278 AgentThreadEntry::ToolCall(call) => match call.status {
1279 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1280 ToolCallStatus::Pending
1281 | ToolCallStatus::InProgress
1282 | ToolCallStatus::Completed
1283 | ToolCallStatus::Failed
1284 | ToolCallStatus::Rejected
1285 | ToolCallStatus::Canceled => continue,
1286 },
1287 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1288 // Reached the beginning of the turn
1289 return false;
1290 }
1291 }
1292 }
1293 false
1294 }
1295
1296 pub fn plan(&self) -> &Plan {
1297 &self.plan
1298 }
1299
1300 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1301 let new_entries_len = request.entries.len();
1302 let mut new_entries = request.entries.into_iter();
1303
1304 // Reuse existing markdown to prevent flickering
1305 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1306 let PlanEntry {
1307 content,
1308 priority,
1309 status,
1310 } = old;
1311 content.update(cx, |old, cx| {
1312 old.replace(new.content, cx);
1313 });
1314 *priority = new.priority;
1315 *status = new.status;
1316 }
1317 for new in new_entries {
1318 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1319 }
1320 self.plan.entries.truncate(new_entries_len);
1321
1322 cx.notify();
1323 }
1324
1325 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1326 self.plan
1327 .entries
1328 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1329 cx.notify();
1330 }
1331
1332 #[cfg(any(test, feature = "test-support"))]
1333 pub fn send_raw(
1334 &mut self,
1335 message: &str,
1336 cx: &mut Context<Self>,
1337 ) -> BoxFuture<'static, Result<()>> {
1338 self.send(
1339 vec![acp::ContentBlock::Text(acp::TextContent {
1340 text: message.to_string(),
1341 annotations: None,
1342 })],
1343 cx,
1344 )
1345 }
1346
1347 pub fn send(
1348 &mut self,
1349 message: Vec<acp::ContentBlock>,
1350 cx: &mut Context<Self>,
1351 ) -> BoxFuture<'static, Result<()>> {
1352 let block = ContentBlock::new_combined(
1353 message.clone(),
1354 self.project.read(cx).languages().clone(),
1355 cx,
1356 );
1357 let request = acp::PromptRequest {
1358 prompt: message.clone(),
1359 session_id: self.session_id.clone(),
1360 };
1361 let git_store = self.project.read(cx).git_store().clone();
1362
1363 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1364 Some(UserMessageId::new())
1365 } else {
1366 None
1367 };
1368
1369 self.run_turn(cx, async move |this, cx| {
1370 this.update(cx, |this, cx| {
1371 this.push_entry(
1372 AgentThreadEntry::UserMessage(UserMessage {
1373 id: message_id.clone(),
1374 content: block,
1375 chunks: message,
1376 checkpoint: None,
1377 }),
1378 cx,
1379 );
1380 })
1381 .ok();
1382
1383 let old_checkpoint = git_store
1384 .update(cx, |git, cx| git.checkpoint(cx))?
1385 .await
1386 .context("failed to get old checkpoint")
1387 .log_err();
1388 this.update(cx, |this, cx| {
1389 if let Some((_ix, message)) = this.last_user_message() {
1390 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1391 git_checkpoint,
1392 show: false,
1393 });
1394 }
1395 this.connection.prompt(message_id, request, cx)
1396 })?
1397 .await
1398 })
1399 }
1400
1401 pub fn can_resume(&self, cx: &App) -> bool {
1402 self.connection.resume(&self.session_id, cx).is_some()
1403 }
1404
1405 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1406 self.run_turn(cx, async move |this, cx| {
1407 this.update(cx, |this, cx| {
1408 this.connection
1409 .resume(&this.session_id, cx)
1410 .map(|resume| resume.run(cx))
1411 })?
1412 .context("resuming a session is not supported")?
1413 .await
1414 })
1415 }
1416
1417 fn run_turn(
1418 &mut self,
1419 cx: &mut Context<Self>,
1420 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1421 ) -> BoxFuture<'static, Result<()>> {
1422 self.clear_completed_plan_entries(cx);
1423
1424 let (tx, rx) = oneshot::channel();
1425 let cancel_task = self.cancel(cx);
1426
1427 self.send_task = Some(cx.spawn(async move |this, cx| {
1428 cancel_task.await;
1429 tx.send(f(this, cx).await).ok();
1430 }));
1431
1432 cx.spawn(async move |this, cx| {
1433 let response = rx.await;
1434
1435 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1436 .await?;
1437
1438 this.update(cx, |this, cx| {
1439 this.project
1440 .update(cx, |project, cx| project.set_agent_location(None, cx));
1441 match response {
1442 Ok(Err(e)) => {
1443 this.send_task.take();
1444 cx.emit(AcpThreadEvent::Error);
1445 Err(e)
1446 }
1447 result => {
1448 let canceled = matches!(
1449 result,
1450 Ok(Ok(acp::PromptResponse {
1451 stop_reason: acp::StopReason::Cancelled
1452 }))
1453 );
1454
1455 // We only take the task if the current prompt wasn't canceled.
1456 //
1457 // This prompt may have been canceled because another one was sent
1458 // while it was still generating. In these cases, dropping `send_task`
1459 // would cause the next generation to be canceled.
1460 if !canceled {
1461 this.send_task.take();
1462 }
1463
1464 // Truncate entries if the last prompt was refused.
1465 if let Ok(Ok(acp::PromptResponse {
1466 stop_reason: acp::StopReason::Refusal,
1467 })) = result
1468 && let Some((ix, _)) = this.last_user_message()
1469 {
1470 let range = ix..this.entries.len();
1471 this.entries.truncate(ix);
1472 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1473 }
1474
1475 cx.emit(AcpThreadEvent::Stopped);
1476 Ok(())
1477 }
1478 }
1479 })?
1480 })
1481 .boxed()
1482 }
1483
1484 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1485 let Some(send_task) = self.send_task.take() else {
1486 return Task::ready(());
1487 };
1488
1489 for entry in self.entries.iter_mut() {
1490 if let AgentThreadEntry::ToolCall(call) = entry {
1491 let cancel = matches!(
1492 call.status,
1493 ToolCallStatus::Pending
1494 | ToolCallStatus::WaitingForConfirmation { .. }
1495 | ToolCallStatus::InProgress
1496 );
1497
1498 if cancel {
1499 call.status = ToolCallStatus::Canceled;
1500 }
1501 }
1502 }
1503
1504 self.connection.cancel(&self.session_id, cx);
1505
1506 // Wait for the send task to complete
1507 cx.foreground_executor().spawn(send_task)
1508 }
1509
1510 /// Rewinds this thread to before the entry at `index`, removing it and all
1511 /// subsequent entries while reverting any changes made from that point.
1512 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1513 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1514 return Task::ready(Err(anyhow!("not supported")));
1515 };
1516 let Some(message) = self.user_message(&id) else {
1517 return Task::ready(Err(anyhow!("message not found")));
1518 };
1519
1520 let checkpoint = message
1521 .checkpoint
1522 .as_ref()
1523 .map(|c| c.git_checkpoint.clone());
1524
1525 let git_store = self.project.read(cx).git_store().clone();
1526 cx.spawn(async move |this, cx| {
1527 if let Some(checkpoint) = checkpoint {
1528 git_store
1529 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1530 .await?;
1531 }
1532
1533 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1534 this.update(cx, |this, cx| {
1535 if let Some((ix, _)) = this.user_message_mut(&id) {
1536 let range = ix..this.entries.len();
1537 this.entries.truncate(ix);
1538 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1539 }
1540 })
1541 })
1542 }
1543
1544 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1545 let git_store = self.project.read(cx).git_store().clone();
1546
1547 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1548 if let Some(checkpoint) = message.checkpoint.as_ref() {
1549 checkpoint.git_checkpoint.clone()
1550 } else {
1551 return Task::ready(Ok(()));
1552 }
1553 } else {
1554 return Task::ready(Ok(()));
1555 };
1556
1557 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1558 cx.spawn(async move |this, cx| {
1559 let new_checkpoint = new_checkpoint
1560 .await
1561 .context("failed to get new checkpoint")
1562 .log_err();
1563 if let Some(new_checkpoint) = new_checkpoint {
1564 let equal = git_store
1565 .update(cx, |git, cx| {
1566 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1567 })?
1568 .await
1569 .unwrap_or(true);
1570 this.update(cx, |this, cx| {
1571 let (ix, message) = this.last_user_message().context("no user message")?;
1572 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1573 checkpoint.show = !equal;
1574 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1575 anyhow::Ok(())
1576 })??;
1577 }
1578
1579 Ok(())
1580 })
1581 }
1582
1583 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1584 self.entries
1585 .iter_mut()
1586 .enumerate()
1587 .rev()
1588 .find_map(|(ix, entry)| {
1589 if let AgentThreadEntry::UserMessage(message) = entry {
1590 Some((ix, message))
1591 } else {
1592 None
1593 }
1594 })
1595 }
1596
1597 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1598 self.entries.iter().find_map(|entry| {
1599 if let AgentThreadEntry::UserMessage(message) = entry {
1600 if message.id.as_ref() == Some(id) {
1601 Some(message)
1602 } else {
1603 None
1604 }
1605 } else {
1606 None
1607 }
1608 })
1609 }
1610
1611 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1612 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1613 if let AgentThreadEntry::UserMessage(message) = entry {
1614 if message.id.as_ref() == Some(id) {
1615 Some((ix, message))
1616 } else {
1617 None
1618 }
1619 } else {
1620 None
1621 }
1622 })
1623 }
1624
1625 pub fn read_text_file(
1626 &self,
1627 path: PathBuf,
1628 line: Option<u32>,
1629 limit: Option<u32>,
1630 reuse_shared_snapshot: bool,
1631 cx: &mut Context<Self>,
1632 ) -> Task<Result<String>> {
1633 let project = self.project.clone();
1634 let action_log = self.action_log.clone();
1635 cx.spawn(async move |this, cx| {
1636 let load = project.update(cx, |project, cx| {
1637 let path = project
1638 .project_path_for_absolute_path(&path, cx)
1639 .context("invalid path")?;
1640 anyhow::Ok(project.open_buffer(path, cx))
1641 });
1642 let buffer = load??.await?;
1643
1644 let snapshot = if reuse_shared_snapshot {
1645 this.read_with(cx, |this, _| {
1646 this.shared_buffers.get(&buffer.clone()).cloned()
1647 })
1648 .log_err()
1649 .flatten()
1650 } else {
1651 None
1652 };
1653
1654 let snapshot = if let Some(snapshot) = snapshot {
1655 snapshot
1656 } else {
1657 action_log.update(cx, |action_log, cx| {
1658 action_log.buffer_read(buffer.clone(), cx);
1659 })?;
1660 project.update(cx, |project, cx| {
1661 let position = buffer
1662 .read(cx)
1663 .snapshot()
1664 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1665 project.set_agent_location(
1666 Some(AgentLocation {
1667 buffer: buffer.downgrade(),
1668 position,
1669 }),
1670 cx,
1671 );
1672 })?;
1673
1674 buffer.update(cx, |buffer, _| buffer.snapshot())?
1675 };
1676
1677 this.update(cx, |this, _| {
1678 let text = snapshot.text();
1679 this.shared_buffers.insert(buffer.clone(), snapshot);
1680 if line.is_none() && limit.is_none() {
1681 return Ok(text);
1682 }
1683 let limit = limit.unwrap_or(u32::MAX) as usize;
1684 let Some(line) = line else {
1685 return Ok(text.lines().take(limit).collect::<String>());
1686 };
1687
1688 let count = text.lines().count();
1689 if count < line as usize {
1690 anyhow::bail!("There are only {} lines", count);
1691 }
1692 Ok(text
1693 .lines()
1694 .skip(line as usize + 1)
1695 .take(limit)
1696 .collect::<String>())
1697 })?
1698 })
1699 }
1700
1701 pub fn write_text_file(
1702 &self,
1703 path: PathBuf,
1704 content: String,
1705 cx: &mut Context<Self>,
1706 ) -> Task<Result<()>> {
1707 let project = self.project.clone();
1708 let action_log = self.action_log.clone();
1709 cx.spawn(async move |this, cx| {
1710 let load = project.update(cx, |project, cx| {
1711 let path = project
1712 .project_path_for_absolute_path(&path, cx)
1713 .context("invalid path")?;
1714 anyhow::Ok(project.open_buffer(path, cx))
1715 });
1716 let buffer = load??.await?;
1717 let snapshot = this.update(cx, |this, cx| {
1718 this.shared_buffers
1719 .get(&buffer)
1720 .cloned()
1721 .unwrap_or_else(|| buffer.read(cx).snapshot())
1722 })?;
1723 let edits = cx
1724 .background_executor()
1725 .spawn(async move {
1726 let old_text = snapshot.text();
1727 text_diff(old_text.as_str(), &content)
1728 .into_iter()
1729 .map(|(range, replacement)| {
1730 (
1731 snapshot.anchor_after(range.start)
1732 ..snapshot.anchor_before(range.end),
1733 replacement,
1734 )
1735 })
1736 .collect::<Vec<_>>()
1737 })
1738 .await;
1739
1740 project.update(cx, |project, cx| {
1741 project.set_agent_location(
1742 Some(AgentLocation {
1743 buffer: buffer.downgrade(),
1744 position: edits
1745 .last()
1746 .map(|(range, _)| range.end)
1747 .unwrap_or(Anchor::MIN),
1748 }),
1749 cx,
1750 );
1751 })?;
1752
1753 let format_on_save = cx.update(|cx| {
1754 action_log.update(cx, |action_log, cx| {
1755 action_log.buffer_read(buffer.clone(), cx);
1756 });
1757
1758 let format_on_save = buffer.update(cx, |buffer, cx| {
1759 buffer.edit(edits, None, cx);
1760
1761 let settings = language::language_settings::language_settings(
1762 buffer.language().map(|l| l.name()),
1763 buffer.file(),
1764 cx,
1765 );
1766
1767 settings.format_on_save != FormatOnSave::Off
1768 });
1769 action_log.update(cx, |action_log, cx| {
1770 action_log.buffer_edited(buffer.clone(), cx);
1771 });
1772 format_on_save
1773 })?;
1774
1775 if format_on_save {
1776 let format_task = project.update(cx, |project, cx| {
1777 project.format(
1778 HashSet::from_iter([buffer.clone()]),
1779 LspFormatTarget::Buffers,
1780 false,
1781 FormatTrigger::Save,
1782 cx,
1783 )
1784 })?;
1785 format_task.await.log_err();
1786
1787 action_log.update(cx, |action_log, cx| {
1788 action_log.buffer_edited(buffer.clone(), cx);
1789 })?;
1790 }
1791
1792 project
1793 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1794 .await
1795 })
1796 }
1797
1798 pub fn to_markdown(&self, cx: &App) -> String {
1799 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1800 }
1801
1802 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1803 cx.emit(AcpThreadEvent::LoadError(error));
1804 }
1805}
1806
1807fn markdown_for_raw_output(
1808 raw_output: &serde_json::Value,
1809 language_registry: &Arc<LanguageRegistry>,
1810 cx: &mut App,
1811) -> Option<Entity<Markdown>> {
1812 match raw_output {
1813 serde_json::Value::Null => None,
1814 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1815 Markdown::new(
1816 value.to_string().into(),
1817 Some(language_registry.clone()),
1818 None,
1819 cx,
1820 )
1821 })),
1822 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1823 Markdown::new(
1824 value.to_string().into(),
1825 Some(language_registry.clone()),
1826 None,
1827 cx,
1828 )
1829 })),
1830 serde_json::Value::String(value) => Some(cx.new(|cx| {
1831 Markdown::new(
1832 value.clone().into(),
1833 Some(language_registry.clone()),
1834 None,
1835 cx,
1836 )
1837 })),
1838 value => Some(cx.new(|cx| {
1839 Markdown::new(
1840 format!("```json\n{}\n```", value).into(),
1841 Some(language_registry.clone()),
1842 None,
1843 cx,
1844 )
1845 })),
1846 }
1847}
1848
1849#[cfg(test)]
1850mod tests {
1851 use super::*;
1852 use anyhow::anyhow;
1853 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1854 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1855 use indoc::indoc;
1856 use project::{FakeFs, Fs};
1857 use rand::Rng as _;
1858 use serde_json::json;
1859 use settings::SettingsStore;
1860 use smol::stream::StreamExt as _;
1861 use std::{
1862 any::Any,
1863 cell::RefCell,
1864 path::Path,
1865 rc::Rc,
1866 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1867 time::Duration,
1868 };
1869 use util::path;
1870
1871 fn init_test(cx: &mut TestAppContext) {
1872 env_logger::try_init().ok();
1873 cx.update(|cx| {
1874 let settings_store = SettingsStore::test(cx);
1875 cx.set_global(settings_store);
1876 Project::init_settings(cx);
1877 language::init(cx);
1878 });
1879 }
1880
1881 #[gpui::test]
1882 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1883 init_test(cx);
1884
1885 let fs = FakeFs::new(cx.executor());
1886 let project = Project::test(fs, [], cx).await;
1887 let connection = Rc::new(FakeAgentConnection::new());
1888 let thread = cx
1889 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1890 .await
1891 .unwrap();
1892
1893 // Test creating a new user message
1894 thread.update(cx, |thread, cx| {
1895 thread.push_user_content_block(
1896 None,
1897 acp::ContentBlock::Text(acp::TextContent {
1898 annotations: None,
1899 text: "Hello, ".to_string(),
1900 }),
1901 cx,
1902 );
1903 });
1904
1905 thread.update(cx, |thread, cx| {
1906 assert_eq!(thread.entries.len(), 1);
1907 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1908 assert_eq!(user_msg.id, None);
1909 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1910 } else {
1911 panic!("Expected UserMessage");
1912 }
1913 });
1914
1915 // Test appending to existing user message
1916 let message_1_id = UserMessageId::new();
1917 thread.update(cx, |thread, cx| {
1918 thread.push_user_content_block(
1919 Some(message_1_id.clone()),
1920 acp::ContentBlock::Text(acp::TextContent {
1921 annotations: None,
1922 text: "world!".to_string(),
1923 }),
1924 cx,
1925 );
1926 });
1927
1928 thread.update(cx, |thread, cx| {
1929 assert_eq!(thread.entries.len(), 1);
1930 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1931 assert_eq!(user_msg.id, Some(message_1_id));
1932 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1933 } else {
1934 panic!("Expected UserMessage");
1935 }
1936 });
1937
1938 // Test creating new user message after assistant message
1939 thread.update(cx, |thread, cx| {
1940 thread.push_assistant_content_block(
1941 acp::ContentBlock::Text(acp::TextContent {
1942 annotations: None,
1943 text: "Assistant response".to_string(),
1944 }),
1945 false,
1946 cx,
1947 );
1948 });
1949
1950 let message_2_id = UserMessageId::new();
1951 thread.update(cx, |thread, cx| {
1952 thread.push_user_content_block(
1953 Some(message_2_id.clone()),
1954 acp::ContentBlock::Text(acp::TextContent {
1955 annotations: None,
1956 text: "New user message".to_string(),
1957 }),
1958 cx,
1959 );
1960 });
1961
1962 thread.update(cx, |thread, cx| {
1963 assert_eq!(thread.entries.len(), 3);
1964 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1965 assert_eq!(user_msg.id, Some(message_2_id));
1966 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1967 } else {
1968 panic!("Expected UserMessage at index 2");
1969 }
1970 });
1971 }
1972
1973 #[gpui::test]
1974 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1975 init_test(cx);
1976
1977 let fs = FakeFs::new(cx.executor());
1978 let project = Project::test(fs, [], cx).await;
1979 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1980 |_, thread, mut cx| {
1981 async move {
1982 thread.update(&mut cx, |thread, cx| {
1983 thread
1984 .handle_session_update(
1985 acp::SessionUpdate::AgentThoughtChunk {
1986 content: "Thinking ".into(),
1987 },
1988 cx,
1989 )
1990 .unwrap();
1991 thread
1992 .handle_session_update(
1993 acp::SessionUpdate::AgentThoughtChunk {
1994 content: "hard!".into(),
1995 },
1996 cx,
1997 )
1998 .unwrap();
1999 })?;
2000 Ok(acp::PromptResponse {
2001 stop_reason: acp::StopReason::EndTurn,
2002 })
2003 }
2004 .boxed_local()
2005 },
2006 ));
2007
2008 let thread = cx
2009 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2010 .await
2011 .unwrap();
2012
2013 thread
2014 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2015 .await
2016 .unwrap();
2017
2018 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2019 assert_eq!(
2020 output,
2021 indoc! {r#"
2022 ## User
2023
2024 Hello from Zed!
2025
2026 ## Assistant
2027
2028 <thinking>
2029 Thinking hard!
2030 </thinking>
2031
2032 "#}
2033 );
2034 }
2035
2036 #[gpui::test]
2037 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2038 init_test(cx);
2039
2040 let fs = FakeFs::new(cx.executor());
2041 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2042 .await;
2043 let project = Project::test(fs.clone(), [], cx).await;
2044 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2045 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2046 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2047 move |_, thread, mut cx| {
2048 let read_file_tx = read_file_tx.clone();
2049 async move {
2050 let content = thread
2051 .update(&mut cx, |thread, cx| {
2052 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2053 })
2054 .unwrap()
2055 .await
2056 .unwrap();
2057 assert_eq!(content, "one\ntwo\nthree\n");
2058 read_file_tx.take().unwrap().send(()).unwrap();
2059 thread
2060 .update(&mut cx, |thread, cx| {
2061 thread.write_text_file(
2062 path!("/tmp/foo").into(),
2063 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2064 cx,
2065 )
2066 })
2067 .unwrap()
2068 .await
2069 .unwrap();
2070 Ok(acp::PromptResponse {
2071 stop_reason: acp::StopReason::EndTurn,
2072 })
2073 }
2074 .boxed_local()
2075 },
2076 ));
2077
2078 let (worktree, pathbuf) = project
2079 .update(cx, |project, cx| {
2080 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2081 })
2082 .await
2083 .unwrap();
2084 let buffer = project
2085 .update(cx, |project, cx| {
2086 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2087 })
2088 .await
2089 .unwrap();
2090
2091 let thread = cx
2092 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2093 .await
2094 .unwrap();
2095
2096 let request = thread.update(cx, |thread, cx| {
2097 thread.send_raw("Extend the count in /tmp/foo", cx)
2098 });
2099 read_file_rx.await.ok();
2100 buffer.update(cx, |buffer, cx| {
2101 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2102 });
2103 cx.run_until_parked();
2104 assert_eq!(
2105 buffer.read_with(cx, |buffer, _| buffer.text()),
2106 "zero\none\ntwo\nthree\nfour\nfive\n"
2107 );
2108 assert_eq!(
2109 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2110 "zero\none\ntwo\nthree\nfour\nfive\n"
2111 );
2112 request.await.unwrap();
2113 }
2114
2115 #[gpui::test]
2116 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2117 init_test(cx);
2118
2119 let fs = FakeFs::new(cx.executor());
2120 let project = Project::test(fs, [], cx).await;
2121 let id = acp::ToolCallId("test".into());
2122
2123 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2124 let id = id.clone();
2125 move |_, thread, mut cx| {
2126 let id = id.clone();
2127 async move {
2128 thread
2129 .update(&mut cx, |thread, cx| {
2130 thread.handle_session_update(
2131 acp::SessionUpdate::ToolCall(acp::ToolCall {
2132 id: id.clone(),
2133 title: "Label".into(),
2134 kind: acp::ToolKind::Fetch,
2135 status: acp::ToolCallStatus::InProgress,
2136 content: vec![],
2137 locations: vec![],
2138 raw_input: None,
2139 raw_output: None,
2140 }),
2141 cx,
2142 )
2143 })
2144 .unwrap()
2145 .unwrap();
2146 Ok(acp::PromptResponse {
2147 stop_reason: acp::StopReason::EndTurn,
2148 })
2149 }
2150 .boxed_local()
2151 }
2152 }));
2153
2154 let thread = cx
2155 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2156 .await
2157 .unwrap();
2158
2159 let request = thread.update(cx, |thread, cx| {
2160 thread.send_raw("Fetch https://example.com", cx)
2161 });
2162
2163 run_until_first_tool_call(&thread, cx).await;
2164
2165 thread.read_with(cx, |thread, _| {
2166 assert!(matches!(
2167 thread.entries[1],
2168 AgentThreadEntry::ToolCall(ToolCall {
2169 status: ToolCallStatus::InProgress,
2170 ..
2171 })
2172 ));
2173 });
2174
2175 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2176
2177 thread.read_with(cx, |thread, _| {
2178 assert!(matches!(
2179 &thread.entries[1],
2180 AgentThreadEntry::ToolCall(ToolCall {
2181 status: ToolCallStatus::Canceled,
2182 ..
2183 })
2184 ));
2185 });
2186
2187 thread
2188 .update(cx, |thread, cx| {
2189 thread.handle_session_update(
2190 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2191 id,
2192 fields: acp::ToolCallUpdateFields {
2193 status: Some(acp::ToolCallStatus::Completed),
2194 ..Default::default()
2195 },
2196 }),
2197 cx,
2198 )
2199 })
2200 .unwrap();
2201
2202 request.await.unwrap();
2203
2204 thread.read_with(cx, |thread, _| {
2205 assert!(matches!(
2206 thread.entries[1],
2207 AgentThreadEntry::ToolCall(ToolCall {
2208 status: ToolCallStatus::Completed,
2209 ..
2210 })
2211 ));
2212 });
2213 }
2214
2215 #[gpui::test]
2216 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2217 init_test(cx);
2218 let fs = FakeFs::new(cx.background_executor.clone());
2219 fs.insert_tree(path!("/test"), json!({})).await;
2220 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2221
2222 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2223 move |_, thread, mut cx| {
2224 async move {
2225 thread
2226 .update(&mut cx, |thread, cx| {
2227 thread.handle_session_update(
2228 acp::SessionUpdate::ToolCall(acp::ToolCall {
2229 id: acp::ToolCallId("test".into()),
2230 title: "Label".into(),
2231 kind: acp::ToolKind::Edit,
2232 status: acp::ToolCallStatus::Completed,
2233 content: vec![acp::ToolCallContent::Diff {
2234 diff: acp::Diff {
2235 path: "/test/test.txt".into(),
2236 old_text: None,
2237 new_text: "foo".into(),
2238 },
2239 }],
2240 locations: vec![],
2241 raw_input: None,
2242 raw_output: None,
2243 }),
2244 cx,
2245 )
2246 })
2247 .unwrap()
2248 .unwrap();
2249 Ok(acp::PromptResponse {
2250 stop_reason: acp::StopReason::EndTurn,
2251 })
2252 }
2253 .boxed_local()
2254 }
2255 }));
2256
2257 let thread = cx
2258 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2259 .await
2260 .unwrap();
2261
2262 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2263 .await
2264 .unwrap();
2265
2266 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2267 }
2268
2269 #[gpui::test(iterations = 10)]
2270 async fn test_checkpoints(cx: &mut TestAppContext) {
2271 init_test(cx);
2272 let fs = FakeFs::new(cx.background_executor.clone());
2273 fs.insert_tree(
2274 path!("/test"),
2275 json!({
2276 ".git": {}
2277 }),
2278 )
2279 .await;
2280 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2281
2282 let simulate_changes = Arc::new(AtomicBool::new(true));
2283 let next_filename = Arc::new(AtomicUsize::new(0));
2284 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2285 let simulate_changes = simulate_changes.clone();
2286 let next_filename = next_filename.clone();
2287 let fs = fs.clone();
2288 move |request, thread, mut cx| {
2289 let fs = fs.clone();
2290 let simulate_changes = simulate_changes.clone();
2291 let next_filename = next_filename.clone();
2292 async move {
2293 if simulate_changes.load(SeqCst) {
2294 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2295 fs.write(Path::new(&filename), b"").await?;
2296 }
2297
2298 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2299 panic!("expected text content block");
2300 };
2301 thread.update(&mut cx, |thread, cx| {
2302 thread
2303 .handle_session_update(
2304 acp::SessionUpdate::AgentMessageChunk {
2305 content: content.text.to_uppercase().into(),
2306 },
2307 cx,
2308 )
2309 .unwrap();
2310 })?;
2311 Ok(acp::PromptResponse {
2312 stop_reason: acp::StopReason::EndTurn,
2313 })
2314 }
2315 .boxed_local()
2316 }
2317 }));
2318 let thread = cx
2319 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2320 .await
2321 .unwrap();
2322
2323 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2324 .await
2325 .unwrap();
2326 thread.read_with(cx, |thread, cx| {
2327 assert_eq!(
2328 thread.to_markdown(cx),
2329 indoc! {"
2330 ## User (checkpoint)
2331
2332 Lorem
2333
2334 ## Assistant
2335
2336 LOREM
2337
2338 "}
2339 );
2340 });
2341 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2342
2343 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2344 .await
2345 .unwrap();
2346 thread.read_with(cx, |thread, cx| {
2347 assert_eq!(
2348 thread.to_markdown(cx),
2349 indoc! {"
2350 ## User (checkpoint)
2351
2352 Lorem
2353
2354 ## Assistant
2355
2356 LOREM
2357
2358 ## User (checkpoint)
2359
2360 ipsum
2361
2362 ## Assistant
2363
2364 IPSUM
2365
2366 "}
2367 );
2368 });
2369 assert_eq!(
2370 fs.files(),
2371 vec![
2372 Path::new(path!("/test/file-0")),
2373 Path::new(path!("/test/file-1"))
2374 ]
2375 );
2376
2377 // Checkpoint isn't stored when there are no changes.
2378 simulate_changes.store(false, SeqCst);
2379 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2380 .await
2381 .unwrap();
2382 thread.read_with(cx, |thread, cx| {
2383 assert_eq!(
2384 thread.to_markdown(cx),
2385 indoc! {"
2386 ## User (checkpoint)
2387
2388 Lorem
2389
2390 ## Assistant
2391
2392 LOREM
2393
2394 ## User (checkpoint)
2395
2396 ipsum
2397
2398 ## Assistant
2399
2400 IPSUM
2401
2402 ## User
2403
2404 dolor
2405
2406 ## Assistant
2407
2408 DOLOR
2409
2410 "}
2411 );
2412 });
2413 assert_eq!(
2414 fs.files(),
2415 vec![
2416 Path::new(path!("/test/file-0")),
2417 Path::new(path!("/test/file-1"))
2418 ]
2419 );
2420
2421 // Rewinding the conversation truncates the history and restores the checkpoint.
2422 thread
2423 .update(cx, |thread, cx| {
2424 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2425 panic!("unexpected entries {:?}", thread.entries)
2426 };
2427 thread.rewind(message.id.clone().unwrap(), cx)
2428 })
2429 .await
2430 .unwrap();
2431 thread.read_with(cx, |thread, cx| {
2432 assert_eq!(
2433 thread.to_markdown(cx),
2434 indoc! {"
2435 ## User (checkpoint)
2436
2437 Lorem
2438
2439 ## Assistant
2440
2441 LOREM
2442
2443 "}
2444 );
2445 });
2446 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2447 }
2448
2449 #[gpui::test]
2450 async fn test_refusal(cx: &mut TestAppContext) {
2451 init_test(cx);
2452 let fs = FakeFs::new(cx.background_executor.clone());
2453 fs.insert_tree(path!("/"), json!({})).await;
2454 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2455
2456 let refuse_next = Arc::new(AtomicBool::new(false));
2457 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2458 let refuse_next = refuse_next.clone();
2459 move |request, thread, mut cx| {
2460 let refuse_next = refuse_next.clone();
2461 async move {
2462 if refuse_next.load(SeqCst) {
2463 return Ok(acp::PromptResponse {
2464 stop_reason: acp::StopReason::Refusal,
2465 });
2466 }
2467
2468 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2469 panic!("expected text content block");
2470 };
2471 thread.update(&mut cx, |thread, cx| {
2472 thread
2473 .handle_session_update(
2474 acp::SessionUpdate::AgentMessageChunk {
2475 content: content.text.to_uppercase().into(),
2476 },
2477 cx,
2478 )
2479 .unwrap();
2480 })?;
2481 Ok(acp::PromptResponse {
2482 stop_reason: acp::StopReason::EndTurn,
2483 })
2484 }
2485 .boxed_local()
2486 }
2487 }));
2488 let thread = cx
2489 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2490 .await
2491 .unwrap();
2492
2493 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2494 .await
2495 .unwrap();
2496 thread.read_with(cx, |thread, cx| {
2497 assert_eq!(
2498 thread.to_markdown(cx),
2499 indoc! {"
2500 ## User
2501
2502 hello
2503
2504 ## Assistant
2505
2506 HELLO
2507
2508 "}
2509 );
2510 });
2511
2512 // Simulate refusing the second message, ensuring the conversation gets
2513 // truncated to before sending it.
2514 refuse_next.store(true, SeqCst);
2515 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2516 .await
2517 .unwrap();
2518 thread.read_with(cx, |thread, cx| {
2519 assert_eq!(
2520 thread.to_markdown(cx),
2521 indoc! {"
2522 ## User
2523
2524 hello
2525
2526 ## Assistant
2527
2528 HELLO
2529
2530 "}
2531 );
2532 });
2533 }
2534
2535 async fn run_until_first_tool_call(
2536 thread: &Entity<AcpThread>,
2537 cx: &mut TestAppContext,
2538 ) -> usize {
2539 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2540
2541 let subscription = cx.update(|cx| {
2542 cx.subscribe(thread, move |thread, _, cx| {
2543 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2544 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2545 return tx.try_send(ix).unwrap();
2546 }
2547 }
2548 })
2549 });
2550
2551 select! {
2552 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2553 panic!("Timeout waiting for tool call")
2554 }
2555 ix = rx.next().fuse() => {
2556 drop(subscription);
2557 ix.unwrap()
2558 }
2559 }
2560 }
2561
2562 #[derive(Clone, Default)]
2563 struct FakeAgentConnection {
2564 auth_methods: Vec<acp::AuthMethod>,
2565 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2566 on_user_message: Option<
2567 Rc<
2568 dyn Fn(
2569 acp::PromptRequest,
2570 WeakEntity<AcpThread>,
2571 AsyncApp,
2572 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2573 + 'static,
2574 >,
2575 >,
2576 }
2577
2578 impl FakeAgentConnection {
2579 fn new() -> Self {
2580 Self {
2581 auth_methods: Vec::new(),
2582 on_user_message: None,
2583 sessions: Arc::default(),
2584 }
2585 }
2586
2587 #[expect(unused)]
2588 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2589 self.auth_methods = auth_methods;
2590 self
2591 }
2592
2593 fn on_user_message(
2594 mut self,
2595 handler: impl Fn(
2596 acp::PromptRequest,
2597 WeakEntity<AcpThread>,
2598 AsyncApp,
2599 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2600 + 'static,
2601 ) -> Self {
2602 self.on_user_message.replace(Rc::new(handler));
2603 self
2604 }
2605 }
2606
2607 impl AgentConnection for FakeAgentConnection {
2608 fn auth_methods(&self) -> &[acp::AuthMethod] {
2609 &self.auth_methods
2610 }
2611
2612 fn new_thread(
2613 self: Rc<Self>,
2614 project: Entity<Project>,
2615 _cwd: &Path,
2616 cx: &mut App,
2617 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2618 let session_id = acp::SessionId(
2619 rand::thread_rng()
2620 .sample_iter(&rand::distributions::Alphanumeric)
2621 .take(7)
2622 .map(char::from)
2623 .collect::<String>()
2624 .into(),
2625 );
2626 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2627 let thread = cx.new(|cx| {
2628 AcpThread::new(
2629 "Test",
2630 self.clone(),
2631 project,
2632 action_log,
2633 session_id.clone(),
2634 watch::Receiver::constant(acp::PromptCapabilities {
2635 image: true,
2636 audio: true,
2637 embedded_context: true,
2638 }),
2639 cx,
2640 )
2641 });
2642 self.sessions.lock().insert(session_id, thread.downgrade());
2643 Task::ready(Ok(thread))
2644 }
2645
2646 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2647 if self.auth_methods().iter().any(|m| m.id == method) {
2648 Task::ready(Ok(()))
2649 } else {
2650 Task::ready(Err(anyhow!("Invalid Auth Method")))
2651 }
2652 }
2653
2654 fn prompt(
2655 &self,
2656 _id: Option<UserMessageId>,
2657 params: acp::PromptRequest,
2658 cx: &mut App,
2659 ) -> Task<gpui::Result<acp::PromptResponse>> {
2660 let sessions = self.sessions.lock();
2661 let thread = sessions.get(¶ms.session_id).unwrap();
2662 if let Some(handler) = &self.on_user_message {
2663 let handler = handler.clone();
2664 let thread = thread.clone();
2665 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2666 } else {
2667 Task::ready(Ok(acp::PromptResponse {
2668 stop_reason: acp::StopReason::EndTurn,
2669 }))
2670 }
2671 }
2672
2673 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2674 let sessions = self.sessions.lock();
2675 let thread = sessions.get(session_id).unwrap().clone();
2676
2677 cx.spawn(async move |cx| {
2678 thread
2679 .update(cx, |thread, cx| thread.cancel(cx))
2680 .unwrap()
2681 .await
2682 })
2683 .detach();
2684 }
2685
2686 fn truncate(
2687 &self,
2688 session_id: &acp::SessionId,
2689 _cx: &App,
2690 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2691 Some(Rc::new(FakeAgentSessionEditor {
2692 _session_id: session_id.clone(),
2693 }))
2694 }
2695
2696 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2697 self
2698 }
2699 }
2700
2701 struct FakeAgentSessionEditor {
2702 _session_id: acp::SessionId,
2703 }
2704
2705 impl AgentSessionTruncate for FakeAgentSessionEditor {
2706 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2707 Task::ready(Ok(()))
2708 }
2709 }
2710}