1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 Self {
187 id: tool_call.id,
188 label: cx.new(|cx| {
189 Markdown::new(
190 tool_call.title.into(),
191 Some(language_registry.clone()),
192 None,
193 cx,
194 )
195 }),
196 kind: tool_call.kind,
197 content: tool_call
198 .content
199 .into_iter()
200 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
201 .collect(),
202 locations: tool_call.locations,
203 resolved_locations: Vec::default(),
204 status,
205 raw_input: tool_call.raw_input,
206 raw_output: tool_call.raw_output,
207 }
208 }
209
210 fn update_fields(
211 &mut self,
212 fields: acp::ToolCallUpdateFields,
213 language_registry: Arc<LanguageRegistry>,
214 cx: &mut App,
215 ) {
216 let acp::ToolCallUpdateFields {
217 kind,
218 status,
219 title,
220 content,
221 locations,
222 raw_input,
223 raw_output,
224 } = fields;
225
226 if let Some(kind) = kind {
227 self.kind = kind;
228 }
229
230 if let Some(status) = status {
231 self.status = status.into();
232 }
233
234 if let Some(title) = title {
235 self.label.update(cx, |label, cx| {
236 label.replace(title, cx);
237 });
238 }
239
240 if let Some(content) = content {
241 self.content = content
242 .into_iter()
243 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
244 .collect();
245 }
246
247 if let Some(locations) = locations {
248 self.locations = locations;
249 }
250
251 if let Some(raw_input) = raw_input {
252 self.raw_input = Some(raw_input);
253 }
254
255 if let Some(raw_output) = raw_output {
256 if self.content.is_empty()
257 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
258 {
259 self.content
260 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
261 markdown,
262 }));
263 }
264 self.raw_output = Some(raw_output);
265 }
266 }
267
268 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
269 self.content.iter().filter_map(|content| match content {
270 ToolCallContent::Diff(diff) => Some(diff),
271 ToolCallContent::ContentBlock(_) => None,
272 ToolCallContent::Terminal(_) => None,
273 })
274 }
275
276 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
277 self.content.iter().filter_map(|content| match content {
278 ToolCallContent::Terminal(terminal) => Some(terminal),
279 ToolCallContent::ContentBlock(_) => None,
280 ToolCallContent::Diff(_) => None,
281 })
282 }
283
284 fn to_markdown(&self, cx: &App) -> String {
285 let mut markdown = format!(
286 "**Tool Call: {}**\nStatus: {}\n\n",
287 self.label.read(cx).source(),
288 self.status
289 );
290 for content in &self.content {
291 markdown.push_str(content.to_markdown(cx).as_str());
292 markdown.push_str("\n\n");
293 }
294 markdown
295 }
296
297 async fn resolve_location(
298 location: acp::ToolCallLocation,
299 project: WeakEntity<Project>,
300 cx: &mut AsyncApp,
301 ) -> Option<AgentLocation> {
302 let buffer = project
303 .update(cx, |project, cx| {
304 project
305 .project_path_for_absolute_path(&location.path, cx)
306 .map(|path| project.open_buffer(path, cx))
307 })
308 .ok()??;
309 let buffer = buffer.await.log_err()?;
310 let position = buffer
311 .update(cx, |buffer, _| {
312 if let Some(row) = location.line {
313 let snapshot = buffer.snapshot();
314 let column = snapshot.indent_size_for_line(row).len;
315 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
316 snapshot.anchor_before(point)
317 } else {
318 Anchor::MIN
319 }
320 })
321 .ok()?;
322
323 Some(AgentLocation {
324 buffer: buffer.downgrade(),
325 position,
326 })
327 }
328
329 fn resolve_locations(
330 &self,
331 project: Entity<Project>,
332 cx: &mut App,
333 ) -> Task<Vec<Option<AgentLocation>>> {
334 let locations = self.locations.clone();
335 project.update(cx, |_, cx| {
336 cx.spawn(async move |project, cx| {
337 let mut new_locations = Vec::new();
338 for location in locations {
339 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
340 }
341 new_locations
342 })
343 })
344 }
345}
346
347#[derive(Debug)]
348pub enum ToolCallStatus {
349 /// The tool call hasn't started running yet, but we start showing it to
350 /// the user.
351 Pending,
352 /// The tool call is waiting for confirmation from the user.
353 WaitingForConfirmation {
354 options: Vec<acp::PermissionOption>,
355 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
356 },
357 /// The tool call is currently running.
358 InProgress,
359 /// The tool call completed successfully.
360 Completed,
361 /// The tool call failed.
362 Failed,
363 /// The user rejected the tool call.
364 Rejected,
365 /// The user canceled generation so the tool call was canceled.
366 Canceled,
367}
368
369impl From<acp::ToolCallStatus> for ToolCallStatus {
370 fn from(status: acp::ToolCallStatus) -> Self {
371 match status {
372 acp::ToolCallStatus::Pending => Self::Pending,
373 acp::ToolCallStatus::InProgress => Self::InProgress,
374 acp::ToolCallStatus::Completed => Self::Completed,
375 acp::ToolCallStatus::Failed => Self::Failed,
376 }
377 }
378}
379
380impl Display for ToolCallStatus {
381 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
382 write!(
383 f,
384 "{}",
385 match self {
386 ToolCallStatus::Pending => "Pending",
387 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
388 ToolCallStatus::InProgress => "In Progress",
389 ToolCallStatus::Completed => "Completed",
390 ToolCallStatus::Failed => "Failed",
391 ToolCallStatus::Rejected => "Rejected",
392 ToolCallStatus::Canceled => "Canceled",
393 }
394 )
395 }
396}
397
398#[derive(Debug, PartialEq, Clone)]
399pub enum ContentBlock {
400 Empty,
401 Markdown { markdown: Entity<Markdown> },
402 ResourceLink { resource_link: acp::ResourceLink },
403}
404
405impl ContentBlock {
406 pub fn new(
407 block: acp::ContentBlock,
408 language_registry: &Arc<LanguageRegistry>,
409 cx: &mut App,
410 ) -> Self {
411 let mut this = Self::Empty;
412 this.append(block, language_registry, cx);
413 this
414 }
415
416 pub fn new_combined(
417 blocks: impl IntoIterator<Item = acp::ContentBlock>,
418 language_registry: Arc<LanguageRegistry>,
419 cx: &mut App,
420 ) -> Self {
421 let mut this = Self::Empty;
422 for block in blocks {
423 this.append(block, &language_registry, cx);
424 }
425 this
426 }
427
428 pub fn append(
429 &mut self,
430 block: acp::ContentBlock,
431 language_registry: &Arc<LanguageRegistry>,
432 cx: &mut App,
433 ) {
434 if matches!(self, ContentBlock::Empty)
435 && let acp::ContentBlock::ResourceLink(resource_link) = block
436 {
437 *self = ContentBlock::ResourceLink { resource_link };
438 return;
439 }
440
441 let new_content = self.block_string_contents(block);
442
443 match self {
444 ContentBlock::Empty => {
445 *self = Self::create_markdown_block(new_content, language_registry, cx);
446 }
447 ContentBlock::Markdown { markdown } => {
448 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
449 }
450 ContentBlock::ResourceLink { resource_link } => {
451 let existing_content = Self::resource_link_md(&resource_link.uri);
452 let combined = format!("{}\n{}", existing_content, new_content);
453
454 *self = Self::create_markdown_block(combined, language_registry, cx);
455 }
456 }
457 }
458
459 fn create_markdown_block(
460 content: String,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) -> ContentBlock {
464 ContentBlock::Markdown {
465 markdown: cx
466 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
467 }
468 }
469
470 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
471 match block {
472 acp::ContentBlock::Text(text_content) => text_content.text,
473 acp::ContentBlock::ResourceLink(resource_link) => {
474 Self::resource_link_md(&resource_link.uri)
475 }
476 acp::ContentBlock::Resource(acp::EmbeddedResource {
477 resource:
478 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
479 uri,
480 ..
481 }),
482 ..
483 }) => Self::resource_link_md(&uri),
484 acp::ContentBlock::Image(image) => Self::image_md(&image),
485 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
486 }
487 }
488
489 fn resource_link_md(uri: &str) -> String {
490 if let Some(uri) = MentionUri::parse(uri).log_err() {
491 uri.as_link().to_string()
492 } else {
493 uri.to_string()
494 }
495 }
496
497 fn image_md(_image: &acp::ImageContent) -> String {
498 "`Image`".into()
499 }
500
501 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
502 match self {
503 ContentBlock::Empty => "",
504 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
505 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
506 }
507 }
508
509 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
510 match self {
511 ContentBlock::Empty => None,
512 ContentBlock::Markdown { markdown } => Some(markdown),
513 ContentBlock::ResourceLink { .. } => None,
514 }
515 }
516
517 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
518 match self {
519 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
520 _ => None,
521 }
522 }
523}
524
525#[derive(Debug)]
526pub enum ToolCallContent {
527 ContentBlock(ContentBlock),
528 Diff(Entity<Diff>),
529 Terminal(Entity<Terminal>),
530}
531
532impl ToolCallContent {
533 pub fn from_acp(
534 content: acp::ToolCallContent,
535 language_registry: Arc<LanguageRegistry>,
536 cx: &mut App,
537 ) -> Self {
538 match content {
539 acp::ToolCallContent::Content { content } => {
540 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
541 }
542 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
543 Diff::finalized(
544 diff.path,
545 diff.old_text,
546 diff.new_text,
547 language_registry,
548 cx,
549 )
550 })),
551 }
552 }
553
554 pub fn to_markdown(&self, cx: &App) -> String {
555 match self {
556 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
557 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
558 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
559 }
560 }
561}
562
563#[derive(Debug, PartialEq)]
564pub enum ToolCallUpdate {
565 UpdateFields(acp::ToolCallUpdate),
566 UpdateDiff(ToolCallUpdateDiff),
567 UpdateTerminal(ToolCallUpdateTerminal),
568}
569
570impl ToolCallUpdate {
571 fn id(&self) -> &acp::ToolCallId {
572 match self {
573 Self::UpdateFields(update) => &update.id,
574 Self::UpdateDiff(diff) => &diff.id,
575 Self::UpdateTerminal(terminal) => &terminal.id,
576 }
577 }
578}
579
580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
581 fn from(update: acp::ToolCallUpdate) -> Self {
582 Self::UpdateFields(update)
583 }
584}
585
586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
587 fn from(diff: ToolCallUpdateDiff) -> Self {
588 Self::UpdateDiff(diff)
589 }
590}
591
592#[derive(Debug, PartialEq)]
593pub struct ToolCallUpdateDiff {
594 pub id: acp::ToolCallId,
595 pub diff: Entity<Diff>,
596}
597
598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
599 fn from(terminal: ToolCallUpdateTerminal) -> Self {
600 Self::UpdateTerminal(terminal)
601 }
602}
603
604#[derive(Debug, PartialEq)]
605pub struct ToolCallUpdateTerminal {
606 pub id: acp::ToolCallId,
607 pub terminal: Entity<Terminal>,
608}
609
610#[derive(Debug, Default)]
611pub struct Plan {
612 pub entries: Vec<PlanEntry>,
613}
614
615#[derive(Debug)]
616pub struct PlanStats<'a> {
617 pub in_progress_entry: Option<&'a PlanEntry>,
618 pub pending: u32,
619 pub completed: u32,
620}
621
622impl Plan {
623 pub fn is_empty(&self) -> bool {
624 self.entries.is_empty()
625 }
626
627 pub fn stats(&self) -> PlanStats<'_> {
628 let mut stats = PlanStats {
629 in_progress_entry: None,
630 pending: 0,
631 completed: 0,
632 };
633
634 for entry in &self.entries {
635 match &entry.status {
636 acp::PlanEntryStatus::Pending => {
637 stats.pending += 1;
638 }
639 acp::PlanEntryStatus::InProgress => {
640 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
641 }
642 acp::PlanEntryStatus::Completed => {
643 stats.completed += 1;
644 }
645 }
646 }
647
648 stats
649 }
650}
651
652#[derive(Debug)]
653pub struct PlanEntry {
654 pub content: Entity<Markdown>,
655 pub priority: acp::PlanEntryPriority,
656 pub status: acp::PlanEntryStatus,
657}
658
659impl PlanEntry {
660 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
661 Self {
662 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
663 priority: entry.priority,
664 status: entry.status,
665 }
666 }
667}
668
669#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
670pub struct TokenUsage {
671 pub max_tokens: u64,
672 pub used_tokens: u64,
673}
674
675impl TokenUsage {
676 pub fn ratio(&self) -> TokenUsageRatio {
677 #[cfg(debug_assertions)]
678 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
679 .unwrap_or("0.8".to_string())
680 .parse()
681 .unwrap();
682 #[cfg(not(debug_assertions))]
683 let warning_threshold: f32 = 0.8;
684
685 // When the maximum is unknown because there is no selected model,
686 // avoid showing the token limit warning.
687 if self.max_tokens == 0 {
688 TokenUsageRatio::Normal
689 } else if self.used_tokens >= self.max_tokens {
690 TokenUsageRatio::Exceeded
691 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
692 TokenUsageRatio::Warning
693 } else {
694 TokenUsageRatio::Normal
695 }
696 }
697}
698
699#[derive(Debug, Clone, PartialEq, Eq)]
700pub enum TokenUsageRatio {
701 Normal,
702 Warning,
703 Exceeded,
704}
705
706#[derive(Debug, Clone)]
707pub struct RetryStatus {
708 pub last_error: SharedString,
709 pub attempt: usize,
710 pub max_attempts: usize,
711 pub started_at: Instant,
712 pub duration: Duration,
713}
714
715pub struct AcpThread {
716 title: SharedString,
717 entries: Vec<AgentThreadEntry>,
718 plan: Plan,
719 project: Entity<Project>,
720 action_log: Entity<ActionLog>,
721 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
722 send_task: Option<Task<()>>,
723 connection: Rc<dyn AgentConnection>,
724 session_id: acp::SessionId,
725 token_usage: Option<TokenUsage>,
726}
727
728#[derive(Debug)]
729pub enum AcpThreadEvent {
730 NewEntry,
731 TitleUpdated,
732 TokenUsageUpdated,
733 EntryUpdated(usize),
734 EntriesRemoved(Range<usize>),
735 ToolAuthorizationRequired,
736 Retry(RetryStatus),
737 Stopped,
738 Error,
739 LoadError(LoadError),
740}
741
742impl EventEmitter<AcpThreadEvent> for AcpThread {}
743
744#[derive(PartialEq, Eq)]
745pub enum ThreadStatus {
746 Idle,
747 WaitingForToolConfirmation,
748 Generating,
749}
750
751#[derive(Debug, Clone)]
752pub enum LoadError {
753 NotInstalled {
754 error_message: SharedString,
755 install_message: SharedString,
756 install_command: String,
757 },
758 Unsupported {
759 error_message: SharedString,
760 upgrade_message: SharedString,
761 upgrade_command: String,
762 },
763 Exited {
764 status: ExitStatus,
765 },
766 Other(SharedString),
767}
768
769impl Display for LoadError {
770 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
771 match self {
772 LoadError::NotInstalled { error_message, .. }
773 | LoadError::Unsupported { error_message, .. } => {
774 write!(f, "{error_message}")
775 }
776 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
777 LoadError::Other(msg) => write!(f, "{}", msg),
778 }
779 }
780}
781
782impl Error for LoadError {}
783
784impl AcpThread {
785 pub fn new(
786 title: impl Into<SharedString>,
787 connection: Rc<dyn AgentConnection>,
788 project: Entity<Project>,
789 action_log: Entity<ActionLog>,
790 session_id: acp::SessionId,
791 ) -> Self {
792 Self {
793 action_log,
794 shared_buffers: Default::default(),
795 entries: Default::default(),
796 plan: Default::default(),
797 title: title.into(),
798 project,
799 send_task: None,
800 connection,
801 session_id,
802 token_usage: None,
803 }
804 }
805
806 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
807 &self.connection
808 }
809
810 pub fn action_log(&self) -> &Entity<ActionLog> {
811 &self.action_log
812 }
813
814 pub fn project(&self) -> &Entity<Project> {
815 &self.project
816 }
817
818 pub fn title(&self) -> SharedString {
819 self.title.clone()
820 }
821
822 pub fn entries(&self) -> &[AgentThreadEntry] {
823 &self.entries
824 }
825
826 pub fn session_id(&self) -> &acp::SessionId {
827 &self.session_id
828 }
829
830 pub fn status(&self) -> ThreadStatus {
831 if self.send_task.is_some() {
832 if self.waiting_for_tool_confirmation() {
833 ThreadStatus::WaitingForToolConfirmation
834 } else {
835 ThreadStatus::Generating
836 }
837 } else {
838 ThreadStatus::Idle
839 }
840 }
841
842 pub fn token_usage(&self) -> Option<&TokenUsage> {
843 self.token_usage.as_ref()
844 }
845
846 pub fn has_pending_edit_tool_calls(&self) -> bool {
847 for entry in self.entries.iter().rev() {
848 match entry {
849 AgentThreadEntry::UserMessage(_) => return false,
850 AgentThreadEntry::ToolCall(
851 call @ ToolCall {
852 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
853 ..
854 },
855 ) if call.diffs().next().is_some() => {
856 return true;
857 }
858 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
859 }
860 }
861
862 false
863 }
864
865 pub fn used_tools_since_last_user_message(&self) -> bool {
866 for entry in self.entries.iter().rev() {
867 match entry {
868 AgentThreadEntry::UserMessage(..) => return false,
869 AgentThreadEntry::AssistantMessage(..) => continue,
870 AgentThreadEntry::ToolCall(..) => return true,
871 }
872 }
873
874 false
875 }
876
877 pub fn handle_session_update(
878 &mut self,
879 update: acp::SessionUpdate,
880 cx: &mut Context<Self>,
881 ) -> Result<(), acp::Error> {
882 match update {
883 acp::SessionUpdate::UserMessageChunk { content } => {
884 self.push_user_content_block(None, content, cx);
885 }
886 acp::SessionUpdate::AgentMessageChunk { content } => {
887 self.push_assistant_content_block(content, false, cx);
888 }
889 acp::SessionUpdate::AgentThoughtChunk { content } => {
890 self.push_assistant_content_block(content, true, cx);
891 }
892 acp::SessionUpdate::ToolCall(tool_call) => {
893 self.upsert_tool_call(tool_call, cx)?;
894 }
895 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
896 self.update_tool_call(tool_call_update, cx)?;
897 }
898 acp::SessionUpdate::Plan(plan) => {
899 self.update_plan(plan, cx);
900 }
901 }
902 Ok(())
903 }
904
905 pub fn push_user_content_block(
906 &mut self,
907 message_id: Option<UserMessageId>,
908 chunk: acp::ContentBlock,
909 cx: &mut Context<Self>,
910 ) {
911 let language_registry = self.project.read(cx).languages().clone();
912 let entries_len = self.entries.len();
913
914 if let Some(last_entry) = self.entries.last_mut()
915 && let AgentThreadEntry::UserMessage(UserMessage {
916 id,
917 content,
918 chunks,
919 ..
920 }) = last_entry
921 {
922 *id = message_id.or(id.take());
923 content.append(chunk.clone(), &language_registry, cx);
924 chunks.push(chunk);
925 let idx = entries_len - 1;
926 cx.emit(AcpThreadEvent::EntryUpdated(idx));
927 } else {
928 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
929 self.push_entry(
930 AgentThreadEntry::UserMessage(UserMessage {
931 id: message_id,
932 content,
933 chunks: vec![chunk],
934 checkpoint: None,
935 }),
936 cx,
937 );
938 }
939 }
940
941 pub fn push_assistant_content_block(
942 &mut self,
943 chunk: acp::ContentBlock,
944 is_thought: bool,
945 cx: &mut Context<Self>,
946 ) {
947 let language_registry = self.project.read(cx).languages().clone();
948 let entries_len = self.entries.len();
949 if let Some(last_entry) = self.entries.last_mut()
950 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
951 {
952 let idx = entries_len - 1;
953 cx.emit(AcpThreadEvent::EntryUpdated(idx));
954 match (chunks.last_mut(), is_thought) {
955 (Some(AssistantMessageChunk::Message { block }), false)
956 | (Some(AssistantMessageChunk::Thought { block }), true) => {
957 block.append(chunk, &language_registry, cx)
958 }
959 _ => {
960 let block = ContentBlock::new(chunk, &language_registry, cx);
961 if is_thought {
962 chunks.push(AssistantMessageChunk::Thought { block })
963 } else {
964 chunks.push(AssistantMessageChunk::Message { block })
965 }
966 }
967 }
968 } else {
969 let block = ContentBlock::new(chunk, &language_registry, cx);
970 let chunk = if is_thought {
971 AssistantMessageChunk::Thought { block }
972 } else {
973 AssistantMessageChunk::Message { block }
974 };
975
976 self.push_entry(
977 AgentThreadEntry::AssistantMessage(AssistantMessage {
978 chunks: vec![chunk],
979 }),
980 cx,
981 );
982 }
983 }
984
985 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
986 self.entries.push(entry);
987 cx.emit(AcpThreadEvent::NewEntry);
988 }
989
990 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
991 self.title = title;
992 cx.emit(AcpThreadEvent::TitleUpdated);
993 Ok(())
994 }
995
996 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
997 self.token_usage = usage;
998 cx.emit(AcpThreadEvent::TokenUsageUpdated);
999 }
1000
1001 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1002 cx.emit(AcpThreadEvent::Retry(status));
1003 }
1004
1005 pub fn update_tool_call(
1006 &mut self,
1007 update: impl Into<ToolCallUpdate>,
1008 cx: &mut Context<Self>,
1009 ) -> Result<()> {
1010 let update = update.into();
1011 let languages = self.project.read(cx).languages().clone();
1012
1013 let (ix, current_call) = self
1014 .tool_call_mut(update.id())
1015 .context("Tool call not found")?;
1016 match update {
1017 ToolCallUpdate::UpdateFields(update) => {
1018 let location_updated = update.fields.locations.is_some();
1019 current_call.update_fields(update.fields, languages, cx);
1020 if location_updated {
1021 self.resolve_locations(update.id, cx);
1022 }
1023 }
1024 ToolCallUpdate::UpdateDiff(update) => {
1025 current_call.content.clear();
1026 current_call
1027 .content
1028 .push(ToolCallContent::Diff(update.diff));
1029 }
1030 ToolCallUpdate::UpdateTerminal(update) => {
1031 current_call.content.clear();
1032 current_call
1033 .content
1034 .push(ToolCallContent::Terminal(update.terminal));
1035 }
1036 }
1037
1038 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1039
1040 Ok(())
1041 }
1042
1043 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1044 pub fn upsert_tool_call(
1045 &mut self,
1046 tool_call: acp::ToolCall,
1047 cx: &mut Context<Self>,
1048 ) -> Result<(), acp::Error> {
1049 let status = tool_call.status.into();
1050 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1051 }
1052
1053 /// Fails if id does not match an existing entry.
1054 pub fn upsert_tool_call_inner(
1055 &mut self,
1056 tool_call_update: acp::ToolCallUpdate,
1057 status: ToolCallStatus,
1058 cx: &mut Context<Self>,
1059 ) -> Result<(), acp::Error> {
1060 let language_registry = self.project.read(cx).languages().clone();
1061 let id = tool_call_update.id.clone();
1062
1063 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1064 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1065 current_call.status = status;
1066
1067 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1068 } else {
1069 let call =
1070 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1071 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1072 };
1073
1074 self.resolve_locations(id, cx);
1075 Ok(())
1076 }
1077
1078 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1079 // The tool call we are looking for is typically the last one, or very close to the end.
1080 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1081 self.entries
1082 .iter_mut()
1083 .enumerate()
1084 .rev()
1085 .find_map(|(index, tool_call)| {
1086 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1087 && &tool_call.id == id
1088 {
1089 Some((index, tool_call))
1090 } else {
1091 None
1092 }
1093 })
1094 }
1095
1096 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1097 self.entries
1098 .iter()
1099 .enumerate()
1100 .rev()
1101 .find_map(|(index, tool_call)| {
1102 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1103 && &tool_call.id == id
1104 {
1105 Some((index, tool_call))
1106 } else {
1107 None
1108 }
1109 })
1110 }
1111
1112 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1113 let project = self.project.clone();
1114 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1115 return;
1116 };
1117 let task = tool_call.resolve_locations(project, cx);
1118 cx.spawn(async move |this, cx| {
1119 let resolved_locations = task.await;
1120 this.update(cx, |this, cx| {
1121 let project = this.project.clone();
1122 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1123 return;
1124 };
1125 if let Some(Some(location)) = resolved_locations.last() {
1126 project.update(cx, |project, cx| {
1127 if let Some(agent_location) = project.agent_location() {
1128 let should_ignore = agent_location.buffer == location.buffer
1129 && location
1130 .buffer
1131 .update(cx, |buffer, _| {
1132 let snapshot = buffer.snapshot();
1133 let old_position =
1134 agent_location.position.to_point(&snapshot);
1135 let new_position = location.position.to_point(&snapshot);
1136 // ignore this so that when we get updates from the edit tool
1137 // the position doesn't reset to the startof line
1138 old_position.row == new_position.row
1139 && old_position.column > new_position.column
1140 })
1141 .ok()
1142 .unwrap_or_default();
1143 if !should_ignore {
1144 project.set_agent_location(Some(location.clone()), cx);
1145 }
1146 }
1147 });
1148 }
1149 if tool_call.resolved_locations != resolved_locations {
1150 tool_call.resolved_locations = resolved_locations;
1151 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1152 }
1153 })
1154 })
1155 .detach();
1156 }
1157
1158 pub fn request_tool_call_authorization(
1159 &mut self,
1160 tool_call: acp::ToolCallUpdate,
1161 options: Vec<acp::PermissionOption>,
1162 cx: &mut Context<Self>,
1163 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1164 let (tx, rx) = oneshot::channel();
1165
1166 let status = ToolCallStatus::WaitingForConfirmation {
1167 options,
1168 respond_tx: tx,
1169 };
1170
1171 self.upsert_tool_call_inner(tool_call, status, cx)?;
1172 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1173 Ok(rx)
1174 }
1175
1176 pub fn authorize_tool_call(
1177 &mut self,
1178 id: acp::ToolCallId,
1179 option_id: acp::PermissionOptionId,
1180 option_kind: acp::PermissionOptionKind,
1181 cx: &mut Context<Self>,
1182 ) {
1183 let Some((ix, call)) = self.tool_call_mut(&id) else {
1184 return;
1185 };
1186
1187 let new_status = match option_kind {
1188 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1189 ToolCallStatus::Rejected
1190 }
1191 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1192 ToolCallStatus::InProgress
1193 }
1194 };
1195
1196 let curr_status = mem::replace(&mut call.status, new_status);
1197
1198 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1199 respond_tx.send(option_id).log_err();
1200 } else if cfg!(debug_assertions) {
1201 panic!("tried to authorize an already authorized tool call");
1202 }
1203
1204 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1205 }
1206
1207 /// Returns true if the last turn is awaiting tool authorization
1208 pub fn waiting_for_tool_confirmation(&self) -> bool {
1209 for entry in self.entries.iter().rev() {
1210 match &entry {
1211 AgentThreadEntry::ToolCall(call) => match call.status {
1212 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1213 ToolCallStatus::Pending
1214 | ToolCallStatus::InProgress
1215 | ToolCallStatus::Completed
1216 | ToolCallStatus::Failed
1217 | ToolCallStatus::Rejected
1218 | ToolCallStatus::Canceled => continue,
1219 },
1220 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1221 // Reached the beginning of the turn
1222 return false;
1223 }
1224 }
1225 }
1226 false
1227 }
1228
1229 pub fn plan(&self) -> &Plan {
1230 &self.plan
1231 }
1232
1233 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1234 let new_entries_len = request.entries.len();
1235 let mut new_entries = request.entries.into_iter();
1236
1237 // Reuse existing markdown to prevent flickering
1238 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1239 let PlanEntry {
1240 content,
1241 priority,
1242 status,
1243 } = old;
1244 content.update(cx, |old, cx| {
1245 old.replace(new.content, cx);
1246 });
1247 *priority = new.priority;
1248 *status = new.status;
1249 }
1250 for new in new_entries {
1251 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1252 }
1253 self.plan.entries.truncate(new_entries_len);
1254
1255 cx.notify();
1256 }
1257
1258 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1259 self.plan
1260 .entries
1261 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1262 cx.notify();
1263 }
1264
1265 #[cfg(any(test, feature = "test-support"))]
1266 pub fn send_raw(
1267 &mut self,
1268 message: &str,
1269 cx: &mut Context<Self>,
1270 ) -> BoxFuture<'static, Result<()>> {
1271 self.send(
1272 vec![acp::ContentBlock::Text(acp::TextContent {
1273 text: message.to_string(),
1274 annotations: None,
1275 })],
1276 cx,
1277 )
1278 }
1279
1280 pub fn send(
1281 &mut self,
1282 message: Vec<acp::ContentBlock>,
1283 cx: &mut Context<Self>,
1284 ) -> BoxFuture<'static, Result<()>> {
1285 let block = ContentBlock::new_combined(
1286 message.clone(),
1287 self.project.read(cx).languages().clone(),
1288 cx,
1289 );
1290 let request = acp::PromptRequest {
1291 prompt: message.clone(),
1292 session_id: self.session_id.clone(),
1293 };
1294 let git_store = self.project.read(cx).git_store().clone();
1295
1296 let message_id = if self
1297 .connection
1298 .session_editor(&self.session_id, cx)
1299 .is_some()
1300 {
1301 Some(UserMessageId::new())
1302 } else {
1303 None
1304 };
1305
1306 self.run_turn(cx, async move |this, cx| {
1307 this.update(cx, |this, cx| {
1308 this.push_entry(
1309 AgentThreadEntry::UserMessage(UserMessage {
1310 id: message_id.clone(),
1311 content: block,
1312 chunks: message,
1313 checkpoint: None,
1314 }),
1315 cx,
1316 );
1317 })
1318 .ok();
1319
1320 let old_checkpoint = git_store
1321 .update(cx, |git, cx| git.checkpoint(cx))?
1322 .await
1323 .context("failed to get old checkpoint")
1324 .log_err();
1325 this.update(cx, |this, cx| {
1326 if let Some((_ix, message)) = this.last_user_message() {
1327 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1328 git_checkpoint,
1329 show: false,
1330 });
1331 }
1332 this.connection.prompt(message_id, request, cx)
1333 })?
1334 .await
1335 })
1336 }
1337
1338 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1339 self.run_turn(cx, async move |this, cx| {
1340 this.update(cx, |this, cx| {
1341 this.connection
1342 .resume(&this.session_id, cx)
1343 .map(|resume| resume.run(cx))
1344 })?
1345 .context("resuming a session is not supported")?
1346 .await
1347 })
1348 }
1349
1350 fn run_turn(
1351 &mut self,
1352 cx: &mut Context<Self>,
1353 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1354 ) -> BoxFuture<'static, Result<()>> {
1355 self.clear_completed_plan_entries(cx);
1356
1357 let (tx, rx) = oneshot::channel();
1358 let cancel_task = self.cancel(cx);
1359
1360 self.send_task = Some(cx.spawn(async move |this, cx| {
1361 cancel_task.await;
1362 tx.send(f(this, cx).await).ok();
1363 }));
1364
1365 cx.spawn(async move |this, cx| {
1366 let response = rx.await;
1367
1368 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1369 .await?;
1370
1371 this.update(cx, |this, cx| {
1372 this.project
1373 .update(cx, |project, cx| project.set_agent_location(None, cx));
1374 match response {
1375 Ok(Err(e)) => {
1376 this.send_task.take();
1377 cx.emit(AcpThreadEvent::Error);
1378 Err(e)
1379 }
1380 result => {
1381 let canceled = matches!(
1382 result,
1383 Ok(Ok(acp::PromptResponse {
1384 stop_reason: acp::StopReason::Canceled
1385 }))
1386 );
1387
1388 // We only take the task if the current prompt wasn't canceled.
1389 //
1390 // This prompt may have been canceled because another one was sent
1391 // while it was still generating. In these cases, dropping `send_task`
1392 // would cause the next generation to be canceled.
1393 if !canceled {
1394 this.send_task.take();
1395 }
1396
1397 // Truncate entries if the last prompt was refused.
1398 if let Ok(Ok(acp::PromptResponse {
1399 stop_reason: acp::StopReason::Refusal,
1400 })) = result
1401 && let Some((ix, _)) = this.last_user_message()
1402 {
1403 let range = ix..this.entries.len();
1404 this.entries.truncate(ix);
1405 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1406 }
1407
1408 cx.emit(AcpThreadEvent::Stopped);
1409 Ok(())
1410 }
1411 }
1412 })?
1413 })
1414 .boxed()
1415 }
1416
1417 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1418 let Some(send_task) = self.send_task.take() else {
1419 return Task::ready(());
1420 };
1421
1422 for entry in self.entries.iter_mut() {
1423 if let AgentThreadEntry::ToolCall(call) = entry {
1424 let cancel = matches!(
1425 call.status,
1426 ToolCallStatus::Pending
1427 | ToolCallStatus::WaitingForConfirmation { .. }
1428 | ToolCallStatus::InProgress
1429 );
1430
1431 if cancel {
1432 call.status = ToolCallStatus::Canceled;
1433 }
1434 }
1435 }
1436
1437 self.connection.cancel(&self.session_id, cx);
1438
1439 // Wait for the send task to complete
1440 cx.foreground_executor().spawn(send_task)
1441 }
1442
1443 /// Rewinds this thread to before the entry at `index`, removing it and all
1444 /// subsequent entries while reverting any changes made from that point.
1445 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1446 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1447 return Task::ready(Err(anyhow!("not supported")));
1448 };
1449 let Some(message) = self.user_message(&id) else {
1450 return Task::ready(Err(anyhow!("message not found")));
1451 };
1452
1453 let checkpoint = message
1454 .checkpoint
1455 .as_ref()
1456 .map(|c| c.git_checkpoint.clone());
1457
1458 let git_store = self.project.read(cx).git_store().clone();
1459 cx.spawn(async move |this, cx| {
1460 if let Some(checkpoint) = checkpoint {
1461 git_store
1462 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1463 .await?;
1464 }
1465
1466 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1467 .await?;
1468 this.update(cx, |this, cx| {
1469 if let Some((ix, _)) = this.user_message_mut(&id) {
1470 let range = ix..this.entries.len();
1471 this.entries.truncate(ix);
1472 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1473 }
1474 })
1475 })
1476 }
1477
1478 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1479 let git_store = self.project.read(cx).git_store().clone();
1480
1481 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1482 if let Some(checkpoint) = message.checkpoint.as_ref() {
1483 checkpoint.git_checkpoint.clone()
1484 } else {
1485 return Task::ready(Ok(()));
1486 }
1487 } else {
1488 return Task::ready(Ok(()));
1489 };
1490
1491 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1492 cx.spawn(async move |this, cx| {
1493 let new_checkpoint = new_checkpoint
1494 .await
1495 .context("failed to get new checkpoint")
1496 .log_err();
1497 if let Some(new_checkpoint) = new_checkpoint {
1498 let equal = git_store
1499 .update(cx, |git, cx| {
1500 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1501 })?
1502 .await
1503 .unwrap_or(true);
1504 this.update(cx, |this, cx| {
1505 let (ix, message) = this.last_user_message().context("no user message")?;
1506 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1507 checkpoint.show = !equal;
1508 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1509 anyhow::Ok(())
1510 })??;
1511 }
1512
1513 Ok(())
1514 })
1515 }
1516
1517 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1518 self.entries
1519 .iter_mut()
1520 .enumerate()
1521 .rev()
1522 .find_map(|(ix, entry)| {
1523 if let AgentThreadEntry::UserMessage(message) = entry {
1524 Some((ix, message))
1525 } else {
1526 None
1527 }
1528 })
1529 }
1530
1531 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1532 self.entries.iter().find_map(|entry| {
1533 if let AgentThreadEntry::UserMessage(message) = entry {
1534 if message.id.as_ref() == Some(id) {
1535 Some(message)
1536 } else {
1537 None
1538 }
1539 } else {
1540 None
1541 }
1542 })
1543 }
1544
1545 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1546 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1547 if let AgentThreadEntry::UserMessage(message) = entry {
1548 if message.id.as_ref() == Some(id) {
1549 Some((ix, message))
1550 } else {
1551 None
1552 }
1553 } else {
1554 None
1555 }
1556 })
1557 }
1558
1559 pub fn read_text_file(
1560 &self,
1561 path: PathBuf,
1562 line: Option<u32>,
1563 limit: Option<u32>,
1564 reuse_shared_snapshot: bool,
1565 cx: &mut Context<Self>,
1566 ) -> Task<Result<String>> {
1567 let project = self.project.clone();
1568 let action_log = self.action_log.clone();
1569 cx.spawn(async move |this, cx| {
1570 let load = project.update(cx, |project, cx| {
1571 let path = project
1572 .project_path_for_absolute_path(&path, cx)
1573 .context("invalid path")?;
1574 anyhow::Ok(project.open_buffer(path, cx))
1575 });
1576 let buffer = load??.await?;
1577
1578 let snapshot = if reuse_shared_snapshot {
1579 this.read_with(cx, |this, _| {
1580 this.shared_buffers.get(&buffer.clone()).cloned()
1581 })
1582 .log_err()
1583 .flatten()
1584 } else {
1585 None
1586 };
1587
1588 let snapshot = if let Some(snapshot) = snapshot {
1589 snapshot
1590 } else {
1591 action_log.update(cx, |action_log, cx| {
1592 action_log.buffer_read(buffer.clone(), cx);
1593 })?;
1594 project.update(cx, |project, cx| {
1595 let position = buffer
1596 .read(cx)
1597 .snapshot()
1598 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1599 project.set_agent_location(
1600 Some(AgentLocation {
1601 buffer: buffer.downgrade(),
1602 position,
1603 }),
1604 cx,
1605 );
1606 })?;
1607
1608 buffer.update(cx, |buffer, _| buffer.snapshot())?
1609 };
1610
1611 this.update(cx, |this, _| {
1612 let text = snapshot.text();
1613 this.shared_buffers.insert(buffer.clone(), snapshot);
1614 if line.is_none() && limit.is_none() {
1615 return Ok(text);
1616 }
1617 let limit = limit.unwrap_or(u32::MAX) as usize;
1618 let Some(line) = line else {
1619 return Ok(text.lines().take(limit).collect::<String>());
1620 };
1621
1622 let count = text.lines().count();
1623 if count < line as usize {
1624 anyhow::bail!("There are only {} lines", count);
1625 }
1626 Ok(text
1627 .lines()
1628 .skip(line as usize + 1)
1629 .take(limit)
1630 .collect::<String>())
1631 })?
1632 })
1633 }
1634
1635 pub fn write_text_file(
1636 &self,
1637 path: PathBuf,
1638 content: String,
1639 cx: &mut Context<Self>,
1640 ) -> Task<Result<()>> {
1641 let project = self.project.clone();
1642 let action_log = self.action_log.clone();
1643 cx.spawn(async move |this, cx| {
1644 let load = project.update(cx, |project, cx| {
1645 let path = project
1646 .project_path_for_absolute_path(&path, cx)
1647 .context("invalid path")?;
1648 anyhow::Ok(project.open_buffer(path, cx))
1649 });
1650 let buffer = load??.await?;
1651 let snapshot = this.update(cx, |this, cx| {
1652 this.shared_buffers
1653 .get(&buffer)
1654 .cloned()
1655 .unwrap_or_else(|| buffer.read(cx).snapshot())
1656 })?;
1657 let edits = cx
1658 .background_executor()
1659 .spawn(async move {
1660 let old_text = snapshot.text();
1661 text_diff(old_text.as_str(), &content)
1662 .into_iter()
1663 .map(|(range, replacement)| {
1664 (
1665 snapshot.anchor_after(range.start)
1666 ..snapshot.anchor_before(range.end),
1667 replacement,
1668 )
1669 })
1670 .collect::<Vec<_>>()
1671 })
1672 .await;
1673
1674 project.update(cx, |project, cx| {
1675 project.set_agent_location(
1676 Some(AgentLocation {
1677 buffer: buffer.downgrade(),
1678 position: edits
1679 .last()
1680 .map(|(range, _)| range.end)
1681 .unwrap_or(Anchor::MIN),
1682 }),
1683 cx,
1684 );
1685 })?;
1686
1687 let format_on_save = cx.update(|cx| {
1688 action_log.update(cx, |action_log, cx| {
1689 action_log.buffer_read(buffer.clone(), cx);
1690 });
1691
1692 let format_on_save = buffer.update(cx, |buffer, cx| {
1693 buffer.edit(edits, None, cx);
1694
1695 let settings = language::language_settings::language_settings(
1696 buffer.language().map(|l| l.name()),
1697 buffer.file(),
1698 cx,
1699 );
1700
1701 settings.format_on_save != FormatOnSave::Off
1702 });
1703 action_log.update(cx, |action_log, cx| {
1704 action_log.buffer_edited(buffer.clone(), cx);
1705 });
1706 format_on_save
1707 })?;
1708
1709 if format_on_save {
1710 let format_task = project.update(cx, |project, cx| {
1711 project.format(
1712 HashSet::from_iter([buffer.clone()]),
1713 LspFormatTarget::Buffers,
1714 false,
1715 FormatTrigger::Save,
1716 cx,
1717 )
1718 })?;
1719 format_task.await.log_err();
1720
1721 action_log.update(cx, |action_log, cx| {
1722 action_log.buffer_edited(buffer.clone(), cx);
1723 })?;
1724 }
1725
1726 project
1727 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1728 .await
1729 })
1730 }
1731
1732 pub fn to_markdown(&self, cx: &App) -> String {
1733 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1734 }
1735
1736 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1737 cx.emit(AcpThreadEvent::LoadError(error));
1738 }
1739}
1740
1741fn markdown_for_raw_output(
1742 raw_output: &serde_json::Value,
1743 language_registry: &Arc<LanguageRegistry>,
1744 cx: &mut App,
1745) -> Option<Entity<Markdown>> {
1746 match raw_output {
1747 serde_json::Value::Null => None,
1748 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1749 Markdown::new(
1750 value.to_string().into(),
1751 Some(language_registry.clone()),
1752 None,
1753 cx,
1754 )
1755 })),
1756 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1757 Markdown::new(
1758 value.to_string().into(),
1759 Some(language_registry.clone()),
1760 None,
1761 cx,
1762 )
1763 })),
1764 serde_json::Value::String(value) => Some(cx.new(|cx| {
1765 Markdown::new(
1766 value.clone().into(),
1767 Some(language_registry.clone()),
1768 None,
1769 cx,
1770 )
1771 })),
1772 value => Some(cx.new(|cx| {
1773 Markdown::new(
1774 format!("```json\n{}\n```", value).into(),
1775 Some(language_registry.clone()),
1776 None,
1777 cx,
1778 )
1779 })),
1780 }
1781}
1782
1783#[cfg(test)]
1784mod tests {
1785 use super::*;
1786 use anyhow::anyhow;
1787 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1788 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1789 use indoc::indoc;
1790 use project::{FakeFs, Fs};
1791 use rand::Rng as _;
1792 use serde_json::json;
1793 use settings::SettingsStore;
1794 use smol::stream::StreamExt as _;
1795 use std::{
1796 any::Any,
1797 cell::RefCell,
1798 path::Path,
1799 rc::Rc,
1800 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1801 time::Duration,
1802 };
1803 use util::path;
1804
1805 fn init_test(cx: &mut TestAppContext) {
1806 env_logger::try_init().ok();
1807 cx.update(|cx| {
1808 let settings_store = SettingsStore::test(cx);
1809 cx.set_global(settings_store);
1810 Project::init_settings(cx);
1811 language::init(cx);
1812 });
1813 }
1814
1815 #[gpui::test]
1816 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1817 init_test(cx);
1818
1819 let fs = FakeFs::new(cx.executor());
1820 let project = Project::test(fs, [], cx).await;
1821 let connection = Rc::new(FakeAgentConnection::new());
1822 let thread = cx
1823 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1824 .await
1825 .unwrap();
1826
1827 // Test creating a new user message
1828 thread.update(cx, |thread, cx| {
1829 thread.push_user_content_block(
1830 None,
1831 acp::ContentBlock::Text(acp::TextContent {
1832 annotations: None,
1833 text: "Hello, ".to_string(),
1834 }),
1835 cx,
1836 );
1837 });
1838
1839 thread.update(cx, |thread, cx| {
1840 assert_eq!(thread.entries.len(), 1);
1841 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1842 assert_eq!(user_msg.id, None);
1843 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1844 } else {
1845 panic!("Expected UserMessage");
1846 }
1847 });
1848
1849 // Test appending to existing user message
1850 let message_1_id = UserMessageId::new();
1851 thread.update(cx, |thread, cx| {
1852 thread.push_user_content_block(
1853 Some(message_1_id.clone()),
1854 acp::ContentBlock::Text(acp::TextContent {
1855 annotations: None,
1856 text: "world!".to_string(),
1857 }),
1858 cx,
1859 );
1860 });
1861
1862 thread.update(cx, |thread, cx| {
1863 assert_eq!(thread.entries.len(), 1);
1864 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1865 assert_eq!(user_msg.id, Some(message_1_id));
1866 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1867 } else {
1868 panic!("Expected UserMessage");
1869 }
1870 });
1871
1872 // Test creating new user message after assistant message
1873 thread.update(cx, |thread, cx| {
1874 thread.push_assistant_content_block(
1875 acp::ContentBlock::Text(acp::TextContent {
1876 annotations: None,
1877 text: "Assistant response".to_string(),
1878 }),
1879 false,
1880 cx,
1881 );
1882 });
1883
1884 let message_2_id = UserMessageId::new();
1885 thread.update(cx, |thread, cx| {
1886 thread.push_user_content_block(
1887 Some(message_2_id.clone()),
1888 acp::ContentBlock::Text(acp::TextContent {
1889 annotations: None,
1890 text: "New user message".to_string(),
1891 }),
1892 cx,
1893 );
1894 });
1895
1896 thread.update(cx, |thread, cx| {
1897 assert_eq!(thread.entries.len(), 3);
1898 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1899 assert_eq!(user_msg.id, Some(message_2_id));
1900 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1901 } else {
1902 panic!("Expected UserMessage at index 2");
1903 }
1904 });
1905 }
1906
1907 #[gpui::test]
1908 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1909 init_test(cx);
1910
1911 let fs = FakeFs::new(cx.executor());
1912 let project = Project::test(fs, [], cx).await;
1913 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1914 |_, thread, mut cx| {
1915 async move {
1916 thread.update(&mut cx, |thread, cx| {
1917 thread
1918 .handle_session_update(
1919 acp::SessionUpdate::AgentThoughtChunk {
1920 content: "Thinking ".into(),
1921 },
1922 cx,
1923 )
1924 .unwrap();
1925 thread
1926 .handle_session_update(
1927 acp::SessionUpdate::AgentThoughtChunk {
1928 content: "hard!".into(),
1929 },
1930 cx,
1931 )
1932 .unwrap();
1933 })?;
1934 Ok(acp::PromptResponse {
1935 stop_reason: acp::StopReason::EndTurn,
1936 })
1937 }
1938 .boxed_local()
1939 },
1940 ));
1941
1942 let thread = cx
1943 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1944 .await
1945 .unwrap();
1946
1947 thread
1948 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1949 .await
1950 .unwrap();
1951
1952 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1953 assert_eq!(
1954 output,
1955 indoc! {r#"
1956 ## User
1957
1958 Hello from Zed!
1959
1960 ## Assistant
1961
1962 <thinking>
1963 Thinking hard!
1964 </thinking>
1965
1966 "#}
1967 );
1968 }
1969
1970 #[gpui::test]
1971 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1972 init_test(cx);
1973
1974 let fs = FakeFs::new(cx.executor());
1975 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1976 .await;
1977 let project = Project::test(fs.clone(), [], cx).await;
1978 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1979 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1980 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1981 move |_, thread, mut cx| {
1982 let read_file_tx = read_file_tx.clone();
1983 async move {
1984 let content = thread
1985 .update(&mut cx, |thread, cx| {
1986 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1987 })
1988 .unwrap()
1989 .await
1990 .unwrap();
1991 assert_eq!(content, "one\ntwo\nthree\n");
1992 read_file_tx.take().unwrap().send(()).unwrap();
1993 thread
1994 .update(&mut cx, |thread, cx| {
1995 thread.write_text_file(
1996 path!("/tmp/foo").into(),
1997 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1998 cx,
1999 )
2000 })
2001 .unwrap()
2002 .await
2003 .unwrap();
2004 Ok(acp::PromptResponse {
2005 stop_reason: acp::StopReason::EndTurn,
2006 })
2007 }
2008 .boxed_local()
2009 },
2010 ));
2011
2012 let (worktree, pathbuf) = project
2013 .update(cx, |project, cx| {
2014 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2015 })
2016 .await
2017 .unwrap();
2018 let buffer = project
2019 .update(cx, |project, cx| {
2020 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2021 })
2022 .await
2023 .unwrap();
2024
2025 let thread = cx
2026 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2027 .await
2028 .unwrap();
2029
2030 let request = thread.update(cx, |thread, cx| {
2031 thread.send_raw("Extend the count in /tmp/foo", cx)
2032 });
2033 read_file_rx.await.ok();
2034 buffer.update(cx, |buffer, cx| {
2035 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2036 });
2037 cx.run_until_parked();
2038 assert_eq!(
2039 buffer.read_with(cx, |buffer, _| buffer.text()),
2040 "zero\none\ntwo\nthree\nfour\nfive\n"
2041 );
2042 assert_eq!(
2043 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2044 "zero\none\ntwo\nthree\nfour\nfive\n"
2045 );
2046 request.await.unwrap();
2047 }
2048
2049 #[gpui::test]
2050 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2051 init_test(cx);
2052
2053 let fs = FakeFs::new(cx.executor());
2054 let project = Project::test(fs, [], cx).await;
2055 let id = acp::ToolCallId("test".into());
2056
2057 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2058 let id = id.clone();
2059 move |_, thread, mut cx| {
2060 let id = id.clone();
2061 async move {
2062 thread
2063 .update(&mut cx, |thread, cx| {
2064 thread.handle_session_update(
2065 acp::SessionUpdate::ToolCall(acp::ToolCall {
2066 id: id.clone(),
2067 title: "Label".into(),
2068 kind: acp::ToolKind::Fetch,
2069 status: acp::ToolCallStatus::InProgress,
2070 content: vec![],
2071 locations: vec![],
2072 raw_input: None,
2073 raw_output: None,
2074 }),
2075 cx,
2076 )
2077 })
2078 .unwrap()
2079 .unwrap();
2080 Ok(acp::PromptResponse {
2081 stop_reason: acp::StopReason::EndTurn,
2082 })
2083 }
2084 .boxed_local()
2085 }
2086 }));
2087
2088 let thread = cx
2089 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2090 .await
2091 .unwrap();
2092
2093 let request = thread.update(cx, |thread, cx| {
2094 thread.send_raw("Fetch https://example.com", cx)
2095 });
2096
2097 run_until_first_tool_call(&thread, cx).await;
2098
2099 thread.read_with(cx, |thread, _| {
2100 assert!(matches!(
2101 thread.entries[1],
2102 AgentThreadEntry::ToolCall(ToolCall {
2103 status: ToolCallStatus::InProgress,
2104 ..
2105 })
2106 ));
2107 });
2108
2109 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2110
2111 thread.read_with(cx, |thread, _| {
2112 assert!(matches!(
2113 &thread.entries[1],
2114 AgentThreadEntry::ToolCall(ToolCall {
2115 status: ToolCallStatus::Canceled,
2116 ..
2117 })
2118 ));
2119 });
2120
2121 thread
2122 .update(cx, |thread, cx| {
2123 thread.handle_session_update(
2124 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2125 id,
2126 fields: acp::ToolCallUpdateFields {
2127 status: Some(acp::ToolCallStatus::Completed),
2128 ..Default::default()
2129 },
2130 }),
2131 cx,
2132 )
2133 })
2134 .unwrap();
2135
2136 request.await.unwrap();
2137
2138 thread.read_with(cx, |thread, _| {
2139 assert!(matches!(
2140 thread.entries[1],
2141 AgentThreadEntry::ToolCall(ToolCall {
2142 status: ToolCallStatus::Completed,
2143 ..
2144 })
2145 ));
2146 });
2147 }
2148
2149 #[gpui::test]
2150 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2151 init_test(cx);
2152 let fs = FakeFs::new(cx.background_executor.clone());
2153 fs.insert_tree(path!("/test"), json!({})).await;
2154 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2155
2156 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2157 move |_, thread, mut cx| {
2158 async move {
2159 thread
2160 .update(&mut cx, |thread, cx| {
2161 thread.handle_session_update(
2162 acp::SessionUpdate::ToolCall(acp::ToolCall {
2163 id: acp::ToolCallId("test".into()),
2164 title: "Label".into(),
2165 kind: acp::ToolKind::Edit,
2166 status: acp::ToolCallStatus::Completed,
2167 content: vec![acp::ToolCallContent::Diff {
2168 diff: acp::Diff {
2169 path: "/test/test.txt".into(),
2170 old_text: None,
2171 new_text: "foo".into(),
2172 },
2173 }],
2174 locations: vec![],
2175 raw_input: None,
2176 raw_output: None,
2177 }),
2178 cx,
2179 )
2180 })
2181 .unwrap()
2182 .unwrap();
2183 Ok(acp::PromptResponse {
2184 stop_reason: acp::StopReason::EndTurn,
2185 })
2186 }
2187 .boxed_local()
2188 }
2189 }));
2190
2191 let thread = cx
2192 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2193 .await
2194 .unwrap();
2195
2196 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2197 .await
2198 .unwrap();
2199
2200 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2201 }
2202
2203 #[gpui::test(iterations = 10)]
2204 async fn test_checkpoints(cx: &mut TestAppContext) {
2205 init_test(cx);
2206 let fs = FakeFs::new(cx.background_executor.clone());
2207 fs.insert_tree(
2208 path!("/test"),
2209 json!({
2210 ".git": {}
2211 }),
2212 )
2213 .await;
2214 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2215
2216 let simulate_changes = Arc::new(AtomicBool::new(true));
2217 let next_filename = Arc::new(AtomicUsize::new(0));
2218 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2219 let simulate_changes = simulate_changes.clone();
2220 let next_filename = next_filename.clone();
2221 let fs = fs.clone();
2222 move |request, thread, mut cx| {
2223 let fs = fs.clone();
2224 let simulate_changes = simulate_changes.clone();
2225 let next_filename = next_filename.clone();
2226 async move {
2227 if simulate_changes.load(SeqCst) {
2228 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2229 fs.write(Path::new(&filename), b"").await?;
2230 }
2231
2232 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2233 panic!("expected text content block");
2234 };
2235 thread.update(&mut cx, |thread, cx| {
2236 thread
2237 .handle_session_update(
2238 acp::SessionUpdate::AgentMessageChunk {
2239 content: content.text.to_uppercase().into(),
2240 },
2241 cx,
2242 )
2243 .unwrap();
2244 })?;
2245 Ok(acp::PromptResponse {
2246 stop_reason: acp::StopReason::EndTurn,
2247 })
2248 }
2249 .boxed_local()
2250 }
2251 }));
2252 let thread = cx
2253 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2254 .await
2255 .unwrap();
2256
2257 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2258 .await
2259 .unwrap();
2260 thread.read_with(cx, |thread, cx| {
2261 assert_eq!(
2262 thread.to_markdown(cx),
2263 indoc! {"
2264 ## User (checkpoint)
2265
2266 Lorem
2267
2268 ## Assistant
2269
2270 LOREM
2271
2272 "}
2273 );
2274 });
2275 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2276
2277 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2278 .await
2279 .unwrap();
2280 thread.read_with(cx, |thread, cx| {
2281 assert_eq!(
2282 thread.to_markdown(cx),
2283 indoc! {"
2284 ## User (checkpoint)
2285
2286 Lorem
2287
2288 ## Assistant
2289
2290 LOREM
2291
2292 ## User (checkpoint)
2293
2294 ipsum
2295
2296 ## Assistant
2297
2298 IPSUM
2299
2300 "}
2301 );
2302 });
2303 assert_eq!(
2304 fs.files(),
2305 vec![
2306 Path::new(path!("/test/file-0")),
2307 Path::new(path!("/test/file-1"))
2308 ]
2309 );
2310
2311 // Checkpoint isn't stored when there are no changes.
2312 simulate_changes.store(false, SeqCst);
2313 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2314 .await
2315 .unwrap();
2316 thread.read_with(cx, |thread, cx| {
2317 assert_eq!(
2318 thread.to_markdown(cx),
2319 indoc! {"
2320 ## User (checkpoint)
2321
2322 Lorem
2323
2324 ## Assistant
2325
2326 LOREM
2327
2328 ## User (checkpoint)
2329
2330 ipsum
2331
2332 ## Assistant
2333
2334 IPSUM
2335
2336 ## User
2337
2338 dolor
2339
2340 ## Assistant
2341
2342 DOLOR
2343
2344 "}
2345 );
2346 });
2347 assert_eq!(
2348 fs.files(),
2349 vec![
2350 Path::new(path!("/test/file-0")),
2351 Path::new(path!("/test/file-1"))
2352 ]
2353 );
2354
2355 // Rewinding the conversation truncates the history and restores the checkpoint.
2356 thread
2357 .update(cx, |thread, cx| {
2358 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2359 panic!("unexpected entries {:?}", thread.entries)
2360 };
2361 thread.rewind(message.id.clone().unwrap(), cx)
2362 })
2363 .await
2364 .unwrap();
2365 thread.read_with(cx, |thread, cx| {
2366 assert_eq!(
2367 thread.to_markdown(cx),
2368 indoc! {"
2369 ## User (checkpoint)
2370
2371 Lorem
2372
2373 ## Assistant
2374
2375 LOREM
2376
2377 "}
2378 );
2379 });
2380 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2381 }
2382
2383 #[gpui::test]
2384 async fn test_refusal(cx: &mut TestAppContext) {
2385 init_test(cx);
2386 let fs = FakeFs::new(cx.background_executor.clone());
2387 fs.insert_tree(path!("/"), json!({})).await;
2388 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2389
2390 let refuse_next = Arc::new(AtomicBool::new(false));
2391 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2392 let refuse_next = refuse_next.clone();
2393 move |request, thread, mut cx| {
2394 let refuse_next = refuse_next.clone();
2395 async move {
2396 if refuse_next.load(SeqCst) {
2397 return Ok(acp::PromptResponse {
2398 stop_reason: acp::StopReason::Refusal,
2399 });
2400 }
2401
2402 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2403 panic!("expected text content block");
2404 };
2405 thread.update(&mut cx, |thread, cx| {
2406 thread
2407 .handle_session_update(
2408 acp::SessionUpdate::AgentMessageChunk {
2409 content: content.text.to_uppercase().into(),
2410 },
2411 cx,
2412 )
2413 .unwrap();
2414 })?;
2415 Ok(acp::PromptResponse {
2416 stop_reason: acp::StopReason::EndTurn,
2417 })
2418 }
2419 .boxed_local()
2420 }
2421 }));
2422 let thread = cx
2423 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2424 .await
2425 .unwrap();
2426
2427 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2428 .await
2429 .unwrap();
2430 thread.read_with(cx, |thread, cx| {
2431 assert_eq!(
2432 thread.to_markdown(cx),
2433 indoc! {"
2434 ## User
2435
2436 hello
2437
2438 ## Assistant
2439
2440 HELLO
2441
2442 "}
2443 );
2444 });
2445
2446 // Simulate refusing the second message, ensuring the conversation gets
2447 // truncated to before sending it.
2448 refuse_next.store(true, SeqCst);
2449 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2450 .await
2451 .unwrap();
2452 thread.read_with(cx, |thread, cx| {
2453 assert_eq!(
2454 thread.to_markdown(cx),
2455 indoc! {"
2456 ## User
2457
2458 hello
2459
2460 ## Assistant
2461
2462 HELLO
2463
2464 "}
2465 );
2466 });
2467 }
2468
2469 async fn run_until_first_tool_call(
2470 thread: &Entity<AcpThread>,
2471 cx: &mut TestAppContext,
2472 ) -> usize {
2473 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2474
2475 let subscription = cx.update(|cx| {
2476 cx.subscribe(thread, move |thread, _, cx| {
2477 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2478 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2479 return tx.try_send(ix).unwrap();
2480 }
2481 }
2482 })
2483 });
2484
2485 select! {
2486 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2487 panic!("Timeout waiting for tool call")
2488 }
2489 ix = rx.next().fuse() => {
2490 drop(subscription);
2491 ix.unwrap()
2492 }
2493 }
2494 }
2495
2496 #[derive(Clone, Default)]
2497 struct FakeAgentConnection {
2498 auth_methods: Vec<acp::AuthMethod>,
2499 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2500 on_user_message: Option<
2501 Rc<
2502 dyn Fn(
2503 acp::PromptRequest,
2504 WeakEntity<AcpThread>,
2505 AsyncApp,
2506 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2507 + 'static,
2508 >,
2509 >,
2510 }
2511
2512 impl FakeAgentConnection {
2513 fn new() -> Self {
2514 Self {
2515 auth_methods: Vec::new(),
2516 on_user_message: None,
2517 sessions: Arc::default(),
2518 }
2519 }
2520
2521 #[expect(unused)]
2522 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2523 self.auth_methods = auth_methods;
2524 self
2525 }
2526
2527 fn on_user_message(
2528 mut self,
2529 handler: impl Fn(
2530 acp::PromptRequest,
2531 WeakEntity<AcpThread>,
2532 AsyncApp,
2533 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2534 + 'static,
2535 ) -> Self {
2536 self.on_user_message.replace(Rc::new(handler));
2537 self
2538 }
2539 }
2540
2541 impl AgentConnection for FakeAgentConnection {
2542 fn auth_methods(&self) -> &[acp::AuthMethod] {
2543 &self.auth_methods
2544 }
2545
2546 fn new_thread(
2547 self: Rc<Self>,
2548 project: Entity<Project>,
2549 _cwd: &Path,
2550 cx: &mut App,
2551 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2552 let session_id = acp::SessionId(
2553 rand::thread_rng()
2554 .sample_iter(&rand::distributions::Alphanumeric)
2555 .take(7)
2556 .map(char::from)
2557 .collect::<String>()
2558 .into(),
2559 );
2560 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2561 let thread = cx.new(|_cx| {
2562 AcpThread::new(
2563 "Test",
2564 self.clone(),
2565 project,
2566 action_log,
2567 session_id.clone(),
2568 )
2569 });
2570 self.sessions.lock().insert(session_id, thread.downgrade());
2571 Task::ready(Ok(thread))
2572 }
2573
2574 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2575 if self.auth_methods().iter().any(|m| m.id == method) {
2576 Task::ready(Ok(()))
2577 } else {
2578 Task::ready(Err(anyhow!("Invalid Auth Method")))
2579 }
2580 }
2581
2582 fn prompt(
2583 &self,
2584 _id: Option<UserMessageId>,
2585 params: acp::PromptRequest,
2586 cx: &mut App,
2587 ) -> Task<gpui::Result<acp::PromptResponse>> {
2588 let sessions = self.sessions.lock();
2589 let thread = sessions.get(¶ms.session_id).unwrap();
2590 if let Some(handler) = &self.on_user_message {
2591 let handler = handler.clone();
2592 let thread = thread.clone();
2593 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2594 } else {
2595 Task::ready(Ok(acp::PromptResponse {
2596 stop_reason: acp::StopReason::EndTurn,
2597 }))
2598 }
2599 }
2600
2601 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
2602 acp::PromptCapabilities {
2603 image: true,
2604 audio: true,
2605 embedded_context: true,
2606 }
2607 }
2608
2609 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2610 let sessions = self.sessions.lock();
2611 let thread = sessions.get(session_id).unwrap().clone();
2612
2613 cx.spawn(async move |cx| {
2614 thread
2615 .update(cx, |thread, cx| thread.cancel(cx))
2616 .unwrap()
2617 .await
2618 })
2619 .detach();
2620 }
2621
2622 fn session_editor(
2623 &self,
2624 session_id: &acp::SessionId,
2625 _cx: &mut App,
2626 ) -> Option<Rc<dyn AgentSessionEditor>> {
2627 Some(Rc::new(FakeAgentSessionEditor {
2628 _session_id: session_id.clone(),
2629 }))
2630 }
2631
2632 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2633 self
2634 }
2635 }
2636
2637 struct FakeAgentSessionEditor {
2638 _session_id: acp::SessionId,
2639 }
2640
2641 impl AgentSessionEditor for FakeAgentSessionEditor {
2642 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2643 Task::ready(Ok(()))
2644 }
2645 }
2646}