1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35use uuid::Uuid;
36
37pub fn new_prompt_id() -> acp::PromptId {
38 acp::PromptId(Uuid::new_v4().to_string().into())
39}
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub prompt_id: Option<acp::PromptId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77}
78
79impl AssistantMessage {
80 pub fn to_markdown(&self, cx: &App) -> String {
81 format!(
82 "## Assistant\n\n{}\n\n",
83 self.chunks
84 .iter()
85 .map(|chunk| chunk.to_markdown(cx))
86 .join("\n\n")
87 )
88 }
89}
90
91#[derive(Debug, PartialEq)]
92pub enum AssistantMessageChunk {
93 Message { block: ContentBlock },
94 Thought { block: ContentBlock },
95}
96
97impl AssistantMessageChunk {
98 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Message {
100 block: ContentBlock::new(chunk.into(), language_registry, cx),
101 }
102 }
103
104 fn to_markdown(&self, cx: &App) -> String {
105 match self {
106 Self::Message { block } => block.to_markdown(cx).to_string(),
107 Self::Thought { block } => {
108 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
109 }
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum AgentThreadEntry {
116 UserMessage(UserMessage),
117 AssistantMessage(AssistantMessage),
118 ToolCall(ToolCall),
119}
120
121impl AgentThreadEntry {
122 pub fn to_markdown(&self, cx: &App) -> String {
123 match self {
124 Self::UserMessage(message) => message.to_markdown(cx),
125 Self::AssistantMessage(message) => message.to_markdown(cx),
126 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
127 }
128 }
129
130 pub fn user_message(&self) -> Option<&UserMessage> {
131 if let AgentThreadEntry::UserMessage(message) = self {
132 Some(message)
133 } else {
134 None
135 }
136 }
137
138 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.diffs())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
147 if let AgentThreadEntry::ToolCall(call) = self {
148 itertools::Either::Left(call.terminals())
149 } else {
150 itertools::Either::Right(std::iter::empty())
151 }
152 }
153
154 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
155 if let AgentThreadEntry::ToolCall(ToolCall {
156 locations,
157 resolved_locations,
158 ..
159 }) = self
160 {
161 Some((
162 locations.get(ix)?.clone(),
163 resolved_locations.get(ix)?.clone()?,
164 ))
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 pub content: Vec<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179 pub resolved_locations: Vec<Option<AgentLocation>>,
180 pub raw_input: Option<serde_json::Value>,
181 pub raw_output: Option<serde_json::Value>,
182}
183
184impl ToolCall {
185 fn from_acp(
186 tool_call: acp::ToolCall,
187 status: ToolCallStatus,
188 language_registry: Arc<LanguageRegistry>,
189 cx: &mut App,
190 ) -> Self {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 Self {
197 id: tool_call.id,
198 label: cx
199 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
200 kind: tool_call.kind,
201 content: tool_call
202 .content
203 .into_iter()
204 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
205 .collect(),
206 locations: tool_call.locations,
207 resolved_locations: Vec::default(),
208 status,
209 raw_input: tool_call.raw_input,
210 raw_output: tool_call.raw_output,
211 }
212 }
213
214 fn update_fields(
215 &mut self,
216 fields: acp::ToolCallUpdateFields,
217 language_registry: Arc<LanguageRegistry>,
218 cx: &mut App,
219 ) {
220 let acp::ToolCallUpdateFields {
221 kind,
222 status,
223 title,
224 content,
225 locations,
226 raw_input,
227 raw_output,
228 } = fields;
229
230 if let Some(kind) = kind {
231 self.kind = kind;
232 }
233
234 if let Some(status) = status {
235 self.status = status.into();
236 }
237
238 if let Some(title) = title {
239 self.label.update(cx, |label, cx| {
240 if let Some((first_line, _)) = title.split_once("\n") {
241 label.replace(first_line.to_owned() + "…", cx)
242 } else {
243 label.replace(title, cx);
244 }
245 });
246 }
247
248 if let Some(content) = content {
249 let new_content_len = content.len();
250 let mut content = content.into_iter();
251
252 // Reuse existing content if we can
253 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
254 old.update_from_acp(new, language_registry.clone(), cx);
255 }
256 for new in content {
257 self.content.push(ToolCallContent::from_acp(
258 new,
259 language_registry.clone(),
260 cx,
261 ))
262 }
263 self.content.truncate(new_content_len);
264 }
265
266 if let Some(locations) = locations {
267 self.locations = locations;
268 }
269
270 if let Some(raw_input) = raw_input {
271 self.raw_input = Some(raw_input);
272 }
273
274 if let Some(raw_output) = raw_output {
275 if self.content.is_empty()
276 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
277 {
278 self.content
279 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
280 markdown,
281 }));
282 }
283 self.raw_output = Some(raw_output);
284 }
285 }
286
287 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
288 self.content.iter().filter_map(|content| match content {
289 ToolCallContent::Diff(diff) => Some(diff),
290 ToolCallContent::ContentBlock(_) => None,
291 ToolCallContent::Terminal(_) => None,
292 })
293 }
294
295 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
296 self.content.iter().filter_map(|content| match content {
297 ToolCallContent::Terminal(terminal) => Some(terminal),
298 ToolCallContent::ContentBlock(_) => None,
299 ToolCallContent::Diff(_) => None,
300 })
301 }
302
303 fn to_markdown(&self, cx: &App) -> String {
304 let mut markdown = format!(
305 "**Tool Call: {}**\nStatus: {}\n\n",
306 self.label.read(cx).source(),
307 self.status
308 );
309 for content in &self.content {
310 markdown.push_str(content.to_markdown(cx).as_str());
311 markdown.push_str("\n\n");
312 }
313 markdown
314 }
315
316 async fn resolve_location(
317 location: acp::ToolCallLocation,
318 project: WeakEntity<Project>,
319 cx: &mut AsyncApp,
320 ) -> Option<AgentLocation> {
321 let buffer = project
322 .update(cx, |project, cx| {
323 project
324 .project_path_for_absolute_path(&location.path, cx)
325 .map(|path| project.open_buffer(path, cx))
326 })
327 .ok()??;
328 let buffer = buffer.await.log_err()?;
329 let position = buffer
330 .update(cx, |buffer, _| {
331 if let Some(row) = location.line {
332 let snapshot = buffer.snapshot();
333 let column = snapshot.indent_size_for_line(row).len;
334 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
335 snapshot.anchor_before(point)
336 } else {
337 Anchor::MIN
338 }
339 })
340 .ok()?;
341
342 Some(AgentLocation {
343 buffer: buffer.downgrade(),
344 position,
345 })
346 }
347
348 fn resolve_locations(
349 &self,
350 project: Entity<Project>,
351 cx: &mut App,
352 ) -> Task<Vec<Option<AgentLocation>>> {
353 let locations = self.locations.clone();
354 project.update(cx, |_, cx| {
355 cx.spawn(async move |project, cx| {
356 let mut new_locations = Vec::new();
357 for location in locations {
358 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
359 }
360 new_locations
361 })
362 })
363 }
364}
365
366#[derive(Debug)]
367pub enum ToolCallStatus {
368 /// The tool call hasn't started running yet, but we start showing it to
369 /// the user.
370 Pending,
371 /// The tool call is waiting for confirmation from the user.
372 WaitingForConfirmation {
373 options: Vec<acp::PermissionOption>,
374 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
375 },
376 /// The tool call is currently running.
377 InProgress,
378 /// The tool call completed successfully.
379 Completed,
380 /// The tool call failed.
381 Failed,
382 /// The user rejected the tool call.
383 Rejected,
384 /// The user canceled generation so the tool call was canceled.
385 Canceled,
386}
387
388impl From<acp::ToolCallStatus> for ToolCallStatus {
389 fn from(status: acp::ToolCallStatus) -> Self {
390 match status {
391 acp::ToolCallStatus::Pending => Self::Pending,
392 acp::ToolCallStatus::InProgress => Self::InProgress,
393 acp::ToolCallStatus::Completed => Self::Completed,
394 acp::ToolCallStatus::Failed => Self::Failed,
395 }
396 }
397}
398
399impl Display for ToolCallStatus {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 write!(
402 f,
403 "{}",
404 match self {
405 ToolCallStatus::Pending => "Pending",
406 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
407 ToolCallStatus::InProgress => "In Progress",
408 ToolCallStatus::Completed => "Completed",
409 ToolCallStatus::Failed => "Failed",
410 ToolCallStatus::Rejected => "Rejected",
411 ToolCallStatus::Canceled => "Canceled",
412 }
413 )
414 }
415}
416
417#[derive(Debug, PartialEq, Clone)]
418pub enum ContentBlock {
419 Empty,
420 Markdown { markdown: Entity<Markdown> },
421 ResourceLink { resource_link: acp::ResourceLink },
422}
423
424impl ContentBlock {
425 pub fn new(
426 block: acp::ContentBlock,
427 language_registry: &Arc<LanguageRegistry>,
428 cx: &mut App,
429 ) -> Self {
430 let mut this = Self::Empty;
431 this.append(block, language_registry, cx);
432 this
433 }
434
435 pub fn new_combined(
436 blocks: impl IntoIterator<Item = acp::ContentBlock>,
437 language_registry: Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 for block in blocks {
442 this.append(block, &language_registry, cx);
443 }
444 this
445 }
446
447 pub fn append(
448 &mut self,
449 block: acp::ContentBlock,
450 language_registry: &Arc<LanguageRegistry>,
451 cx: &mut App,
452 ) {
453 if matches!(self, ContentBlock::Empty)
454 && let acp::ContentBlock::ResourceLink(resource_link) = block
455 {
456 *self = ContentBlock::ResourceLink { resource_link };
457 return;
458 }
459
460 let new_content = self.block_string_contents(block);
461
462 match self {
463 ContentBlock::Empty => {
464 *self = Self::create_markdown_block(new_content, language_registry, cx);
465 }
466 ContentBlock::Markdown { markdown } => {
467 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
468 }
469 ContentBlock::ResourceLink { resource_link } => {
470 let existing_content = Self::resource_link_md(&resource_link.uri);
471 let combined = format!("{}\n{}", existing_content, new_content);
472
473 *self = Self::create_markdown_block(combined, language_registry, cx);
474 }
475 }
476 }
477
478 fn create_markdown_block(
479 content: String,
480 language_registry: &Arc<LanguageRegistry>,
481 cx: &mut App,
482 ) -> ContentBlock {
483 ContentBlock::Markdown {
484 markdown: cx
485 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
486 }
487 }
488
489 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
490 match block {
491 acp::ContentBlock::Text(text_content) => text_content.text,
492 acp::ContentBlock::ResourceLink(resource_link) => {
493 Self::resource_link_md(&resource_link.uri)
494 }
495 acp::ContentBlock::Resource(acp::EmbeddedResource {
496 resource:
497 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
498 uri,
499 ..
500 }),
501 ..
502 }) => Self::resource_link_md(&uri),
503 acp::ContentBlock::Image(image) => Self::image_md(&image),
504 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
505 }
506 }
507
508 fn resource_link_md(uri: &str) -> String {
509 if let Some(uri) = MentionUri::parse(uri).log_err() {
510 uri.as_link().to_string()
511 } else {
512 uri.to_string()
513 }
514 }
515
516 fn image_md(_image: &acp::ImageContent) -> String {
517 "`Image`".into()
518 }
519
520 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
521 match self {
522 ContentBlock::Empty => "",
523 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
524 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
525 }
526 }
527
528 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
529 match self {
530 ContentBlock::Empty => None,
531 ContentBlock::Markdown { markdown } => Some(markdown),
532 ContentBlock::ResourceLink { .. } => None,
533 }
534 }
535
536 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
537 match self {
538 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
539 _ => None,
540 }
541 }
542}
543
544#[derive(Debug)]
545pub enum ToolCallContent {
546 ContentBlock(ContentBlock),
547 Diff(Entity<Diff>),
548 Terminal(Entity<Terminal>),
549}
550
551impl ToolCallContent {
552 pub fn from_acp(
553 content: acp::ToolCallContent,
554 language_registry: Arc<LanguageRegistry>,
555 cx: &mut App,
556 ) -> Self {
557 match content {
558 acp::ToolCallContent::Content { content } => {
559 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
560 }
561 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
562 Diff::finalized(
563 diff.path,
564 diff.old_text,
565 diff.new_text,
566 language_registry,
567 cx,
568 )
569 })),
570 }
571 }
572
573 pub fn update_from_acp(
574 &mut self,
575 new: acp::ToolCallContent,
576 language_registry: Arc<LanguageRegistry>,
577 cx: &mut App,
578 ) {
579 let needs_update = match (&self, &new) {
580 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
581 old_diff.read(cx).needs_update(
582 new_diff.old_text.as_deref().unwrap_or(""),
583 &new_diff.new_text,
584 cx,
585 )
586 }
587 _ => true,
588 };
589
590 if needs_update {
591 *self = Self::from_acp(new, language_registry, cx);
592 }
593 }
594
595 pub fn to_markdown(&self, cx: &App) -> String {
596 match self {
597 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
598 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
599 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
600 }
601 }
602}
603
604#[derive(Debug, PartialEq)]
605pub enum ToolCallUpdate {
606 UpdateFields(acp::ToolCallUpdate),
607 UpdateDiff(ToolCallUpdateDiff),
608 UpdateTerminal(ToolCallUpdateTerminal),
609}
610
611impl ToolCallUpdate {
612 fn id(&self) -> &acp::ToolCallId {
613 match self {
614 Self::UpdateFields(update) => &update.id,
615 Self::UpdateDiff(diff) => &diff.id,
616 Self::UpdateTerminal(terminal) => &terminal.id,
617 }
618 }
619}
620
621impl From<acp::ToolCallUpdate> for ToolCallUpdate {
622 fn from(update: acp::ToolCallUpdate) -> Self {
623 Self::UpdateFields(update)
624 }
625}
626
627impl From<ToolCallUpdateDiff> for ToolCallUpdate {
628 fn from(diff: ToolCallUpdateDiff) -> Self {
629 Self::UpdateDiff(diff)
630 }
631}
632
633#[derive(Debug, PartialEq)]
634pub struct ToolCallUpdateDiff {
635 pub id: acp::ToolCallId,
636 pub diff: Entity<Diff>,
637}
638
639impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
640 fn from(terminal: ToolCallUpdateTerminal) -> Self {
641 Self::UpdateTerminal(terminal)
642 }
643}
644
645#[derive(Debug, PartialEq)]
646pub struct ToolCallUpdateTerminal {
647 pub id: acp::ToolCallId,
648 pub terminal: Entity<Terminal>,
649}
650
651#[derive(Debug, Default)]
652pub struct Plan {
653 pub entries: Vec<PlanEntry>,
654}
655
656#[derive(Debug)]
657pub struct PlanStats<'a> {
658 pub in_progress_entry: Option<&'a PlanEntry>,
659 pub pending: u32,
660 pub completed: u32,
661}
662
663impl Plan {
664 pub fn is_empty(&self) -> bool {
665 self.entries.is_empty()
666 }
667
668 pub fn stats(&self) -> PlanStats<'_> {
669 let mut stats = PlanStats {
670 in_progress_entry: None,
671 pending: 0,
672 completed: 0,
673 };
674
675 for entry in &self.entries {
676 match &entry.status {
677 acp::PlanEntryStatus::Pending => {
678 stats.pending += 1;
679 }
680 acp::PlanEntryStatus::InProgress => {
681 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
682 }
683 acp::PlanEntryStatus::Completed => {
684 stats.completed += 1;
685 }
686 }
687 }
688
689 stats
690 }
691}
692
693#[derive(Debug)]
694pub struct PlanEntry {
695 pub content: Entity<Markdown>,
696 pub priority: acp::PlanEntryPriority,
697 pub status: acp::PlanEntryStatus,
698}
699
700impl PlanEntry {
701 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
702 Self {
703 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
704 priority: entry.priority,
705 status: entry.status,
706 }
707 }
708}
709
710#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
711pub struct TokenUsage {
712 pub max_tokens: u64,
713 pub used_tokens: u64,
714}
715
716impl TokenUsage {
717 pub fn ratio(&self) -> TokenUsageRatio {
718 #[cfg(debug_assertions)]
719 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
720 .unwrap_or("0.8".to_string())
721 .parse()
722 .unwrap();
723 #[cfg(not(debug_assertions))]
724 let warning_threshold: f32 = 0.8;
725
726 // When the maximum is unknown because there is no selected model,
727 // avoid showing the token limit warning.
728 if self.max_tokens == 0 {
729 TokenUsageRatio::Normal
730 } else if self.used_tokens >= self.max_tokens {
731 TokenUsageRatio::Exceeded
732 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
733 TokenUsageRatio::Warning
734 } else {
735 TokenUsageRatio::Normal
736 }
737 }
738}
739
740#[derive(Debug, Clone, PartialEq, Eq)]
741pub enum TokenUsageRatio {
742 Normal,
743 Warning,
744 Exceeded,
745}
746
747#[derive(Debug, Clone)]
748pub struct RetryStatus {
749 pub last_error: SharedString,
750 pub attempt: usize,
751 pub max_attempts: usize,
752 pub started_at: Instant,
753 pub duration: Duration,
754}
755
756pub struct AcpThread {
757 title: SharedString,
758 entries: Vec<AgentThreadEntry>,
759 plan: Plan,
760 project: Entity<Project>,
761 action_log: Entity<ActionLog>,
762 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
763 send_task: Option<Task<()>>,
764 connection: Rc<dyn AgentConnection>,
765 session_id: acp::SessionId,
766 token_usage: Option<TokenUsage>,
767 prompt_capabilities: acp::PromptCapabilities,
768 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
769}
770
771#[derive(Debug)]
772pub enum AcpThreadEvent {
773 NewEntry,
774 TitleUpdated,
775 TokenUsageUpdated,
776 EntryUpdated(usize),
777 EntriesRemoved(Range<usize>),
778 ToolAuthorizationRequired,
779 Retry(RetryStatus),
780 Stopped,
781 Error,
782 LoadError(LoadError),
783 PromptCapabilitiesUpdated,
784}
785
786impl EventEmitter<AcpThreadEvent> for AcpThread {}
787
788#[derive(PartialEq, Eq, Debug)]
789pub enum ThreadStatus {
790 Idle,
791 WaitingForToolConfirmation,
792 Generating,
793}
794
795#[derive(Debug, Clone)]
796pub enum LoadError {
797 NotInstalled {
798 error_message: SharedString,
799 install_message: SharedString,
800 install_command: String,
801 },
802 Unsupported {
803 error_message: SharedString,
804 upgrade_message: SharedString,
805 upgrade_command: String,
806 },
807 Exited {
808 status: ExitStatus,
809 },
810 Other(SharedString),
811}
812
813impl Display for LoadError {
814 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
815 match self {
816 LoadError::NotInstalled { error_message, .. }
817 | LoadError::Unsupported { error_message, .. } => {
818 write!(f, "{error_message}")
819 }
820 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
821 LoadError::Other(msg) => write!(f, "{}", msg),
822 }
823 }
824}
825
826impl Error for LoadError {}
827
828impl AcpThread {
829 pub fn new(
830 title: impl Into<SharedString>,
831 connection: Rc<dyn AgentConnection>,
832 project: Entity<Project>,
833 action_log: Entity<ActionLog>,
834 session_id: acp::SessionId,
835 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
836 cx: &mut Context<Self>,
837 ) -> Self {
838 let prompt_capabilities = *prompt_capabilities_rx.borrow();
839 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
840 loop {
841 let caps = prompt_capabilities_rx.recv().await?;
842 this.update(cx, |this, cx| {
843 this.prompt_capabilities = caps;
844 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
845 })?;
846 }
847 });
848
849 Self {
850 action_log,
851 shared_buffers: Default::default(),
852 entries: Default::default(),
853 plan: Default::default(),
854 title: title.into(),
855 project,
856 send_task: None,
857 connection,
858 session_id,
859 token_usage: None,
860 prompt_capabilities,
861 _observe_prompt_capabilities: task,
862 }
863 }
864
865 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
866 self.prompt_capabilities
867 }
868
869 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
870 &self.connection
871 }
872
873 pub fn action_log(&self) -> &Entity<ActionLog> {
874 &self.action_log
875 }
876
877 pub fn project(&self) -> &Entity<Project> {
878 &self.project
879 }
880
881 pub fn title(&self) -> SharedString {
882 self.title.clone()
883 }
884
885 pub fn entries(&self) -> &[AgentThreadEntry] {
886 &self.entries
887 }
888
889 pub fn session_id(&self) -> &acp::SessionId {
890 &self.session_id
891 }
892
893 pub fn status(&self) -> ThreadStatus {
894 if self.send_task.is_some() {
895 if self.waiting_for_tool_confirmation() {
896 ThreadStatus::WaitingForToolConfirmation
897 } else {
898 ThreadStatus::Generating
899 }
900 } else {
901 ThreadStatus::Idle
902 }
903 }
904
905 pub fn token_usage(&self) -> Option<&TokenUsage> {
906 self.token_usage.as_ref()
907 }
908
909 pub fn has_pending_edit_tool_calls(&self) -> bool {
910 for entry in self.entries.iter().rev() {
911 match entry {
912 AgentThreadEntry::UserMessage(_) => return false,
913 AgentThreadEntry::ToolCall(
914 call @ ToolCall {
915 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
916 ..
917 },
918 ) if call.diffs().next().is_some() => {
919 return true;
920 }
921 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
922 }
923 }
924
925 false
926 }
927
928 pub fn used_tools_since_last_user_message(&self) -> bool {
929 for entry in self.entries.iter().rev() {
930 match entry {
931 AgentThreadEntry::UserMessage(..) => return false,
932 AgentThreadEntry::AssistantMessage(..) => continue,
933 AgentThreadEntry::ToolCall(..) => return true,
934 }
935 }
936
937 false
938 }
939
940 pub fn handle_session_update(
941 &mut self,
942 update: acp::SessionUpdate,
943 cx: &mut Context<Self>,
944 ) -> Result<(), acp::Error> {
945 match update {
946 acp::SessionUpdate::UserMessageChunk { content } => {
947 self.push_user_content_block(None, content, cx);
948 }
949 acp::SessionUpdate::AgentMessageChunk { content } => {
950 self.push_assistant_content_block(content, false, cx);
951 }
952 acp::SessionUpdate::AgentThoughtChunk { content } => {
953 self.push_assistant_content_block(content, true, cx);
954 }
955 acp::SessionUpdate::ToolCall(tool_call) => {
956 self.upsert_tool_call(tool_call, cx)?;
957 }
958 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
959 self.update_tool_call(tool_call_update, cx)?;
960 }
961 acp::SessionUpdate::Plan(plan) => {
962 self.update_plan(plan, cx);
963 }
964 }
965 Ok(())
966 }
967
968 pub fn push_user_content_block(
969 &mut self,
970 prompt_id: Option<acp::PromptId>,
971 chunk: acp::ContentBlock,
972 cx: &mut Context<Self>,
973 ) {
974 let language_registry = self.project.read(cx).languages().clone();
975 let entries_len = self.entries.len();
976
977 if let Some(last_entry) = self.entries.last_mut()
978 && let AgentThreadEntry::UserMessage(UserMessage {
979 prompt_id: id,
980 content,
981 chunks,
982 ..
983 }) = last_entry
984 {
985 *id = prompt_id.or(id.take());
986 content.append(chunk.clone(), &language_registry, cx);
987 chunks.push(chunk);
988 let idx = entries_len - 1;
989 cx.emit(AcpThreadEvent::EntryUpdated(idx));
990 } else {
991 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
992 self.push_entry(
993 AgentThreadEntry::UserMessage(UserMessage {
994 prompt_id,
995 content,
996 chunks: vec![chunk],
997 checkpoint: None,
998 }),
999 cx,
1000 );
1001 }
1002 }
1003
1004 pub fn push_assistant_content_block(
1005 &mut self,
1006 chunk: acp::ContentBlock,
1007 is_thought: bool,
1008 cx: &mut Context<Self>,
1009 ) {
1010 let language_registry = self.project.read(cx).languages().clone();
1011 let entries_len = self.entries.len();
1012 if let Some(last_entry) = self.entries.last_mut()
1013 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1014 {
1015 let idx = entries_len - 1;
1016 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1017 match (chunks.last_mut(), is_thought) {
1018 (Some(AssistantMessageChunk::Message { block }), false)
1019 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1020 block.append(chunk, &language_registry, cx)
1021 }
1022 _ => {
1023 let block = ContentBlock::new(chunk, &language_registry, cx);
1024 if is_thought {
1025 chunks.push(AssistantMessageChunk::Thought { block })
1026 } else {
1027 chunks.push(AssistantMessageChunk::Message { block })
1028 }
1029 }
1030 }
1031 } else {
1032 let block = ContentBlock::new(chunk, &language_registry, cx);
1033 let chunk = if is_thought {
1034 AssistantMessageChunk::Thought { block }
1035 } else {
1036 AssistantMessageChunk::Message { block }
1037 };
1038
1039 self.push_entry(
1040 AgentThreadEntry::AssistantMessage(AssistantMessage {
1041 chunks: vec![chunk],
1042 }),
1043 cx,
1044 );
1045 }
1046 }
1047
1048 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1049 self.entries.push(entry);
1050 cx.emit(AcpThreadEvent::NewEntry);
1051 }
1052
1053 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1054 self.connection.set_title(&self.session_id, cx).is_some()
1055 }
1056
1057 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1058 if title != self.title {
1059 self.title = title.clone();
1060 cx.emit(AcpThreadEvent::TitleUpdated);
1061 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1062 return set_title.run(title, cx);
1063 }
1064 }
1065 Task::ready(Ok(()))
1066 }
1067
1068 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1069 self.token_usage = usage;
1070 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1071 }
1072
1073 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1074 cx.emit(AcpThreadEvent::Retry(status));
1075 }
1076
1077 pub fn update_tool_call(
1078 &mut self,
1079 update: impl Into<ToolCallUpdate>,
1080 cx: &mut Context<Self>,
1081 ) -> Result<()> {
1082 let update = update.into();
1083 let languages = self.project.read(cx).languages().clone();
1084
1085 let (ix, current_call) = self
1086 .tool_call_mut(update.id())
1087 .context("Tool call not found")?;
1088 match update {
1089 ToolCallUpdate::UpdateFields(update) => {
1090 let location_updated = update.fields.locations.is_some();
1091 current_call.update_fields(update.fields, languages, cx);
1092 if location_updated {
1093 self.resolve_locations(update.id, cx);
1094 }
1095 }
1096 ToolCallUpdate::UpdateDiff(update) => {
1097 current_call.content.clear();
1098 current_call
1099 .content
1100 .push(ToolCallContent::Diff(update.diff));
1101 }
1102 ToolCallUpdate::UpdateTerminal(update) => {
1103 current_call.content.clear();
1104 current_call
1105 .content
1106 .push(ToolCallContent::Terminal(update.terminal));
1107 }
1108 }
1109
1110 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1111
1112 Ok(())
1113 }
1114
1115 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1116 pub fn upsert_tool_call(
1117 &mut self,
1118 tool_call: acp::ToolCall,
1119 cx: &mut Context<Self>,
1120 ) -> Result<(), acp::Error> {
1121 let status = tool_call.status.into();
1122 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1123 }
1124
1125 /// Fails if id does not match an existing entry.
1126 pub fn upsert_tool_call_inner(
1127 &mut self,
1128 tool_call_update: acp::ToolCallUpdate,
1129 status: ToolCallStatus,
1130 cx: &mut Context<Self>,
1131 ) -> Result<(), acp::Error> {
1132 let language_registry = self.project.read(cx).languages().clone();
1133 let id = tool_call_update.id.clone();
1134
1135 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1136 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1137 current_call.status = status;
1138
1139 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1140 } else {
1141 let call =
1142 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1143 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1144 };
1145
1146 self.resolve_locations(id, cx);
1147 Ok(())
1148 }
1149
1150 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1151 // The tool call we are looking for is typically the last one, or very close to the end.
1152 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1153 self.entries
1154 .iter_mut()
1155 .enumerate()
1156 .rev()
1157 .find_map(|(index, tool_call)| {
1158 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1159 && &tool_call.id == id
1160 {
1161 Some((index, tool_call))
1162 } else {
1163 None
1164 }
1165 })
1166 }
1167
1168 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1169 self.entries
1170 .iter()
1171 .enumerate()
1172 .rev()
1173 .find_map(|(index, tool_call)| {
1174 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1175 && &tool_call.id == id
1176 {
1177 Some((index, tool_call))
1178 } else {
1179 None
1180 }
1181 })
1182 }
1183
1184 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1185 let project = self.project.clone();
1186 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1187 return;
1188 };
1189 let task = tool_call.resolve_locations(project, cx);
1190 cx.spawn(async move |this, cx| {
1191 let resolved_locations = task.await;
1192 this.update(cx, |this, cx| {
1193 let project = this.project.clone();
1194 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1195 return;
1196 };
1197 if let Some(Some(location)) = resolved_locations.last() {
1198 project.update(cx, |project, cx| {
1199 if let Some(agent_location) = project.agent_location() {
1200 let should_ignore = agent_location.buffer == location.buffer
1201 && location
1202 .buffer
1203 .update(cx, |buffer, _| {
1204 let snapshot = buffer.snapshot();
1205 let old_position =
1206 agent_location.position.to_point(&snapshot);
1207 let new_position = location.position.to_point(&snapshot);
1208 // ignore this so that when we get updates from the edit tool
1209 // the position doesn't reset to the startof line
1210 old_position.row == new_position.row
1211 && old_position.column > new_position.column
1212 })
1213 .ok()
1214 .unwrap_or_default();
1215 if !should_ignore {
1216 project.set_agent_location(Some(location.clone()), cx);
1217 }
1218 }
1219 });
1220 }
1221 if tool_call.resolved_locations != resolved_locations {
1222 tool_call.resolved_locations = resolved_locations;
1223 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224 }
1225 })
1226 })
1227 .detach();
1228 }
1229
1230 pub fn request_tool_call_authorization(
1231 &mut self,
1232 tool_call: acp::ToolCallUpdate,
1233 options: Vec<acp::PermissionOption>,
1234 cx: &mut Context<Self>,
1235 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1236 let (tx, rx) = oneshot::channel();
1237
1238 let status = ToolCallStatus::WaitingForConfirmation {
1239 options,
1240 respond_tx: tx,
1241 };
1242
1243 self.upsert_tool_call_inner(tool_call, status, cx)?;
1244 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1245 Ok(rx)
1246 }
1247
1248 pub fn authorize_tool_call(
1249 &mut self,
1250 id: acp::ToolCallId,
1251 option_id: acp::PermissionOptionId,
1252 option_kind: acp::PermissionOptionKind,
1253 cx: &mut Context<Self>,
1254 ) {
1255 let Some((ix, call)) = self.tool_call_mut(&id) else {
1256 return;
1257 };
1258
1259 let new_status = match option_kind {
1260 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1261 ToolCallStatus::Rejected
1262 }
1263 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1264 ToolCallStatus::InProgress
1265 }
1266 };
1267
1268 let curr_status = mem::replace(&mut call.status, new_status);
1269
1270 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1271 respond_tx.send(option_id).log_err();
1272 } else if cfg!(debug_assertions) {
1273 panic!("tried to authorize an already authorized tool call");
1274 }
1275
1276 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1277 }
1278
1279 /// Returns true if the last turn is awaiting tool authorization
1280 pub fn waiting_for_tool_confirmation(&self) -> bool {
1281 for entry in self.entries.iter().rev() {
1282 match &entry {
1283 AgentThreadEntry::ToolCall(call) => match call.status {
1284 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1285 ToolCallStatus::Pending
1286 | ToolCallStatus::InProgress
1287 | ToolCallStatus::Completed
1288 | ToolCallStatus::Failed
1289 | ToolCallStatus::Rejected
1290 | ToolCallStatus::Canceled => continue,
1291 },
1292 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1293 // Reached the beginning of the turn
1294 return false;
1295 }
1296 }
1297 }
1298 false
1299 }
1300
1301 pub fn plan(&self) -> &Plan {
1302 &self.plan
1303 }
1304
1305 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1306 let new_entries_len = request.entries.len();
1307 let mut new_entries = request.entries.into_iter();
1308
1309 // Reuse existing markdown to prevent flickering
1310 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1311 let PlanEntry {
1312 content,
1313 priority,
1314 status,
1315 } = old;
1316 content.update(cx, |old, cx| {
1317 old.replace(new.content, cx);
1318 });
1319 *priority = new.priority;
1320 *status = new.status;
1321 }
1322 for new in new_entries {
1323 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1324 }
1325 self.plan.entries.truncate(new_entries_len);
1326
1327 cx.notify();
1328 }
1329
1330 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1331 self.plan
1332 .entries
1333 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1334 cx.notify();
1335 }
1336
1337 #[cfg(any(test, feature = "test-support"))]
1338 pub fn send_raw(
1339 &mut self,
1340 message: &str,
1341 cx: &mut Context<Self>,
1342 ) -> BoxFuture<'static, Result<()>> {
1343 self.send(
1344 new_prompt_id(),
1345 vec![acp::ContentBlock::Text(acp::TextContent {
1346 text: message.to_string(),
1347 annotations: None,
1348 })],
1349 cx,
1350 )
1351 }
1352
1353 pub fn send(
1354 &mut self,
1355 prompt_id: acp::PromptId,
1356 message: Vec<acp::ContentBlock>,
1357 cx: &mut Context<Self>,
1358 ) -> BoxFuture<'static, Result<()>> {
1359 let block = ContentBlock::new_combined(
1360 message.clone(),
1361 self.project.read(cx).languages().clone(),
1362 cx,
1363 );
1364 let request = acp::PromptRequest {
1365 prompt_id: Some(prompt_id.clone()),
1366 prompt: message.clone(),
1367 session_id: self.session_id.clone(),
1368 };
1369 let git_store = self.project.read(cx).git_store().clone();
1370
1371 self.run_turn(cx, async move |this, cx| {
1372 this.update(cx, |this, cx| {
1373 this.push_entry(
1374 AgentThreadEntry::UserMessage(UserMessage {
1375 prompt_id: Some(prompt_id),
1376 content: block,
1377 chunks: message,
1378 checkpoint: None,
1379 }),
1380 cx,
1381 );
1382 })
1383 .ok();
1384
1385 let old_checkpoint = git_store
1386 .update(cx, |git, cx| git.checkpoint(cx))?
1387 .await
1388 .context("failed to get old checkpoint")
1389 .log_err();
1390 this.update(cx, |this, cx| {
1391 if let Some((_ix, message)) = this.last_user_message() {
1392 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1393 git_checkpoint,
1394 show: false,
1395 });
1396 }
1397 this.connection.prompt(request, cx)
1398 })?
1399 .await
1400 })
1401 }
1402
1403 pub fn can_resume(&self, cx: &App) -> bool {
1404 self.connection.resume(&self.session_id, cx).is_some()
1405 }
1406
1407 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1408 self.run_turn(cx, async move |this, cx| {
1409 this.update(cx, |this, cx| {
1410 this.connection
1411 .resume(&this.session_id, cx)
1412 .map(|resume| resume.run(cx))
1413 })?
1414 .context("resuming a session is not supported")?
1415 .await
1416 })
1417 }
1418
1419 fn run_turn(
1420 &mut self,
1421 cx: &mut Context<Self>,
1422 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1423 ) -> BoxFuture<'static, Result<()>> {
1424 self.clear_completed_plan_entries(cx);
1425
1426 let (tx, rx) = oneshot::channel();
1427 let cancel_task = self.cancel(cx);
1428
1429 self.send_task = Some(cx.spawn(async move |this, cx| {
1430 cancel_task.await;
1431 tx.send(f(this, cx).await).ok();
1432 }));
1433
1434 cx.spawn(async move |this, cx| {
1435 let response = rx.await;
1436
1437 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1438 .await?;
1439
1440 this.update(cx, |this, cx| {
1441 this.project
1442 .update(cx, |project, cx| project.set_agent_location(None, cx));
1443 match response {
1444 Ok(Err(e)) => {
1445 this.send_task.take();
1446 cx.emit(AcpThreadEvent::Error);
1447 Err(e)
1448 }
1449 result => {
1450 let canceled = matches!(
1451 result,
1452 Ok(Ok(acp::PromptResponse {
1453 stop_reason: acp::StopReason::Cancelled
1454 }))
1455 );
1456
1457 // We only take the task if the current prompt wasn't canceled.
1458 //
1459 // This prompt may have been canceled because another one was sent
1460 // while it was still generating. In these cases, dropping `send_task`
1461 // would cause the next generation to be canceled.
1462 if !canceled {
1463 this.send_task.take();
1464 }
1465
1466 // Truncate entries if the last prompt was refused.
1467 if let Ok(Ok(acp::PromptResponse {
1468 stop_reason: acp::StopReason::Refusal,
1469 })) = result
1470 && let Some((ix, _)) = this.last_user_message()
1471 {
1472 let range = ix..this.entries.len();
1473 this.entries.truncate(ix);
1474 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1475 }
1476
1477 cx.emit(AcpThreadEvent::Stopped);
1478 Ok(())
1479 }
1480 }
1481 })?
1482 })
1483 .boxed()
1484 }
1485
1486 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1487 let Some(send_task) = self.send_task.take() else {
1488 return Task::ready(());
1489 };
1490
1491 for entry in self.entries.iter_mut() {
1492 if let AgentThreadEntry::ToolCall(call) = entry {
1493 let cancel = matches!(
1494 call.status,
1495 ToolCallStatus::Pending
1496 | ToolCallStatus::WaitingForConfirmation { .. }
1497 | ToolCallStatus::InProgress
1498 );
1499
1500 if cancel {
1501 call.status = ToolCallStatus::Canceled;
1502 }
1503 }
1504 }
1505
1506 self.connection.cancel(&self.session_id, cx);
1507
1508 // Wait for the send task to complete
1509 cx.foreground_executor().spawn(send_task)
1510 }
1511
1512 /// Rewinds this thread to before the entry at `index`, removing it and all
1513 /// subsequent entries while reverting any changes made from that point.
1514 pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context<Self>) -> Task<Result<()>> {
1515 let Some(rewind) = self.connection.rewind(&self.session_id, cx) else {
1516 return Task::ready(Err(anyhow!("not supported")));
1517 };
1518 let Some(message) = self.user_message(&id) else {
1519 return Task::ready(Err(anyhow!("message not found")));
1520 };
1521
1522 let checkpoint = message
1523 .checkpoint
1524 .as_ref()
1525 .map(|c| c.git_checkpoint.clone());
1526
1527 let git_store = self.project.read(cx).git_store().clone();
1528 cx.spawn(async move |this, cx| {
1529 if let Some(checkpoint) = checkpoint {
1530 git_store
1531 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1532 .await?;
1533 }
1534
1535 cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?;
1536 this.update(cx, |this, cx| {
1537 if let Some((ix, _)) = this.user_message_mut(&id) {
1538 let range = ix..this.entries.len();
1539 this.entries.truncate(ix);
1540 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1541 }
1542 })
1543 })
1544 }
1545
1546 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1547 let git_store = self.project.read(cx).git_store().clone();
1548
1549 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1550 if let Some(checkpoint) = message.checkpoint.as_ref() {
1551 checkpoint.git_checkpoint.clone()
1552 } else {
1553 return Task::ready(Ok(()));
1554 }
1555 } else {
1556 return Task::ready(Ok(()));
1557 };
1558
1559 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1560 cx.spawn(async move |this, cx| {
1561 let new_checkpoint = new_checkpoint
1562 .await
1563 .context("failed to get new checkpoint")
1564 .log_err();
1565 if let Some(new_checkpoint) = new_checkpoint {
1566 let equal = git_store
1567 .update(cx, |git, cx| {
1568 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1569 })?
1570 .await
1571 .unwrap_or(true);
1572 this.update(cx, |this, cx| {
1573 let (ix, message) = this.last_user_message().context("no user message")?;
1574 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1575 checkpoint.show = !equal;
1576 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1577 anyhow::Ok(())
1578 })??;
1579 }
1580
1581 Ok(())
1582 })
1583 }
1584
1585 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1586 self.entries
1587 .iter_mut()
1588 .enumerate()
1589 .rev()
1590 .find_map(|(ix, entry)| {
1591 if let AgentThreadEntry::UserMessage(message) = entry {
1592 Some((ix, message))
1593 } else {
1594 None
1595 }
1596 })
1597 }
1598
1599 fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> {
1600 self.entries.iter().find_map(|entry| {
1601 if let AgentThreadEntry::UserMessage(message) = entry {
1602 if message.prompt_id.as_ref() == Some(id) {
1603 Some(message)
1604 } else {
1605 None
1606 }
1607 } else {
1608 None
1609 }
1610 })
1611 }
1612
1613 fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> {
1614 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1615 if let AgentThreadEntry::UserMessage(message) = entry {
1616 if message.prompt_id.as_ref() == Some(id) {
1617 Some((ix, message))
1618 } else {
1619 None
1620 }
1621 } else {
1622 None
1623 }
1624 })
1625 }
1626
1627 pub fn read_text_file(
1628 &self,
1629 path: PathBuf,
1630 line: Option<u32>,
1631 limit: Option<u32>,
1632 reuse_shared_snapshot: bool,
1633 cx: &mut Context<Self>,
1634 ) -> Task<Result<String>> {
1635 let project = self.project.clone();
1636 let action_log = self.action_log.clone();
1637 cx.spawn(async move |this, cx| {
1638 let load = project.update(cx, |project, cx| {
1639 let path = project
1640 .project_path_for_absolute_path(&path, cx)
1641 .context("invalid path")?;
1642 anyhow::Ok(project.open_buffer(path, cx))
1643 });
1644 let buffer = load??.await?;
1645
1646 let snapshot = if reuse_shared_snapshot {
1647 this.read_with(cx, |this, _| {
1648 this.shared_buffers.get(&buffer.clone()).cloned()
1649 })
1650 .log_err()
1651 .flatten()
1652 } else {
1653 None
1654 };
1655
1656 let snapshot = if let Some(snapshot) = snapshot {
1657 snapshot
1658 } else {
1659 action_log.update(cx, |action_log, cx| {
1660 action_log.buffer_read(buffer.clone(), cx);
1661 })?;
1662 project.update(cx, |project, cx| {
1663 let position = buffer
1664 .read(cx)
1665 .snapshot()
1666 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1667 project.set_agent_location(
1668 Some(AgentLocation {
1669 buffer: buffer.downgrade(),
1670 position,
1671 }),
1672 cx,
1673 );
1674 })?;
1675
1676 buffer.update(cx, |buffer, _| buffer.snapshot())?
1677 };
1678
1679 this.update(cx, |this, _| {
1680 let text = snapshot.text();
1681 this.shared_buffers.insert(buffer.clone(), snapshot);
1682 if line.is_none() && limit.is_none() {
1683 return Ok(text);
1684 }
1685 let limit = limit.unwrap_or(u32::MAX) as usize;
1686 let Some(line) = line else {
1687 return Ok(text.lines().take(limit).collect::<String>());
1688 };
1689
1690 let count = text.lines().count();
1691 if count < line as usize {
1692 anyhow::bail!("There are only {} lines", count);
1693 }
1694 Ok(text
1695 .lines()
1696 .skip(line as usize + 1)
1697 .take(limit)
1698 .collect::<String>())
1699 })?
1700 })
1701 }
1702
1703 pub fn write_text_file(
1704 &self,
1705 path: PathBuf,
1706 content: String,
1707 cx: &mut Context<Self>,
1708 ) -> Task<Result<()>> {
1709 let project = self.project.clone();
1710 let action_log = self.action_log.clone();
1711 cx.spawn(async move |this, cx| {
1712 let load = project.update(cx, |project, cx| {
1713 let path = project
1714 .project_path_for_absolute_path(&path, cx)
1715 .context("invalid path")?;
1716 anyhow::Ok(project.open_buffer(path, cx))
1717 });
1718 let buffer = load??.await?;
1719 let snapshot = this.update(cx, |this, cx| {
1720 this.shared_buffers
1721 .get(&buffer)
1722 .cloned()
1723 .unwrap_or_else(|| buffer.read(cx).snapshot())
1724 })?;
1725 let edits = cx
1726 .background_executor()
1727 .spawn(async move {
1728 let old_text = snapshot.text();
1729 text_diff(old_text.as_str(), &content)
1730 .into_iter()
1731 .map(|(range, replacement)| {
1732 (
1733 snapshot.anchor_after(range.start)
1734 ..snapshot.anchor_before(range.end),
1735 replacement,
1736 )
1737 })
1738 .collect::<Vec<_>>()
1739 })
1740 .await;
1741
1742 project.update(cx, |project, cx| {
1743 project.set_agent_location(
1744 Some(AgentLocation {
1745 buffer: buffer.downgrade(),
1746 position: edits
1747 .last()
1748 .map(|(range, _)| range.end)
1749 .unwrap_or(Anchor::MIN),
1750 }),
1751 cx,
1752 );
1753 })?;
1754
1755 let format_on_save = cx.update(|cx| {
1756 action_log.update(cx, |action_log, cx| {
1757 action_log.buffer_read(buffer.clone(), cx);
1758 });
1759
1760 let format_on_save = buffer.update(cx, |buffer, cx| {
1761 buffer.edit(edits, None, cx);
1762
1763 let settings = language::language_settings::language_settings(
1764 buffer.language().map(|l| l.name()),
1765 buffer.file(),
1766 cx,
1767 );
1768
1769 settings.format_on_save != FormatOnSave::Off
1770 });
1771 action_log.update(cx, |action_log, cx| {
1772 action_log.buffer_edited(buffer.clone(), cx);
1773 });
1774 format_on_save
1775 })?;
1776
1777 if format_on_save {
1778 let format_task = project.update(cx, |project, cx| {
1779 project.format(
1780 HashSet::from_iter([buffer.clone()]),
1781 LspFormatTarget::Buffers,
1782 false,
1783 FormatTrigger::Save,
1784 cx,
1785 )
1786 })?;
1787 format_task.await.log_err();
1788
1789 action_log.update(cx, |action_log, cx| {
1790 action_log.buffer_edited(buffer.clone(), cx);
1791 })?;
1792 }
1793
1794 project
1795 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1796 .await
1797 })
1798 }
1799
1800 pub fn to_markdown(&self, cx: &App) -> String {
1801 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1802 }
1803
1804 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1805 cx.emit(AcpThreadEvent::LoadError(error));
1806 }
1807}
1808
1809fn markdown_for_raw_output(
1810 raw_output: &serde_json::Value,
1811 language_registry: &Arc<LanguageRegistry>,
1812 cx: &mut App,
1813) -> Option<Entity<Markdown>> {
1814 match raw_output {
1815 serde_json::Value::Null => None,
1816 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1817 Markdown::new(
1818 value.to_string().into(),
1819 Some(language_registry.clone()),
1820 None,
1821 cx,
1822 )
1823 })),
1824 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1825 Markdown::new(
1826 value.to_string().into(),
1827 Some(language_registry.clone()),
1828 None,
1829 cx,
1830 )
1831 })),
1832 serde_json::Value::String(value) => Some(cx.new(|cx| {
1833 Markdown::new(
1834 value.clone().into(),
1835 Some(language_registry.clone()),
1836 None,
1837 cx,
1838 )
1839 })),
1840 value => Some(cx.new(|cx| {
1841 Markdown::new(
1842 format!("```json\n{}\n```", value).into(),
1843 Some(language_registry.clone()),
1844 None,
1845 cx,
1846 )
1847 })),
1848 }
1849}
1850
1851#[cfg(test)]
1852mod tests {
1853 use super::*;
1854 use anyhow::anyhow;
1855 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1856 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1857 use indoc::indoc;
1858 use project::{FakeFs, Fs};
1859 use rand::Rng as _;
1860 use serde_json::json;
1861 use settings::SettingsStore;
1862 use smol::stream::StreamExt as _;
1863 use std::{
1864 any::Any,
1865 cell::RefCell,
1866 path::Path,
1867 rc::Rc,
1868 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1869 time::Duration,
1870 };
1871 use util::path;
1872
1873 fn init_test(cx: &mut TestAppContext) {
1874 env_logger::try_init().ok();
1875 cx.update(|cx| {
1876 let settings_store = SettingsStore::test(cx);
1877 cx.set_global(settings_store);
1878 Project::init_settings(cx);
1879 language::init(cx);
1880 });
1881 }
1882
1883 #[gpui::test]
1884 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1885 init_test(cx);
1886
1887 let fs = FakeFs::new(cx.executor());
1888 let project = Project::test(fs, [], cx).await;
1889 let connection = Rc::new(FakeAgentConnection::new());
1890 let thread = cx
1891 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1892 .await
1893 .unwrap();
1894
1895 // Test creating a new user message
1896 thread.update(cx, |thread, cx| {
1897 thread.push_user_content_block(
1898 None,
1899 acp::ContentBlock::Text(acp::TextContent {
1900 annotations: None,
1901 text: "Hello, ".to_string(),
1902 }),
1903 cx,
1904 );
1905 });
1906
1907 thread.update(cx, |thread, cx| {
1908 assert_eq!(thread.entries.len(), 1);
1909 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1910 assert_eq!(user_msg.prompt_id, None);
1911 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1912 } else {
1913 panic!("Expected UserMessage");
1914 }
1915 });
1916
1917 // Test appending to existing user message
1918 let message_1_id = new_prompt_id();
1919 thread.update(cx, |thread, cx| {
1920 thread.push_user_content_block(
1921 Some(message_1_id.clone()),
1922 acp::ContentBlock::Text(acp::TextContent {
1923 annotations: None,
1924 text: "world!".to_string(),
1925 }),
1926 cx,
1927 );
1928 });
1929
1930 thread.update(cx, |thread, cx| {
1931 assert_eq!(thread.entries.len(), 1);
1932 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1933 assert_eq!(user_msg.prompt_id, Some(message_1_id));
1934 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1935 } else {
1936 panic!("Expected UserMessage");
1937 }
1938 });
1939
1940 // Test creating new user message after assistant message
1941 thread.update(cx, |thread, cx| {
1942 thread.push_assistant_content_block(
1943 acp::ContentBlock::Text(acp::TextContent {
1944 annotations: None,
1945 text: "Assistant response".to_string(),
1946 }),
1947 false,
1948 cx,
1949 );
1950 });
1951
1952 let message_2_id = new_prompt_id();
1953 thread.update(cx, |thread, cx| {
1954 thread.push_user_content_block(
1955 Some(message_2_id.clone()),
1956 acp::ContentBlock::Text(acp::TextContent {
1957 annotations: None,
1958 text: "New user message".to_string(),
1959 }),
1960 cx,
1961 );
1962 });
1963
1964 thread.update(cx, |thread, cx| {
1965 assert_eq!(thread.entries.len(), 3);
1966 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1967 assert_eq!(user_msg.prompt_id, Some(message_2_id));
1968 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1969 } else {
1970 panic!("Expected UserMessage at index 2");
1971 }
1972 });
1973 }
1974
1975 #[gpui::test]
1976 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1977 init_test(cx);
1978
1979 let fs = FakeFs::new(cx.executor());
1980 let project = Project::test(fs, [], cx).await;
1981 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1982 |_, thread, mut cx| {
1983 async move {
1984 thread.update(&mut cx, |thread, cx| {
1985 thread
1986 .handle_session_update(
1987 acp::SessionUpdate::AgentThoughtChunk {
1988 content: "Thinking ".into(),
1989 },
1990 cx,
1991 )
1992 .unwrap();
1993 thread
1994 .handle_session_update(
1995 acp::SessionUpdate::AgentThoughtChunk {
1996 content: "hard!".into(),
1997 },
1998 cx,
1999 )
2000 .unwrap();
2001 })?;
2002 Ok(acp::PromptResponse {
2003 stop_reason: acp::StopReason::EndTurn,
2004 })
2005 }
2006 .boxed_local()
2007 },
2008 ));
2009
2010 let thread = cx
2011 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2012 .await
2013 .unwrap();
2014
2015 thread
2016 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2017 .await
2018 .unwrap();
2019
2020 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2021 assert_eq!(
2022 output,
2023 indoc! {r#"
2024 ## User
2025
2026 Hello from Zed!
2027
2028 ## Assistant
2029
2030 <thinking>
2031 Thinking hard!
2032 </thinking>
2033
2034 "#}
2035 );
2036 }
2037
2038 #[gpui::test]
2039 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2040 init_test(cx);
2041
2042 let fs = FakeFs::new(cx.executor());
2043 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2044 .await;
2045 let project = Project::test(fs.clone(), [], cx).await;
2046 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2047 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2048 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2049 move |_, thread, mut cx| {
2050 let read_file_tx = read_file_tx.clone();
2051 async move {
2052 let content = thread
2053 .update(&mut cx, |thread, cx| {
2054 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2055 })
2056 .unwrap()
2057 .await
2058 .unwrap();
2059 assert_eq!(content, "one\ntwo\nthree\n");
2060 read_file_tx.take().unwrap().send(()).unwrap();
2061 thread
2062 .update(&mut cx, |thread, cx| {
2063 thread.write_text_file(
2064 path!("/tmp/foo").into(),
2065 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2066 cx,
2067 )
2068 })
2069 .unwrap()
2070 .await
2071 .unwrap();
2072 Ok(acp::PromptResponse {
2073 stop_reason: acp::StopReason::EndTurn,
2074 })
2075 }
2076 .boxed_local()
2077 },
2078 ));
2079
2080 let (worktree, pathbuf) = project
2081 .update(cx, |project, cx| {
2082 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2083 })
2084 .await
2085 .unwrap();
2086 let buffer = project
2087 .update(cx, |project, cx| {
2088 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2089 })
2090 .await
2091 .unwrap();
2092
2093 let thread = cx
2094 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2095 .await
2096 .unwrap();
2097
2098 let request = thread.update(cx, |thread, cx| {
2099 thread.send_raw("Extend the count in /tmp/foo", cx)
2100 });
2101 read_file_rx.await.ok();
2102 buffer.update(cx, |buffer, cx| {
2103 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2104 });
2105 cx.run_until_parked();
2106 assert_eq!(
2107 buffer.read_with(cx, |buffer, _| buffer.text()),
2108 "zero\none\ntwo\nthree\nfour\nfive\n"
2109 );
2110 assert_eq!(
2111 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2112 "zero\none\ntwo\nthree\nfour\nfive\n"
2113 );
2114 request.await.unwrap();
2115 }
2116
2117 #[gpui::test]
2118 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2119 init_test(cx);
2120
2121 let fs = FakeFs::new(cx.executor());
2122 let project = Project::test(fs, [], cx).await;
2123 let id = acp::ToolCallId("test".into());
2124
2125 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2126 let id = id.clone();
2127 move |_, thread, mut cx| {
2128 let id = id.clone();
2129 async move {
2130 thread
2131 .update(&mut cx, |thread, cx| {
2132 thread.handle_session_update(
2133 acp::SessionUpdate::ToolCall(acp::ToolCall {
2134 id: id.clone(),
2135 title: "Label".into(),
2136 kind: acp::ToolKind::Fetch,
2137 status: acp::ToolCallStatus::InProgress,
2138 content: vec![],
2139 locations: vec![],
2140 raw_input: None,
2141 raw_output: None,
2142 }),
2143 cx,
2144 )
2145 })
2146 .unwrap()
2147 .unwrap();
2148 Ok(acp::PromptResponse {
2149 stop_reason: acp::StopReason::EndTurn,
2150 })
2151 }
2152 .boxed_local()
2153 }
2154 }));
2155
2156 let thread = cx
2157 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2158 .await
2159 .unwrap();
2160
2161 let request = thread.update(cx, |thread, cx| {
2162 thread.send_raw("Fetch https://example.com", cx)
2163 });
2164
2165 run_until_first_tool_call(&thread, cx).await;
2166
2167 thread.read_with(cx, |thread, _| {
2168 assert!(matches!(
2169 thread.entries[1],
2170 AgentThreadEntry::ToolCall(ToolCall {
2171 status: ToolCallStatus::InProgress,
2172 ..
2173 })
2174 ));
2175 });
2176
2177 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2178
2179 thread.read_with(cx, |thread, _| {
2180 assert!(matches!(
2181 &thread.entries[1],
2182 AgentThreadEntry::ToolCall(ToolCall {
2183 status: ToolCallStatus::Canceled,
2184 ..
2185 })
2186 ));
2187 });
2188
2189 thread
2190 .update(cx, |thread, cx| {
2191 thread.handle_session_update(
2192 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2193 id,
2194 fields: acp::ToolCallUpdateFields {
2195 status: Some(acp::ToolCallStatus::Completed),
2196 ..Default::default()
2197 },
2198 }),
2199 cx,
2200 )
2201 })
2202 .unwrap();
2203
2204 request.await.unwrap();
2205
2206 thread.read_with(cx, |thread, _| {
2207 assert!(matches!(
2208 thread.entries[1],
2209 AgentThreadEntry::ToolCall(ToolCall {
2210 status: ToolCallStatus::Completed,
2211 ..
2212 })
2213 ));
2214 });
2215 }
2216
2217 #[gpui::test]
2218 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2219 init_test(cx);
2220 let fs = FakeFs::new(cx.background_executor.clone());
2221 fs.insert_tree(path!("/test"), json!({})).await;
2222 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2223
2224 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2225 move |_, thread, mut cx| {
2226 async move {
2227 thread
2228 .update(&mut cx, |thread, cx| {
2229 thread.handle_session_update(
2230 acp::SessionUpdate::ToolCall(acp::ToolCall {
2231 id: acp::ToolCallId("test".into()),
2232 title: "Label".into(),
2233 kind: acp::ToolKind::Edit,
2234 status: acp::ToolCallStatus::Completed,
2235 content: vec![acp::ToolCallContent::Diff {
2236 diff: acp::Diff {
2237 path: "/test/test.txt".into(),
2238 old_text: None,
2239 new_text: "foo".into(),
2240 },
2241 }],
2242 locations: vec![],
2243 raw_input: None,
2244 raw_output: None,
2245 }),
2246 cx,
2247 )
2248 })
2249 .unwrap()
2250 .unwrap();
2251 Ok(acp::PromptResponse {
2252 stop_reason: acp::StopReason::EndTurn,
2253 })
2254 }
2255 .boxed_local()
2256 }
2257 }));
2258
2259 let thread = cx
2260 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2261 .await
2262 .unwrap();
2263
2264 cx.update(|cx| {
2265 thread.update(cx, |thread, cx| {
2266 thread.send(new_prompt_id(), vec!["Hi".into()], cx)
2267 })
2268 })
2269 .await
2270 .unwrap();
2271
2272 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2273 }
2274
2275 #[gpui::test(iterations = 10)]
2276 async fn test_checkpoints(cx: &mut TestAppContext) {
2277 init_test(cx);
2278 let fs = FakeFs::new(cx.background_executor.clone());
2279 fs.insert_tree(
2280 path!("/test"),
2281 json!({
2282 ".git": {}
2283 }),
2284 )
2285 .await;
2286 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2287
2288 let simulate_changes = Arc::new(AtomicBool::new(true));
2289 let next_filename = Arc::new(AtomicUsize::new(0));
2290 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2291 let simulate_changes = simulate_changes.clone();
2292 let next_filename = next_filename.clone();
2293 let fs = fs.clone();
2294 move |request, thread, mut cx| {
2295 let fs = fs.clone();
2296 let simulate_changes = simulate_changes.clone();
2297 let next_filename = next_filename.clone();
2298 async move {
2299 if simulate_changes.load(SeqCst) {
2300 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2301 fs.write(Path::new(&filename), b"").await?;
2302 }
2303
2304 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2305 panic!("expected text content block");
2306 };
2307 thread.update(&mut cx, |thread, cx| {
2308 thread
2309 .handle_session_update(
2310 acp::SessionUpdate::AgentMessageChunk {
2311 content: content.text.to_uppercase().into(),
2312 },
2313 cx,
2314 )
2315 .unwrap();
2316 })?;
2317 Ok(acp::PromptResponse {
2318 stop_reason: acp::StopReason::EndTurn,
2319 })
2320 }
2321 .boxed_local()
2322 }
2323 }));
2324 let thread = cx
2325 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2326 .await
2327 .unwrap();
2328
2329 cx.update(|cx| {
2330 thread.update(cx, |thread, cx| {
2331 thread.send(new_prompt_id(), vec!["Lorem".into()], cx)
2332 })
2333 })
2334 .await
2335 .unwrap();
2336 thread.read_with(cx, |thread, cx| {
2337 assert_eq!(
2338 thread.to_markdown(cx),
2339 indoc! {"
2340 ## User (checkpoint)
2341
2342 Lorem
2343
2344 ## Assistant
2345
2346 LOREM
2347
2348 "}
2349 );
2350 });
2351 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2352
2353 cx.update(|cx| {
2354 thread.update(cx, |thread, cx| {
2355 thread.send(new_prompt_id(), vec!["ipsum".into()], cx)
2356 })
2357 })
2358 .await
2359 .unwrap();
2360 thread.read_with(cx, |thread, cx| {
2361 assert_eq!(
2362 thread.to_markdown(cx),
2363 indoc! {"
2364 ## User (checkpoint)
2365
2366 Lorem
2367
2368 ## Assistant
2369
2370 LOREM
2371
2372 ## User (checkpoint)
2373
2374 ipsum
2375
2376 ## Assistant
2377
2378 IPSUM
2379
2380 "}
2381 );
2382 });
2383 assert_eq!(
2384 fs.files(),
2385 vec![
2386 Path::new(path!("/test/file-0")),
2387 Path::new(path!("/test/file-1"))
2388 ]
2389 );
2390
2391 // Checkpoint isn't stored when there are no changes.
2392 simulate_changes.store(false, SeqCst);
2393 cx.update(|cx| {
2394 thread.update(cx, |thread, cx| {
2395 thread.send(new_prompt_id(), vec!["dolor".into()], cx)
2396 })
2397 })
2398 .await
2399 .unwrap();
2400 thread.read_with(cx, |thread, cx| {
2401 assert_eq!(
2402 thread.to_markdown(cx),
2403 indoc! {"
2404 ## User (checkpoint)
2405
2406 Lorem
2407
2408 ## Assistant
2409
2410 LOREM
2411
2412 ## User (checkpoint)
2413
2414 ipsum
2415
2416 ## Assistant
2417
2418 IPSUM
2419
2420 ## User
2421
2422 dolor
2423
2424 ## Assistant
2425
2426 DOLOR
2427
2428 "}
2429 );
2430 });
2431 assert_eq!(
2432 fs.files(),
2433 vec![
2434 Path::new(path!("/test/file-0")),
2435 Path::new(path!("/test/file-1"))
2436 ]
2437 );
2438
2439 // Rewinding the conversation truncates the history and restores the checkpoint.
2440 thread
2441 .update(cx, |thread, cx| {
2442 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2443 panic!("unexpected entries {:?}", thread.entries)
2444 };
2445 thread.rewind(message.prompt_id.clone().unwrap(), cx)
2446 })
2447 .await
2448 .unwrap();
2449 thread.read_with(cx, |thread, cx| {
2450 assert_eq!(
2451 thread.to_markdown(cx),
2452 indoc! {"
2453 ## User (checkpoint)
2454
2455 Lorem
2456
2457 ## Assistant
2458
2459 LOREM
2460
2461 "}
2462 );
2463 });
2464 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2465 }
2466
2467 #[gpui::test]
2468 async fn test_refusal(cx: &mut TestAppContext) {
2469 init_test(cx);
2470 let fs = FakeFs::new(cx.background_executor.clone());
2471 fs.insert_tree(path!("/"), json!({})).await;
2472 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2473
2474 let refuse_next = Arc::new(AtomicBool::new(false));
2475 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2476 let refuse_next = refuse_next.clone();
2477 move |request, thread, mut cx| {
2478 let refuse_next = refuse_next.clone();
2479 async move {
2480 if refuse_next.load(SeqCst) {
2481 return Ok(acp::PromptResponse {
2482 stop_reason: acp::StopReason::Refusal,
2483 });
2484 }
2485
2486 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2487 panic!("expected text content block");
2488 };
2489 thread.update(&mut cx, |thread, cx| {
2490 thread
2491 .handle_session_update(
2492 acp::SessionUpdate::AgentMessageChunk {
2493 content: content.text.to_uppercase().into(),
2494 },
2495 cx,
2496 )
2497 .unwrap();
2498 })?;
2499 Ok(acp::PromptResponse {
2500 stop_reason: acp::StopReason::EndTurn,
2501 })
2502 }
2503 .boxed_local()
2504 }
2505 }));
2506 let thread = cx
2507 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2508 .await
2509 .unwrap();
2510
2511 cx.update(|cx| {
2512 thread.update(cx, |thread, cx| {
2513 thread.send(new_prompt_id(), vec!["hello".into()], cx)
2514 })
2515 })
2516 .await
2517 .unwrap();
2518 thread.read_with(cx, |thread, cx| {
2519 assert_eq!(
2520 thread.to_markdown(cx),
2521 indoc! {"
2522 ## User
2523
2524 hello
2525
2526 ## Assistant
2527
2528 HELLO
2529
2530 "}
2531 );
2532 });
2533
2534 // Simulate refusing the second message, ensuring the conversation gets
2535 // truncated to before sending it.
2536 refuse_next.store(true, SeqCst);
2537 cx.update(|cx| {
2538 thread.update(cx, |thread, cx| {
2539 thread.send(new_prompt_id(), vec!["world".into()], cx)
2540 })
2541 })
2542 .await
2543 .unwrap();
2544 thread.read_with(cx, |thread, cx| {
2545 assert_eq!(
2546 thread.to_markdown(cx),
2547 indoc! {"
2548 ## User
2549
2550 hello
2551
2552 ## Assistant
2553
2554 HELLO
2555
2556 "}
2557 );
2558 });
2559 }
2560
2561 async fn run_until_first_tool_call(
2562 thread: &Entity<AcpThread>,
2563 cx: &mut TestAppContext,
2564 ) -> usize {
2565 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2566
2567 let subscription = cx.update(|cx| {
2568 cx.subscribe(thread, move |thread, _, cx| {
2569 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2570 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2571 return tx.try_send(ix).unwrap();
2572 }
2573 }
2574 })
2575 });
2576
2577 select! {
2578 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2579 panic!("Timeout waiting for tool call")
2580 }
2581 ix = rx.next().fuse() => {
2582 drop(subscription);
2583 ix.unwrap()
2584 }
2585 }
2586 }
2587
2588 #[derive(Clone, Default)]
2589 struct FakeAgentConnection {
2590 auth_methods: Vec<acp::AuthMethod>,
2591 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2592 on_user_message: Option<
2593 Rc<
2594 dyn Fn(
2595 acp::PromptRequest,
2596 WeakEntity<AcpThread>,
2597 AsyncApp,
2598 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2599 + 'static,
2600 >,
2601 >,
2602 }
2603
2604 impl FakeAgentConnection {
2605 fn new() -> Self {
2606 Self {
2607 auth_methods: Vec::new(),
2608 on_user_message: None,
2609 sessions: Arc::default(),
2610 }
2611 }
2612
2613 #[expect(unused)]
2614 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2615 self.auth_methods = auth_methods;
2616 self
2617 }
2618
2619 fn on_user_message(
2620 mut self,
2621 handler: impl Fn(
2622 acp::PromptRequest,
2623 WeakEntity<AcpThread>,
2624 AsyncApp,
2625 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2626 + 'static,
2627 ) -> Self {
2628 self.on_user_message.replace(Rc::new(handler));
2629 self
2630 }
2631 }
2632
2633 impl AgentConnection for FakeAgentConnection {
2634 fn auth_methods(&self) -> &[acp::AuthMethod] {
2635 &self.auth_methods
2636 }
2637
2638 fn new_thread(
2639 self: Rc<Self>,
2640 project: Entity<Project>,
2641 _cwd: &Path,
2642 cx: &mut App,
2643 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2644 let session_id = acp::SessionId(
2645 rand::thread_rng()
2646 .sample_iter(&rand::distributions::Alphanumeric)
2647 .take(7)
2648 .map(char::from)
2649 .collect::<String>()
2650 .into(),
2651 );
2652 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2653 let thread = cx.new(|cx| {
2654 AcpThread::new(
2655 "Test",
2656 self.clone(),
2657 project,
2658 action_log,
2659 session_id.clone(),
2660 watch::Receiver::constant(acp::PromptCapabilities {
2661 image: true,
2662 audio: true,
2663 embedded_context: true,
2664 }),
2665 cx,
2666 )
2667 });
2668 self.sessions.lock().insert(session_id, thread.downgrade());
2669 Task::ready(Ok(thread))
2670 }
2671
2672 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2673 if self.auth_methods().iter().any(|m| m.id == method) {
2674 Task::ready(Ok(()))
2675 } else {
2676 Task::ready(Err(anyhow!("Invalid Auth Method")))
2677 }
2678 }
2679
2680 fn prompt(
2681 &self,
2682 params: acp::PromptRequest,
2683 cx: &mut App,
2684 ) -> Task<gpui::Result<acp::PromptResponse>> {
2685 let sessions = self.sessions.lock();
2686 let thread = sessions.get(¶ms.session_id).unwrap();
2687 if let Some(handler) = &self.on_user_message {
2688 let handler = handler.clone();
2689 let thread = thread.clone();
2690 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2691 } else {
2692 Task::ready(Ok(acp::PromptResponse {
2693 stop_reason: acp::StopReason::EndTurn,
2694 }))
2695 }
2696 }
2697
2698 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2699 let sessions = self.sessions.lock();
2700 let thread = sessions.get(session_id).unwrap().clone();
2701
2702 cx.spawn(async move |cx| {
2703 thread
2704 .update(cx, |thread, cx| thread.cancel(cx))
2705 .unwrap()
2706 .await
2707 })
2708 .detach();
2709 }
2710
2711 fn rewind(
2712 &self,
2713 session_id: &acp::SessionId,
2714 _cx: &App,
2715 ) -> Option<Rc<dyn AgentSessionRewind>> {
2716 Some(Rc::new(FakeAgentSessionEditor {
2717 _session_id: session_id.clone(),
2718 }))
2719 }
2720
2721 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2722 self
2723 }
2724 }
2725
2726 struct FakeAgentSessionEditor {
2727 _session_id: acp::SessionId,
2728 }
2729
2730 impl AgentSessionRewind for FakeAgentSessionEditor {
2731 fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task<Result<()>> {
2732 Task::ready(Ok(()))
2733 }
2734 }
2735}