1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use collections::HashSet;
7pub use connection::*;
8pub use diff::*;
9use language::language_settings::FormatOnSave;
10pub use mention::*;
11use project::lsp_store::{FormatTrigger, LspFormatTarget};
12use serde::{Deserialize, Serialize};
13pub use terminal::*;
14
15use action_log::ActionLog;
16use agent_client_protocol as acp;
17use anyhow::{Context as _, Result, anyhow};
18use editor::Bias;
19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
21use itertools::Itertools;
22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
23use markdown::Markdown;
24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
25use std::collections::HashMap;
26use std::error::Error;
27use std::fmt::{Formatter, Write};
28use std::ops::Range;
29use std::process::ExitStatus;
30use std::rc::Rc;
31use std::time::{Duration, Instant};
32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
33use ui::App;
34use util::ResultExt;
35
36#[derive(Debug)]
37pub struct UserMessage {
38 pub id: Option<UserMessageId>,
39 pub content: ContentBlock,
40 pub chunks: Vec<acp::ContentBlock>,
41 pub checkpoint: Option<Checkpoint>,
42}
43
44#[derive(Debug)]
45pub struct Checkpoint {
46 git_checkpoint: GitStoreCheckpoint,
47 pub show: bool,
48}
49
50impl UserMessage {
51 fn to_markdown(&self, cx: &App) -> String {
52 let mut markdown = String::new();
53 if self
54 .checkpoint
55 .as_ref()
56 .is_some_and(|checkpoint| checkpoint.show)
57 {
58 writeln!(markdown, "## User (checkpoint)").unwrap();
59 } else {
60 writeln!(markdown, "## User").unwrap();
61 }
62 writeln!(markdown).unwrap();
63 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
64 writeln!(markdown).unwrap();
65 markdown
66 }
67}
68
69#[derive(Debug, PartialEq)]
70pub struct AssistantMessage {
71 pub chunks: Vec<AssistantMessageChunk>,
72}
73
74impl AssistantMessage {
75 pub fn to_markdown(&self, cx: &App) -> String {
76 format!(
77 "## Assistant\n\n{}\n\n",
78 self.chunks
79 .iter()
80 .map(|chunk| chunk.to_markdown(cx))
81 .join("\n\n")
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub enum AssistantMessageChunk {
88 Message { block: ContentBlock },
89 Thought { block: ContentBlock },
90}
91
92impl AssistantMessageChunk {
93 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
94 Self::Message {
95 block: ContentBlock::new(chunk.into(), language_registry, cx),
96 }
97 }
98
99 fn to_markdown(&self, cx: &App) -> String {
100 match self {
101 Self::Message { block } => block.to_markdown(cx).to_string(),
102 Self::Thought { block } => {
103 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
104 }
105 }
106 }
107}
108
109#[derive(Debug)]
110pub enum AgentThreadEntry {
111 UserMessage(UserMessage),
112 AssistantMessage(AssistantMessage),
113 ToolCall(ToolCall),
114}
115
116impl AgentThreadEntry {
117 pub fn to_markdown(&self, cx: &App) -> String {
118 match self {
119 Self::UserMessage(message) => message.to_markdown(cx),
120 Self::AssistantMessage(message) => message.to_markdown(cx),
121 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
122 }
123 }
124
125 pub fn user_message(&self) -> Option<&UserMessage> {
126 if let AgentThreadEntry::UserMessage(message) = self {
127 Some(message)
128 } else {
129 None
130 }
131 }
132
133 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
134 if let AgentThreadEntry::ToolCall(call) = self {
135 itertools::Either::Left(call.diffs())
136 } else {
137 itertools::Either::Right(std::iter::empty())
138 }
139 }
140
141 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
142 if let AgentThreadEntry::ToolCall(call) = self {
143 itertools::Either::Left(call.terminals())
144 } else {
145 itertools::Either::Right(std::iter::empty())
146 }
147 }
148
149 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 locations,
152 resolved_locations,
153 ..
154 }) = self
155 {
156 Some((
157 locations.get(ix)?.clone(),
158 resolved_locations.get(ix)?.clone()?,
159 ))
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug)]
167pub struct ToolCall {
168 pub id: acp::ToolCallId,
169 pub label: Entity<Markdown>,
170 pub kind: acp::ToolKind,
171 pub content: Vec<ToolCallContent>,
172 pub status: ToolCallStatus,
173 pub locations: Vec<acp::ToolCallLocation>,
174 pub resolved_locations: Vec<Option<AgentLocation>>,
175 pub raw_input: Option<serde_json::Value>,
176 pub raw_output: Option<serde_json::Value>,
177}
178
179impl ToolCall {
180 fn from_acp(
181 tool_call: acp::ToolCall,
182 status: ToolCallStatus,
183 language_registry: Arc<LanguageRegistry>,
184 cx: &mut App,
185 ) -> Self {
186 Self {
187 id: tool_call.id,
188 label: cx.new(|cx| {
189 Markdown::new(
190 tool_call.title.into(),
191 Some(language_registry.clone()),
192 None,
193 cx,
194 )
195 }),
196 kind: tool_call.kind,
197 content: tool_call
198 .content
199 .into_iter()
200 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
201 .collect(),
202 locations: tool_call.locations,
203 resolved_locations: Vec::default(),
204 status,
205 raw_input: tool_call.raw_input,
206 raw_output: tool_call.raw_output,
207 }
208 }
209
210 fn update_fields(
211 &mut self,
212 fields: acp::ToolCallUpdateFields,
213 language_registry: Arc<LanguageRegistry>,
214 cx: &mut App,
215 ) {
216 let acp::ToolCallUpdateFields {
217 kind,
218 status,
219 title,
220 content,
221 locations,
222 raw_input,
223 raw_output,
224 } = fields;
225
226 if let Some(kind) = kind {
227 self.kind = kind;
228 }
229
230 if let Some(status) = status {
231 self.status = status.into();
232 }
233
234 if let Some(title) = title {
235 self.label.update(cx, |label, cx| {
236 label.replace(title, cx);
237 });
238 }
239
240 if let Some(content) = content {
241 self.content = content
242 .into_iter()
243 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
244 .collect();
245 }
246
247 if let Some(locations) = locations {
248 self.locations = locations;
249 }
250
251 if let Some(raw_input) = raw_input {
252 self.raw_input = Some(raw_input);
253 }
254
255 if let Some(raw_output) = raw_output {
256 if self.content.is_empty()
257 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
258 {
259 self.content
260 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
261 markdown,
262 }));
263 }
264 self.raw_output = Some(raw_output);
265 }
266 }
267
268 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
269 self.content.iter().filter_map(|content| match content {
270 ToolCallContent::Diff(diff) => Some(diff),
271 ToolCallContent::ContentBlock(_) => None,
272 ToolCallContent::Terminal(_) => None,
273 })
274 }
275
276 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
277 self.content.iter().filter_map(|content| match content {
278 ToolCallContent::Terminal(terminal) => Some(terminal),
279 ToolCallContent::ContentBlock(_) => None,
280 ToolCallContent::Diff(_) => None,
281 })
282 }
283
284 fn to_markdown(&self, cx: &App) -> String {
285 let mut markdown = format!(
286 "**Tool Call: {}**\nStatus: {}\n\n",
287 self.label.read(cx).source(),
288 self.status
289 );
290 for content in &self.content {
291 markdown.push_str(content.to_markdown(cx).as_str());
292 markdown.push_str("\n\n");
293 }
294 markdown
295 }
296
297 async fn resolve_location(
298 location: acp::ToolCallLocation,
299 project: WeakEntity<Project>,
300 cx: &mut AsyncApp,
301 ) -> Option<AgentLocation> {
302 let buffer = project
303 .update(cx, |project, cx| {
304 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
305 Some(project.open_buffer(path, cx))
306 } else {
307 None
308 }
309 })
310 .ok()??;
311 let buffer = buffer.await.log_err()?;
312 let position = buffer
313 .update(cx, |buffer, _| {
314 if let Some(row) = location.line {
315 let snapshot = buffer.snapshot();
316 let column = snapshot.indent_size_for_line(row).len;
317 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
318 snapshot.anchor_before(point)
319 } else {
320 Anchor::MIN
321 }
322 })
323 .ok()?;
324
325 Some(AgentLocation {
326 buffer: buffer.downgrade(),
327 position,
328 })
329 }
330
331 fn resolve_locations(
332 &self,
333 project: Entity<Project>,
334 cx: &mut App,
335 ) -> Task<Vec<Option<AgentLocation>>> {
336 let locations = self.locations.clone();
337 project.update(cx, |_, cx| {
338 cx.spawn(async move |project, cx| {
339 let mut new_locations = Vec::new();
340 for location in locations {
341 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
342 }
343 new_locations
344 })
345 })
346 }
347}
348
349#[derive(Debug)]
350pub enum ToolCallStatus {
351 /// The tool call hasn't started running yet, but we start showing it to
352 /// the user.
353 Pending,
354 /// The tool call is waiting for confirmation from the user.
355 WaitingForConfirmation {
356 options: Vec<acp::PermissionOption>,
357 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
358 },
359 /// The tool call is currently running.
360 InProgress,
361 /// The tool call completed successfully.
362 Completed,
363 /// The tool call failed.
364 Failed,
365 /// The user rejected the tool call.
366 Rejected,
367 /// The user canceled generation so the tool call was canceled.
368 Canceled,
369}
370
371impl From<acp::ToolCallStatus> for ToolCallStatus {
372 fn from(status: acp::ToolCallStatus) -> Self {
373 match status {
374 acp::ToolCallStatus::Pending => Self::Pending,
375 acp::ToolCallStatus::InProgress => Self::InProgress,
376 acp::ToolCallStatus::Completed => Self::Completed,
377 acp::ToolCallStatus::Failed => Self::Failed,
378 }
379 }
380}
381
382impl Display for ToolCallStatus {
383 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384 write!(
385 f,
386 "{}",
387 match self {
388 ToolCallStatus::Pending => "Pending",
389 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
390 ToolCallStatus::InProgress => "In Progress",
391 ToolCallStatus::Completed => "Completed",
392 ToolCallStatus::Failed => "Failed",
393 ToolCallStatus::Rejected => "Rejected",
394 ToolCallStatus::Canceled => "Canceled",
395 }
396 )
397 }
398}
399
400#[derive(Debug, PartialEq, Clone)]
401pub enum ContentBlock {
402 Empty,
403 Markdown { markdown: Entity<Markdown> },
404 ResourceLink { resource_link: acp::ResourceLink },
405}
406
407impl ContentBlock {
408 pub fn new(
409 block: acp::ContentBlock,
410 language_registry: &Arc<LanguageRegistry>,
411 cx: &mut App,
412 ) -> Self {
413 let mut this = Self::Empty;
414 this.append(block, language_registry, cx);
415 this
416 }
417
418 pub fn new_combined(
419 blocks: impl IntoIterator<Item = acp::ContentBlock>,
420 language_registry: Arc<LanguageRegistry>,
421 cx: &mut App,
422 ) -> Self {
423 let mut this = Self::Empty;
424 for block in blocks {
425 this.append(block, &language_registry, cx);
426 }
427 this
428 }
429
430 pub fn append(
431 &mut self,
432 block: acp::ContentBlock,
433 language_registry: &Arc<LanguageRegistry>,
434 cx: &mut App,
435 ) {
436 if matches!(self, ContentBlock::Empty)
437 && let acp::ContentBlock::ResourceLink(resource_link) = block
438 {
439 *self = ContentBlock::ResourceLink { resource_link };
440 return;
441 }
442
443 let new_content = self.block_string_contents(block);
444
445 match self {
446 ContentBlock::Empty => {
447 *self = Self::create_markdown_block(new_content, language_registry, cx);
448 }
449 ContentBlock::Markdown { markdown } => {
450 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
451 }
452 ContentBlock::ResourceLink { resource_link } => {
453 let existing_content = Self::resource_link_md(&resource_link.uri);
454 let combined = format!("{}\n{}", existing_content, new_content);
455
456 *self = Self::create_markdown_block(combined, language_registry, cx);
457 }
458 }
459 }
460
461 fn create_markdown_block(
462 content: String,
463 language_registry: &Arc<LanguageRegistry>,
464 cx: &mut App,
465 ) -> ContentBlock {
466 ContentBlock::Markdown {
467 markdown: cx
468 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
469 }
470 }
471
472 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
473 match block {
474 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
475 acp::ContentBlock::ResourceLink(resource_link) => {
476 Self::resource_link_md(&resource_link.uri)
477 }
478 acp::ContentBlock::Resource(acp::EmbeddedResource {
479 resource:
480 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
481 uri,
482 ..
483 }),
484 ..
485 }) => Self::resource_link_md(&uri),
486 acp::ContentBlock::Image(image) => Self::image_md(&image),
487 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
488 }
489 }
490
491 fn resource_link_md(uri: &str) -> String {
492 if let Some(uri) = MentionUri::parse(uri).log_err() {
493 uri.as_link().to_string()
494 } else {
495 uri.to_string()
496 }
497 }
498
499 fn image_md(_image: &acp::ImageContent) -> String {
500 "`Image`".into()
501 }
502
503 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
504 match self {
505 ContentBlock::Empty => "",
506 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
507 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
508 }
509 }
510
511 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
512 match self {
513 ContentBlock::Empty => None,
514 ContentBlock::Markdown { markdown } => Some(markdown),
515 ContentBlock::ResourceLink { .. } => None,
516 }
517 }
518
519 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
520 match self {
521 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
522 _ => None,
523 }
524 }
525}
526
527#[derive(Debug)]
528pub enum ToolCallContent {
529 ContentBlock(ContentBlock),
530 Diff(Entity<Diff>),
531 Terminal(Entity<Terminal>),
532}
533
534impl ToolCallContent {
535 pub fn from_acp(
536 content: acp::ToolCallContent,
537 language_registry: Arc<LanguageRegistry>,
538 cx: &mut App,
539 ) -> Self {
540 match content {
541 acp::ToolCallContent::Content { content } => {
542 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
543 }
544 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
545 Diff::finalized(
546 diff.path,
547 diff.old_text,
548 diff.new_text,
549 language_registry,
550 cx,
551 )
552 })),
553 }
554 }
555
556 pub fn to_markdown(&self, cx: &App) -> String {
557 match self {
558 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
559 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
560 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
561 }
562 }
563}
564
565#[derive(Debug, PartialEq)]
566pub enum ToolCallUpdate {
567 UpdateFields(acp::ToolCallUpdate),
568 UpdateDiff(ToolCallUpdateDiff),
569 UpdateTerminal(ToolCallUpdateTerminal),
570}
571
572impl ToolCallUpdate {
573 fn id(&self) -> &acp::ToolCallId {
574 match self {
575 Self::UpdateFields(update) => &update.id,
576 Self::UpdateDiff(diff) => &diff.id,
577 Self::UpdateTerminal(terminal) => &terminal.id,
578 }
579 }
580}
581
582impl From<acp::ToolCallUpdate> for ToolCallUpdate {
583 fn from(update: acp::ToolCallUpdate) -> Self {
584 Self::UpdateFields(update)
585 }
586}
587
588impl From<ToolCallUpdateDiff> for ToolCallUpdate {
589 fn from(diff: ToolCallUpdateDiff) -> Self {
590 Self::UpdateDiff(diff)
591 }
592}
593
594#[derive(Debug, PartialEq)]
595pub struct ToolCallUpdateDiff {
596 pub id: acp::ToolCallId,
597 pub diff: Entity<Diff>,
598}
599
600impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
601 fn from(terminal: ToolCallUpdateTerminal) -> Self {
602 Self::UpdateTerminal(terminal)
603 }
604}
605
606#[derive(Debug, PartialEq)]
607pub struct ToolCallUpdateTerminal {
608 pub id: acp::ToolCallId,
609 pub terminal: Entity<Terminal>,
610}
611
612#[derive(Debug, Default)]
613pub struct Plan {
614 pub entries: Vec<PlanEntry>,
615}
616
617#[derive(Debug)]
618pub struct PlanStats<'a> {
619 pub in_progress_entry: Option<&'a PlanEntry>,
620 pub pending: u32,
621 pub completed: u32,
622}
623
624impl Plan {
625 pub fn is_empty(&self) -> bool {
626 self.entries.is_empty()
627 }
628
629 pub fn stats(&self) -> PlanStats<'_> {
630 let mut stats = PlanStats {
631 in_progress_entry: None,
632 pending: 0,
633 completed: 0,
634 };
635
636 for entry in &self.entries {
637 match &entry.status {
638 acp::PlanEntryStatus::Pending => {
639 stats.pending += 1;
640 }
641 acp::PlanEntryStatus::InProgress => {
642 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
643 }
644 acp::PlanEntryStatus::Completed => {
645 stats.completed += 1;
646 }
647 }
648 }
649
650 stats
651 }
652}
653
654#[derive(Debug)]
655pub struct PlanEntry {
656 pub content: Entity<Markdown>,
657 pub priority: acp::PlanEntryPriority,
658 pub status: acp::PlanEntryStatus,
659}
660
661impl PlanEntry {
662 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
663 Self {
664 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
665 priority: entry.priority,
666 status: entry.status,
667 }
668 }
669}
670
671#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
672pub struct TokenUsage {
673 pub max_tokens: u64,
674 pub used_tokens: u64,
675}
676
677impl TokenUsage {
678 pub fn ratio(&self) -> TokenUsageRatio {
679 #[cfg(debug_assertions)]
680 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
681 .unwrap_or("0.8".to_string())
682 .parse()
683 .unwrap();
684 #[cfg(not(debug_assertions))]
685 let warning_threshold: f32 = 0.8;
686
687 // When the maximum is unknown because there is no selected model,
688 // avoid showing the token limit warning.
689 if self.max_tokens == 0 {
690 TokenUsageRatio::Normal
691 } else if self.used_tokens >= self.max_tokens {
692 TokenUsageRatio::Exceeded
693 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
694 TokenUsageRatio::Warning
695 } else {
696 TokenUsageRatio::Normal
697 }
698 }
699}
700
701#[derive(Debug, Clone, PartialEq, Eq)]
702pub enum TokenUsageRatio {
703 Normal,
704 Warning,
705 Exceeded,
706}
707
708#[derive(Debug, Clone)]
709pub struct RetryStatus {
710 pub last_error: SharedString,
711 pub attempt: usize,
712 pub max_attempts: usize,
713 pub started_at: Instant,
714 pub duration: Duration,
715}
716
717pub struct AcpThread {
718 title: SharedString,
719 entries: Vec<AgentThreadEntry>,
720 plan: Plan,
721 project: Entity<Project>,
722 action_log: Entity<ActionLog>,
723 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
724 send_task: Option<Task<()>>,
725 connection: Rc<dyn AgentConnection>,
726 session_id: acp::SessionId,
727 token_usage: Option<TokenUsage>,
728}
729
730#[derive(Debug)]
731pub enum AcpThreadEvent {
732 NewEntry,
733 TitleUpdated,
734 TokenUsageUpdated,
735 EntryUpdated(usize),
736 EntriesRemoved(Range<usize>),
737 ToolAuthorizationRequired,
738 Retry(RetryStatus),
739 Stopped,
740 Error,
741 LoadError(LoadError),
742}
743
744impl EventEmitter<AcpThreadEvent> for AcpThread {}
745
746#[derive(PartialEq, Eq)]
747pub enum ThreadStatus {
748 Idle,
749 WaitingForToolConfirmation,
750 Generating,
751}
752
753#[derive(Debug, Clone)]
754pub enum LoadError {
755 NotInstalled {
756 error_message: SharedString,
757 install_message: SharedString,
758 install_command: String,
759 },
760 Unsupported {
761 error_message: SharedString,
762 upgrade_message: SharedString,
763 upgrade_command: String,
764 },
765 Exited {
766 status: ExitStatus,
767 },
768 Other(SharedString),
769}
770
771impl Display for LoadError {
772 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
773 match self {
774 LoadError::NotInstalled { error_message, .. }
775 | LoadError::Unsupported { error_message, .. } => {
776 write!(f, "{error_message}")
777 }
778 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
779 LoadError::Other(msg) => write!(f, "{}", msg),
780 }
781 }
782}
783
784impl Error for LoadError {}
785
786impl AcpThread {
787 pub fn new(
788 title: impl Into<SharedString>,
789 connection: Rc<dyn AgentConnection>,
790 project: Entity<Project>,
791 action_log: Entity<ActionLog>,
792 session_id: acp::SessionId,
793 ) -> Self {
794 Self {
795 action_log,
796 shared_buffers: Default::default(),
797 entries: Default::default(),
798 plan: Default::default(),
799 title: title.into(),
800 project,
801 send_task: None,
802 connection,
803 session_id,
804 token_usage: None,
805 }
806 }
807
808 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
809 &self.connection
810 }
811
812 pub fn action_log(&self) -> &Entity<ActionLog> {
813 &self.action_log
814 }
815
816 pub fn project(&self) -> &Entity<Project> {
817 &self.project
818 }
819
820 pub fn title(&self) -> SharedString {
821 self.title.clone()
822 }
823
824 pub fn entries(&self) -> &[AgentThreadEntry] {
825 &self.entries
826 }
827
828 pub fn session_id(&self) -> &acp::SessionId {
829 &self.session_id
830 }
831
832 pub fn status(&self) -> ThreadStatus {
833 if self.send_task.is_some() {
834 if self.waiting_for_tool_confirmation() {
835 ThreadStatus::WaitingForToolConfirmation
836 } else {
837 ThreadStatus::Generating
838 }
839 } else {
840 ThreadStatus::Idle
841 }
842 }
843
844 pub fn token_usage(&self) -> Option<&TokenUsage> {
845 self.token_usage.as_ref()
846 }
847
848 pub fn has_pending_edit_tool_calls(&self) -> bool {
849 for entry in self.entries.iter().rev() {
850 match entry {
851 AgentThreadEntry::UserMessage(_) => return false,
852 AgentThreadEntry::ToolCall(
853 call @ ToolCall {
854 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
855 ..
856 },
857 ) if call.diffs().next().is_some() => {
858 return true;
859 }
860 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
861 }
862 }
863
864 false
865 }
866
867 pub fn used_tools_since_last_user_message(&self) -> bool {
868 for entry in self.entries.iter().rev() {
869 match entry {
870 AgentThreadEntry::UserMessage(..) => return false,
871 AgentThreadEntry::AssistantMessage(..) => continue,
872 AgentThreadEntry::ToolCall(..) => return true,
873 }
874 }
875
876 false
877 }
878
879 pub fn handle_session_update(
880 &mut self,
881 update: acp::SessionUpdate,
882 cx: &mut Context<Self>,
883 ) -> Result<(), acp::Error> {
884 match update {
885 acp::SessionUpdate::UserMessageChunk { content } => {
886 self.push_user_content_block(None, content, cx);
887 }
888 acp::SessionUpdate::AgentMessageChunk { content } => {
889 self.push_assistant_content_block(content, false, cx);
890 }
891 acp::SessionUpdate::AgentThoughtChunk { content } => {
892 self.push_assistant_content_block(content, true, cx);
893 }
894 acp::SessionUpdate::ToolCall(tool_call) => {
895 self.upsert_tool_call(tool_call, cx)?;
896 }
897 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
898 self.update_tool_call(tool_call_update, cx)?;
899 }
900 acp::SessionUpdate::Plan(plan) => {
901 self.update_plan(plan, cx);
902 }
903 }
904 Ok(())
905 }
906
907 pub fn push_user_content_block(
908 &mut self,
909 message_id: Option<UserMessageId>,
910 chunk: acp::ContentBlock,
911 cx: &mut Context<Self>,
912 ) {
913 let language_registry = self.project.read(cx).languages().clone();
914 let entries_len = self.entries.len();
915
916 if let Some(last_entry) = self.entries.last_mut()
917 && let AgentThreadEntry::UserMessage(UserMessage {
918 id,
919 content,
920 chunks,
921 ..
922 }) = last_entry
923 {
924 *id = message_id.or(id.take());
925 content.append(chunk.clone(), &language_registry, cx);
926 chunks.push(chunk);
927 let idx = entries_len - 1;
928 cx.emit(AcpThreadEvent::EntryUpdated(idx));
929 } else {
930 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
931 self.push_entry(
932 AgentThreadEntry::UserMessage(UserMessage {
933 id: message_id,
934 content,
935 chunks: vec![chunk],
936 checkpoint: None,
937 }),
938 cx,
939 );
940 }
941 }
942
943 pub fn push_assistant_content_block(
944 &mut self,
945 chunk: acp::ContentBlock,
946 is_thought: bool,
947 cx: &mut Context<Self>,
948 ) {
949 let language_registry = self.project.read(cx).languages().clone();
950 let entries_len = self.entries.len();
951 if let Some(last_entry) = self.entries.last_mut()
952 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
953 {
954 let idx = entries_len - 1;
955 cx.emit(AcpThreadEvent::EntryUpdated(idx));
956 match (chunks.last_mut(), is_thought) {
957 (Some(AssistantMessageChunk::Message { block }), false)
958 | (Some(AssistantMessageChunk::Thought { block }), true) => {
959 block.append(chunk, &language_registry, cx)
960 }
961 _ => {
962 let block = ContentBlock::new(chunk, &language_registry, cx);
963 if is_thought {
964 chunks.push(AssistantMessageChunk::Thought { block })
965 } else {
966 chunks.push(AssistantMessageChunk::Message { block })
967 }
968 }
969 }
970 } else {
971 let block = ContentBlock::new(chunk, &language_registry, cx);
972 let chunk = if is_thought {
973 AssistantMessageChunk::Thought { block }
974 } else {
975 AssistantMessageChunk::Message { block }
976 };
977
978 self.push_entry(
979 AgentThreadEntry::AssistantMessage(AssistantMessage {
980 chunks: vec![chunk],
981 }),
982 cx,
983 );
984 }
985 }
986
987 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
988 self.entries.push(entry);
989 cx.emit(AcpThreadEvent::NewEntry);
990 }
991
992 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
993 self.title = title;
994 cx.emit(AcpThreadEvent::TitleUpdated);
995 Ok(())
996 }
997
998 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
999 self.token_usage = usage;
1000 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1001 }
1002
1003 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1004 cx.emit(AcpThreadEvent::Retry(status));
1005 }
1006
1007 pub fn update_tool_call(
1008 &mut self,
1009 update: impl Into<ToolCallUpdate>,
1010 cx: &mut Context<Self>,
1011 ) -> Result<()> {
1012 let update = update.into();
1013 let languages = self.project.read(cx).languages().clone();
1014
1015 let (ix, current_call) = self
1016 .tool_call_mut(update.id())
1017 .context("Tool call not found")?;
1018 match update {
1019 ToolCallUpdate::UpdateFields(update) => {
1020 let location_updated = update.fields.locations.is_some();
1021 current_call.update_fields(update.fields, languages, cx);
1022 if location_updated {
1023 self.resolve_locations(update.id.clone(), cx);
1024 }
1025 }
1026 ToolCallUpdate::UpdateDiff(update) => {
1027 current_call.content.clear();
1028 current_call
1029 .content
1030 .push(ToolCallContent::Diff(update.diff));
1031 }
1032 ToolCallUpdate::UpdateTerminal(update) => {
1033 current_call.content.clear();
1034 current_call
1035 .content
1036 .push(ToolCallContent::Terminal(update.terminal));
1037 }
1038 }
1039
1040 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1041
1042 Ok(())
1043 }
1044
1045 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1046 pub fn upsert_tool_call(
1047 &mut self,
1048 tool_call: acp::ToolCall,
1049 cx: &mut Context<Self>,
1050 ) -> Result<(), acp::Error> {
1051 let status = tool_call.status.into();
1052 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1053 }
1054
1055 /// Fails if id does not match an existing entry.
1056 pub fn upsert_tool_call_inner(
1057 &mut self,
1058 tool_call_update: acp::ToolCallUpdate,
1059 status: ToolCallStatus,
1060 cx: &mut Context<Self>,
1061 ) -> Result<(), acp::Error> {
1062 let language_registry = self.project.read(cx).languages().clone();
1063 let id = tool_call_update.id.clone();
1064
1065 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1066 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1067 current_call.status = status;
1068
1069 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1070 } else {
1071 let call =
1072 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1073 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1074 };
1075
1076 self.resolve_locations(id, cx);
1077 Ok(())
1078 }
1079
1080 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1081 // The tool call we are looking for is typically the last one, or very close to the end.
1082 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1083 self.entries
1084 .iter_mut()
1085 .enumerate()
1086 .rev()
1087 .find_map(|(index, tool_call)| {
1088 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1089 && &tool_call.id == id
1090 {
1091 Some((index, tool_call))
1092 } else {
1093 None
1094 }
1095 })
1096 }
1097
1098 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1099 self.entries
1100 .iter()
1101 .enumerate()
1102 .rev()
1103 .find_map(|(index, tool_call)| {
1104 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1105 && &tool_call.id == id
1106 {
1107 Some((index, tool_call))
1108 } else {
1109 None
1110 }
1111 })
1112 }
1113
1114 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1115 let project = self.project.clone();
1116 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1117 return;
1118 };
1119 let task = tool_call.resolve_locations(project, cx);
1120 cx.spawn(async move |this, cx| {
1121 let resolved_locations = task.await;
1122 this.update(cx, |this, cx| {
1123 let project = this.project.clone();
1124 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1125 return;
1126 };
1127 if let Some(Some(location)) = resolved_locations.last() {
1128 project.update(cx, |project, cx| {
1129 if let Some(agent_location) = project.agent_location() {
1130 let should_ignore = agent_location.buffer == location.buffer
1131 && location
1132 .buffer
1133 .update(cx, |buffer, _| {
1134 let snapshot = buffer.snapshot();
1135 let old_position =
1136 agent_location.position.to_point(&snapshot);
1137 let new_position = location.position.to_point(&snapshot);
1138 // ignore this so that when we get updates from the edit tool
1139 // the position doesn't reset to the startof line
1140 old_position.row == new_position.row
1141 && old_position.column > new_position.column
1142 })
1143 .ok()
1144 .unwrap_or_default();
1145 if !should_ignore {
1146 project.set_agent_location(Some(location.clone()), cx);
1147 }
1148 }
1149 });
1150 }
1151 if tool_call.resolved_locations != resolved_locations {
1152 tool_call.resolved_locations = resolved_locations;
1153 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1154 }
1155 })
1156 })
1157 .detach();
1158 }
1159
1160 pub fn request_tool_call_authorization(
1161 &mut self,
1162 tool_call: acp::ToolCallUpdate,
1163 options: Vec<acp::PermissionOption>,
1164 cx: &mut Context<Self>,
1165 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1166 let (tx, rx) = oneshot::channel();
1167
1168 let status = ToolCallStatus::WaitingForConfirmation {
1169 options,
1170 respond_tx: tx,
1171 };
1172
1173 self.upsert_tool_call_inner(tool_call, status, cx)?;
1174 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1175 Ok(rx)
1176 }
1177
1178 pub fn authorize_tool_call(
1179 &mut self,
1180 id: acp::ToolCallId,
1181 option_id: acp::PermissionOptionId,
1182 option_kind: acp::PermissionOptionKind,
1183 cx: &mut Context<Self>,
1184 ) {
1185 let Some((ix, call)) = self.tool_call_mut(&id) else {
1186 return;
1187 };
1188
1189 let new_status = match option_kind {
1190 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1191 ToolCallStatus::Rejected
1192 }
1193 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1194 ToolCallStatus::InProgress
1195 }
1196 };
1197
1198 let curr_status = mem::replace(&mut call.status, new_status);
1199
1200 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1201 respond_tx.send(option_id).log_err();
1202 } else if cfg!(debug_assertions) {
1203 panic!("tried to authorize an already authorized tool call");
1204 }
1205
1206 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1207 }
1208
1209 /// Returns true if the last turn is awaiting tool authorization
1210 pub fn waiting_for_tool_confirmation(&self) -> bool {
1211 for entry in self.entries.iter().rev() {
1212 match &entry {
1213 AgentThreadEntry::ToolCall(call) => match call.status {
1214 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1215 ToolCallStatus::Pending
1216 | ToolCallStatus::InProgress
1217 | ToolCallStatus::Completed
1218 | ToolCallStatus::Failed
1219 | ToolCallStatus::Rejected
1220 | ToolCallStatus::Canceled => continue,
1221 },
1222 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1223 // Reached the beginning of the turn
1224 return false;
1225 }
1226 }
1227 }
1228 false
1229 }
1230
1231 pub fn plan(&self) -> &Plan {
1232 &self.plan
1233 }
1234
1235 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1236 let new_entries_len = request.entries.len();
1237 let mut new_entries = request.entries.into_iter();
1238
1239 // Reuse existing markdown to prevent flickering
1240 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1241 let PlanEntry {
1242 content,
1243 priority,
1244 status,
1245 } = old;
1246 content.update(cx, |old, cx| {
1247 old.replace(new.content, cx);
1248 });
1249 *priority = new.priority;
1250 *status = new.status;
1251 }
1252 for new in new_entries {
1253 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1254 }
1255 self.plan.entries.truncate(new_entries_len);
1256
1257 cx.notify();
1258 }
1259
1260 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1261 self.plan
1262 .entries
1263 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1264 cx.notify();
1265 }
1266
1267 #[cfg(any(test, feature = "test-support"))]
1268 pub fn send_raw(
1269 &mut self,
1270 message: &str,
1271 cx: &mut Context<Self>,
1272 ) -> BoxFuture<'static, Result<()>> {
1273 self.send(
1274 vec![acp::ContentBlock::Text(acp::TextContent {
1275 text: message.to_string(),
1276 annotations: None,
1277 })],
1278 cx,
1279 )
1280 }
1281
1282 pub fn send(
1283 &mut self,
1284 message: Vec<acp::ContentBlock>,
1285 cx: &mut Context<Self>,
1286 ) -> BoxFuture<'static, Result<()>> {
1287 let block = ContentBlock::new_combined(
1288 message.clone(),
1289 self.project.read(cx).languages().clone(),
1290 cx,
1291 );
1292 let request = acp::PromptRequest {
1293 prompt: message.clone(),
1294 session_id: self.session_id.clone(),
1295 };
1296 let git_store = self.project.read(cx).git_store().clone();
1297
1298 let message_id = if self
1299 .connection
1300 .session_editor(&self.session_id, cx)
1301 .is_some()
1302 {
1303 Some(UserMessageId::new())
1304 } else {
1305 None
1306 };
1307
1308 self.run_turn(cx, async move |this, cx| {
1309 this.update(cx, |this, cx| {
1310 this.push_entry(
1311 AgentThreadEntry::UserMessage(UserMessage {
1312 id: message_id.clone(),
1313 content: block,
1314 chunks: message,
1315 checkpoint: None,
1316 }),
1317 cx,
1318 );
1319 })
1320 .ok();
1321
1322 let old_checkpoint = git_store
1323 .update(cx, |git, cx| git.checkpoint(cx))?
1324 .await
1325 .context("failed to get old checkpoint")
1326 .log_err();
1327 this.update(cx, |this, cx| {
1328 if let Some((_ix, message)) = this.last_user_message() {
1329 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1330 git_checkpoint,
1331 show: false,
1332 });
1333 }
1334 this.connection.prompt(message_id, request, cx)
1335 })?
1336 .await
1337 })
1338 }
1339
1340 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1341 self.run_turn(cx, async move |this, cx| {
1342 this.update(cx, |this, cx| {
1343 this.connection
1344 .resume(&this.session_id, cx)
1345 .map(|resume| resume.run(cx))
1346 })?
1347 .context("resuming a session is not supported")?
1348 .await
1349 })
1350 }
1351
1352 fn run_turn(
1353 &mut self,
1354 cx: &mut Context<Self>,
1355 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1356 ) -> BoxFuture<'static, Result<()>> {
1357 self.clear_completed_plan_entries(cx);
1358
1359 let (tx, rx) = oneshot::channel();
1360 let cancel_task = self.cancel(cx);
1361
1362 self.send_task = Some(cx.spawn(async move |this, cx| {
1363 cancel_task.await;
1364 tx.send(f(this, cx).await).ok();
1365 }));
1366
1367 cx.spawn(async move |this, cx| {
1368 let response = rx.await;
1369
1370 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1371 .await?;
1372
1373 this.update(cx, |this, cx| {
1374 this.project
1375 .update(cx, |project, cx| project.set_agent_location(None, cx));
1376 match response {
1377 Ok(Err(e)) => {
1378 this.send_task.take();
1379 cx.emit(AcpThreadEvent::Error);
1380 Err(e)
1381 }
1382 result => {
1383 let canceled = matches!(
1384 result,
1385 Ok(Ok(acp::PromptResponse {
1386 stop_reason: acp::StopReason::Canceled
1387 }))
1388 );
1389
1390 // We only take the task if the current prompt wasn't canceled.
1391 //
1392 // This prompt may have been canceled because another one was sent
1393 // while it was still generating. In these cases, dropping `send_task`
1394 // would cause the next generation to be canceled.
1395 if !canceled {
1396 this.send_task.take();
1397 }
1398
1399 cx.emit(AcpThreadEvent::Stopped);
1400 Ok(())
1401 }
1402 }
1403 })?
1404 })
1405 .boxed()
1406 }
1407
1408 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1409 let Some(send_task) = self.send_task.take() else {
1410 return Task::ready(());
1411 };
1412
1413 for entry in self.entries.iter_mut() {
1414 if let AgentThreadEntry::ToolCall(call) = entry {
1415 let cancel = matches!(
1416 call.status,
1417 ToolCallStatus::Pending
1418 | ToolCallStatus::WaitingForConfirmation { .. }
1419 | ToolCallStatus::InProgress
1420 );
1421
1422 if cancel {
1423 call.status = ToolCallStatus::Canceled;
1424 }
1425 }
1426 }
1427
1428 self.connection.cancel(&self.session_id, cx);
1429
1430 // Wait for the send task to complete
1431 cx.foreground_executor().spawn(send_task)
1432 }
1433
1434 /// Rewinds this thread to before the entry at `index`, removing it and all
1435 /// subsequent entries while reverting any changes made from that point.
1436 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1437 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1438 return Task::ready(Err(anyhow!("not supported")));
1439 };
1440 let Some(message) = self.user_message(&id) else {
1441 return Task::ready(Err(anyhow!("message not found")));
1442 };
1443
1444 let checkpoint = message
1445 .checkpoint
1446 .as_ref()
1447 .map(|c| c.git_checkpoint.clone());
1448
1449 let git_store = self.project.read(cx).git_store().clone();
1450 cx.spawn(async move |this, cx| {
1451 if let Some(checkpoint) = checkpoint {
1452 git_store
1453 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1454 .await?;
1455 }
1456
1457 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1458 .await?;
1459 this.update(cx, |this, cx| {
1460 if let Some((ix, _)) = this.user_message_mut(&id) {
1461 let range = ix..this.entries.len();
1462 this.entries.truncate(ix);
1463 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1464 }
1465 })
1466 })
1467 }
1468
1469 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1470 let git_store = self.project.read(cx).git_store().clone();
1471
1472 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1473 if let Some(checkpoint) = message.checkpoint.as_ref() {
1474 checkpoint.git_checkpoint.clone()
1475 } else {
1476 return Task::ready(Ok(()));
1477 }
1478 } else {
1479 return Task::ready(Ok(()));
1480 };
1481
1482 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1483 cx.spawn(async move |this, cx| {
1484 let new_checkpoint = new_checkpoint
1485 .await
1486 .context("failed to get new checkpoint")
1487 .log_err();
1488 if let Some(new_checkpoint) = new_checkpoint {
1489 let equal = git_store
1490 .update(cx, |git, cx| {
1491 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1492 })?
1493 .await
1494 .unwrap_or(true);
1495 this.update(cx, |this, cx| {
1496 let (ix, message) = this.last_user_message().context("no user message")?;
1497 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1498 checkpoint.show = !equal;
1499 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1500 anyhow::Ok(())
1501 })??;
1502 }
1503
1504 Ok(())
1505 })
1506 }
1507
1508 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1509 self.entries
1510 .iter_mut()
1511 .enumerate()
1512 .rev()
1513 .find_map(|(ix, entry)| {
1514 if let AgentThreadEntry::UserMessage(message) = entry {
1515 Some((ix, message))
1516 } else {
1517 None
1518 }
1519 })
1520 }
1521
1522 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1523 self.entries.iter().find_map(|entry| {
1524 if let AgentThreadEntry::UserMessage(message) = entry {
1525 if message.id.as_ref() == Some(id) {
1526 Some(message)
1527 } else {
1528 None
1529 }
1530 } else {
1531 None
1532 }
1533 })
1534 }
1535
1536 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1537 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1538 if let AgentThreadEntry::UserMessage(message) = entry {
1539 if message.id.as_ref() == Some(id) {
1540 Some((ix, message))
1541 } else {
1542 None
1543 }
1544 } else {
1545 None
1546 }
1547 })
1548 }
1549
1550 pub fn read_text_file(
1551 &self,
1552 path: PathBuf,
1553 line: Option<u32>,
1554 limit: Option<u32>,
1555 reuse_shared_snapshot: bool,
1556 cx: &mut Context<Self>,
1557 ) -> Task<Result<String>> {
1558 let project = self.project.clone();
1559 let action_log = self.action_log.clone();
1560 cx.spawn(async move |this, cx| {
1561 let load = project.update(cx, |project, cx| {
1562 let path = project
1563 .project_path_for_absolute_path(&path, cx)
1564 .context("invalid path")?;
1565 anyhow::Ok(project.open_buffer(path, cx))
1566 });
1567 let buffer = load??.await?;
1568
1569 let snapshot = if reuse_shared_snapshot {
1570 this.read_with(cx, |this, _| {
1571 this.shared_buffers.get(&buffer.clone()).cloned()
1572 })
1573 .log_err()
1574 .flatten()
1575 } else {
1576 None
1577 };
1578
1579 let snapshot = if let Some(snapshot) = snapshot {
1580 snapshot
1581 } else {
1582 action_log.update(cx, |action_log, cx| {
1583 action_log.buffer_read(buffer.clone(), cx);
1584 })?;
1585 project.update(cx, |project, cx| {
1586 let position = buffer
1587 .read(cx)
1588 .snapshot()
1589 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1590 project.set_agent_location(
1591 Some(AgentLocation {
1592 buffer: buffer.downgrade(),
1593 position,
1594 }),
1595 cx,
1596 );
1597 })?;
1598
1599 buffer.update(cx, |buffer, _| buffer.snapshot())?
1600 };
1601
1602 this.update(cx, |this, _| {
1603 let text = snapshot.text();
1604 this.shared_buffers.insert(buffer.clone(), snapshot);
1605 if line.is_none() && limit.is_none() {
1606 return Ok(text);
1607 }
1608 let limit = limit.unwrap_or(u32::MAX) as usize;
1609 let Some(line) = line else {
1610 return Ok(text.lines().take(limit).collect::<String>());
1611 };
1612
1613 let count = text.lines().count();
1614 if count < line as usize {
1615 anyhow::bail!("There are only {} lines", count);
1616 }
1617 Ok(text
1618 .lines()
1619 .skip(line as usize + 1)
1620 .take(limit)
1621 .collect::<String>())
1622 })?
1623 })
1624 }
1625
1626 pub fn write_text_file(
1627 &self,
1628 path: PathBuf,
1629 content: String,
1630 cx: &mut Context<Self>,
1631 ) -> Task<Result<()>> {
1632 let project = self.project.clone();
1633 let action_log = self.action_log.clone();
1634 cx.spawn(async move |this, cx| {
1635 let load = project.update(cx, |project, cx| {
1636 let path = project
1637 .project_path_for_absolute_path(&path, cx)
1638 .context("invalid path")?;
1639 anyhow::Ok(project.open_buffer(path, cx))
1640 });
1641 let buffer = load??.await?;
1642 let snapshot = this.update(cx, |this, cx| {
1643 this.shared_buffers
1644 .get(&buffer)
1645 .cloned()
1646 .unwrap_or_else(|| buffer.read(cx).snapshot())
1647 })?;
1648 let edits = cx
1649 .background_executor()
1650 .spawn(async move {
1651 let old_text = snapshot.text();
1652 text_diff(old_text.as_str(), &content)
1653 .into_iter()
1654 .map(|(range, replacement)| {
1655 (
1656 snapshot.anchor_after(range.start)
1657 ..snapshot.anchor_before(range.end),
1658 replacement,
1659 )
1660 })
1661 .collect::<Vec<_>>()
1662 })
1663 .await;
1664
1665 project.update(cx, |project, cx| {
1666 project.set_agent_location(
1667 Some(AgentLocation {
1668 buffer: buffer.downgrade(),
1669 position: edits
1670 .last()
1671 .map(|(range, _)| range.end)
1672 .unwrap_or(Anchor::MIN),
1673 }),
1674 cx,
1675 );
1676 })?;
1677
1678 let format_on_save = cx.update(|cx| {
1679 action_log.update(cx, |action_log, cx| {
1680 action_log.buffer_read(buffer.clone(), cx);
1681 });
1682
1683 let format_on_save = buffer.update(cx, |buffer, cx| {
1684 buffer.edit(edits, None, cx);
1685
1686 let settings = language::language_settings::language_settings(
1687 buffer.language().map(|l| l.name()),
1688 buffer.file(),
1689 cx,
1690 );
1691
1692 settings.format_on_save != FormatOnSave::Off
1693 });
1694 action_log.update(cx, |action_log, cx| {
1695 action_log.buffer_edited(buffer.clone(), cx);
1696 });
1697 format_on_save
1698 })?;
1699
1700 if format_on_save {
1701 let format_task = project.update(cx, |project, cx| {
1702 project.format(
1703 HashSet::from_iter([buffer.clone()]),
1704 LspFormatTarget::Buffers,
1705 false,
1706 FormatTrigger::Save,
1707 cx,
1708 )
1709 })?;
1710 format_task.await.log_err();
1711
1712 action_log.update(cx, |action_log, cx| {
1713 action_log.buffer_edited(buffer.clone(), cx);
1714 })?;
1715 }
1716
1717 project
1718 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1719 .await
1720 })
1721 }
1722
1723 pub fn to_markdown(&self, cx: &App) -> String {
1724 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1725 }
1726
1727 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1728 cx.emit(AcpThreadEvent::LoadError(error));
1729 }
1730}
1731
1732fn markdown_for_raw_output(
1733 raw_output: &serde_json::Value,
1734 language_registry: &Arc<LanguageRegistry>,
1735 cx: &mut App,
1736) -> Option<Entity<Markdown>> {
1737 match raw_output {
1738 serde_json::Value::Null => None,
1739 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1740 Markdown::new(
1741 value.to_string().into(),
1742 Some(language_registry.clone()),
1743 None,
1744 cx,
1745 )
1746 })),
1747 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1748 Markdown::new(
1749 value.to_string().into(),
1750 Some(language_registry.clone()),
1751 None,
1752 cx,
1753 )
1754 })),
1755 serde_json::Value::String(value) => Some(cx.new(|cx| {
1756 Markdown::new(
1757 value.clone().into(),
1758 Some(language_registry.clone()),
1759 None,
1760 cx,
1761 )
1762 })),
1763 value => Some(cx.new(|cx| {
1764 Markdown::new(
1765 format!("```json\n{}\n```", value).into(),
1766 Some(language_registry.clone()),
1767 None,
1768 cx,
1769 )
1770 })),
1771 }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776 use super::*;
1777 use anyhow::anyhow;
1778 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1779 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1780 use indoc::indoc;
1781 use project::{FakeFs, Fs};
1782 use rand::Rng as _;
1783 use serde_json::json;
1784 use settings::SettingsStore;
1785 use smol::stream::StreamExt as _;
1786 use std::{
1787 any::Any,
1788 cell::RefCell,
1789 path::Path,
1790 rc::Rc,
1791 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1792 time::Duration,
1793 };
1794 use util::path;
1795
1796 fn init_test(cx: &mut TestAppContext) {
1797 env_logger::try_init().ok();
1798 cx.update(|cx| {
1799 let settings_store = SettingsStore::test(cx);
1800 cx.set_global(settings_store);
1801 Project::init_settings(cx);
1802 language::init(cx);
1803 });
1804 }
1805
1806 #[gpui::test]
1807 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1808 init_test(cx);
1809
1810 let fs = FakeFs::new(cx.executor());
1811 let project = Project::test(fs, [], cx).await;
1812 let connection = Rc::new(FakeAgentConnection::new());
1813 let thread = cx
1814 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1815 .await
1816 .unwrap();
1817
1818 // Test creating a new user message
1819 thread.update(cx, |thread, cx| {
1820 thread.push_user_content_block(
1821 None,
1822 acp::ContentBlock::Text(acp::TextContent {
1823 annotations: None,
1824 text: "Hello, ".to_string(),
1825 }),
1826 cx,
1827 );
1828 });
1829
1830 thread.update(cx, |thread, cx| {
1831 assert_eq!(thread.entries.len(), 1);
1832 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1833 assert_eq!(user_msg.id, None);
1834 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1835 } else {
1836 panic!("Expected UserMessage");
1837 }
1838 });
1839
1840 // Test appending to existing user message
1841 let message_1_id = UserMessageId::new();
1842 thread.update(cx, |thread, cx| {
1843 thread.push_user_content_block(
1844 Some(message_1_id.clone()),
1845 acp::ContentBlock::Text(acp::TextContent {
1846 annotations: None,
1847 text: "world!".to_string(),
1848 }),
1849 cx,
1850 );
1851 });
1852
1853 thread.update(cx, |thread, cx| {
1854 assert_eq!(thread.entries.len(), 1);
1855 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1856 assert_eq!(user_msg.id, Some(message_1_id));
1857 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1858 } else {
1859 panic!("Expected UserMessage");
1860 }
1861 });
1862
1863 // Test creating new user message after assistant message
1864 thread.update(cx, |thread, cx| {
1865 thread.push_assistant_content_block(
1866 acp::ContentBlock::Text(acp::TextContent {
1867 annotations: None,
1868 text: "Assistant response".to_string(),
1869 }),
1870 false,
1871 cx,
1872 );
1873 });
1874
1875 let message_2_id = UserMessageId::new();
1876 thread.update(cx, |thread, cx| {
1877 thread.push_user_content_block(
1878 Some(message_2_id.clone()),
1879 acp::ContentBlock::Text(acp::TextContent {
1880 annotations: None,
1881 text: "New user message".to_string(),
1882 }),
1883 cx,
1884 );
1885 });
1886
1887 thread.update(cx, |thread, cx| {
1888 assert_eq!(thread.entries.len(), 3);
1889 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1890 assert_eq!(user_msg.id, Some(message_2_id));
1891 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1892 } else {
1893 panic!("Expected UserMessage at index 2");
1894 }
1895 });
1896 }
1897
1898 #[gpui::test]
1899 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1900 init_test(cx);
1901
1902 let fs = FakeFs::new(cx.executor());
1903 let project = Project::test(fs, [], cx).await;
1904 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1905 |_, thread, mut cx| {
1906 async move {
1907 thread.update(&mut cx, |thread, cx| {
1908 thread
1909 .handle_session_update(
1910 acp::SessionUpdate::AgentThoughtChunk {
1911 content: "Thinking ".into(),
1912 },
1913 cx,
1914 )
1915 .unwrap();
1916 thread
1917 .handle_session_update(
1918 acp::SessionUpdate::AgentThoughtChunk {
1919 content: "hard!".into(),
1920 },
1921 cx,
1922 )
1923 .unwrap();
1924 })?;
1925 Ok(acp::PromptResponse {
1926 stop_reason: acp::StopReason::EndTurn,
1927 })
1928 }
1929 .boxed_local()
1930 },
1931 ));
1932
1933 let thread = cx
1934 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1935 .await
1936 .unwrap();
1937
1938 thread
1939 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1940 .await
1941 .unwrap();
1942
1943 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1944 assert_eq!(
1945 output,
1946 indoc! {r#"
1947 ## User
1948
1949 Hello from Zed!
1950
1951 ## Assistant
1952
1953 <thinking>
1954 Thinking hard!
1955 </thinking>
1956
1957 "#}
1958 );
1959 }
1960
1961 #[gpui::test]
1962 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1963 init_test(cx);
1964
1965 let fs = FakeFs::new(cx.executor());
1966 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1967 .await;
1968 let project = Project::test(fs.clone(), [], cx).await;
1969 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1970 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1971 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1972 move |_, thread, mut cx| {
1973 let read_file_tx = read_file_tx.clone();
1974 async move {
1975 let content = thread
1976 .update(&mut cx, |thread, cx| {
1977 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1978 })
1979 .unwrap()
1980 .await
1981 .unwrap();
1982 assert_eq!(content, "one\ntwo\nthree\n");
1983 read_file_tx.take().unwrap().send(()).unwrap();
1984 thread
1985 .update(&mut cx, |thread, cx| {
1986 thread.write_text_file(
1987 path!("/tmp/foo").into(),
1988 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1989 cx,
1990 )
1991 })
1992 .unwrap()
1993 .await
1994 .unwrap();
1995 Ok(acp::PromptResponse {
1996 stop_reason: acp::StopReason::EndTurn,
1997 })
1998 }
1999 .boxed_local()
2000 },
2001 ));
2002
2003 let (worktree, pathbuf) = project
2004 .update(cx, |project, cx| {
2005 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2006 })
2007 .await
2008 .unwrap();
2009 let buffer = project
2010 .update(cx, |project, cx| {
2011 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2012 })
2013 .await
2014 .unwrap();
2015
2016 let thread = cx
2017 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2018 .await
2019 .unwrap();
2020
2021 let request = thread.update(cx, |thread, cx| {
2022 thread.send_raw("Extend the count in /tmp/foo", cx)
2023 });
2024 read_file_rx.await.ok();
2025 buffer.update(cx, |buffer, cx| {
2026 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2027 });
2028 cx.run_until_parked();
2029 assert_eq!(
2030 buffer.read_with(cx, |buffer, _| buffer.text()),
2031 "zero\none\ntwo\nthree\nfour\nfive\n"
2032 );
2033 assert_eq!(
2034 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2035 "zero\none\ntwo\nthree\nfour\nfive\n"
2036 );
2037 request.await.unwrap();
2038 }
2039
2040 #[gpui::test]
2041 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2042 init_test(cx);
2043
2044 let fs = FakeFs::new(cx.executor());
2045 let project = Project::test(fs, [], cx).await;
2046 let id = acp::ToolCallId("test".into());
2047
2048 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2049 let id = id.clone();
2050 move |_, thread, mut cx| {
2051 let id = id.clone();
2052 async move {
2053 thread
2054 .update(&mut cx, |thread, cx| {
2055 thread.handle_session_update(
2056 acp::SessionUpdate::ToolCall(acp::ToolCall {
2057 id: id.clone(),
2058 title: "Label".into(),
2059 kind: acp::ToolKind::Fetch,
2060 status: acp::ToolCallStatus::InProgress,
2061 content: vec![],
2062 locations: vec![],
2063 raw_input: None,
2064 raw_output: None,
2065 }),
2066 cx,
2067 )
2068 })
2069 .unwrap()
2070 .unwrap();
2071 Ok(acp::PromptResponse {
2072 stop_reason: acp::StopReason::EndTurn,
2073 })
2074 }
2075 .boxed_local()
2076 }
2077 }));
2078
2079 let thread = cx
2080 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2081 .await
2082 .unwrap();
2083
2084 let request = thread.update(cx, |thread, cx| {
2085 thread.send_raw("Fetch https://example.com", cx)
2086 });
2087
2088 run_until_first_tool_call(&thread, cx).await;
2089
2090 thread.read_with(cx, |thread, _| {
2091 assert!(matches!(
2092 thread.entries[1],
2093 AgentThreadEntry::ToolCall(ToolCall {
2094 status: ToolCallStatus::InProgress,
2095 ..
2096 })
2097 ));
2098 });
2099
2100 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2101
2102 thread.read_with(cx, |thread, _| {
2103 assert!(matches!(
2104 &thread.entries[1],
2105 AgentThreadEntry::ToolCall(ToolCall {
2106 status: ToolCallStatus::Canceled,
2107 ..
2108 })
2109 ));
2110 });
2111
2112 thread
2113 .update(cx, |thread, cx| {
2114 thread.handle_session_update(
2115 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2116 id,
2117 fields: acp::ToolCallUpdateFields {
2118 status: Some(acp::ToolCallStatus::Completed),
2119 ..Default::default()
2120 },
2121 }),
2122 cx,
2123 )
2124 })
2125 .unwrap();
2126
2127 request.await.unwrap();
2128
2129 thread.read_with(cx, |thread, _| {
2130 assert!(matches!(
2131 thread.entries[1],
2132 AgentThreadEntry::ToolCall(ToolCall {
2133 status: ToolCallStatus::Completed,
2134 ..
2135 })
2136 ));
2137 });
2138 }
2139
2140 #[gpui::test]
2141 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2142 init_test(cx);
2143 let fs = FakeFs::new(cx.background_executor.clone());
2144 fs.insert_tree(path!("/test"), json!({})).await;
2145 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2146
2147 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2148 move |_, thread, mut cx| {
2149 async move {
2150 thread
2151 .update(&mut cx, |thread, cx| {
2152 thread.handle_session_update(
2153 acp::SessionUpdate::ToolCall(acp::ToolCall {
2154 id: acp::ToolCallId("test".into()),
2155 title: "Label".into(),
2156 kind: acp::ToolKind::Edit,
2157 status: acp::ToolCallStatus::Completed,
2158 content: vec![acp::ToolCallContent::Diff {
2159 diff: acp::Diff {
2160 path: "/test/test.txt".into(),
2161 old_text: None,
2162 new_text: "foo".into(),
2163 },
2164 }],
2165 locations: vec![],
2166 raw_input: None,
2167 raw_output: None,
2168 }),
2169 cx,
2170 )
2171 })
2172 .unwrap()
2173 .unwrap();
2174 Ok(acp::PromptResponse {
2175 stop_reason: acp::StopReason::EndTurn,
2176 })
2177 }
2178 .boxed_local()
2179 }
2180 }));
2181
2182 let thread = cx
2183 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2184 .await
2185 .unwrap();
2186
2187 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2188 .await
2189 .unwrap();
2190
2191 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2192 }
2193
2194 #[gpui::test(iterations = 10)]
2195 async fn test_checkpoints(cx: &mut TestAppContext) {
2196 init_test(cx);
2197 let fs = FakeFs::new(cx.background_executor.clone());
2198 fs.insert_tree(
2199 path!("/test"),
2200 json!({
2201 ".git": {}
2202 }),
2203 )
2204 .await;
2205 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2206
2207 let simulate_changes = Arc::new(AtomicBool::new(true));
2208 let next_filename = Arc::new(AtomicUsize::new(0));
2209 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2210 let simulate_changes = simulate_changes.clone();
2211 let next_filename = next_filename.clone();
2212 let fs = fs.clone();
2213 move |request, thread, mut cx| {
2214 let fs = fs.clone();
2215 let simulate_changes = simulate_changes.clone();
2216 let next_filename = next_filename.clone();
2217 async move {
2218 if simulate_changes.load(SeqCst) {
2219 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2220 fs.write(Path::new(&filename), b"").await?;
2221 }
2222
2223 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2224 panic!("expected text content block");
2225 };
2226 thread.update(&mut cx, |thread, cx| {
2227 thread
2228 .handle_session_update(
2229 acp::SessionUpdate::AgentMessageChunk {
2230 content: content.text.to_uppercase().into(),
2231 },
2232 cx,
2233 )
2234 .unwrap();
2235 })?;
2236 Ok(acp::PromptResponse {
2237 stop_reason: acp::StopReason::EndTurn,
2238 })
2239 }
2240 .boxed_local()
2241 }
2242 }));
2243 let thread = cx
2244 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2245 .await
2246 .unwrap();
2247
2248 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2249 .await
2250 .unwrap();
2251 thread.read_with(cx, |thread, cx| {
2252 assert_eq!(
2253 thread.to_markdown(cx),
2254 indoc! {"
2255 ## User (checkpoint)
2256
2257 Lorem
2258
2259 ## Assistant
2260
2261 LOREM
2262
2263 "}
2264 );
2265 });
2266 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2267
2268 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2269 .await
2270 .unwrap();
2271 thread.read_with(cx, |thread, cx| {
2272 assert_eq!(
2273 thread.to_markdown(cx),
2274 indoc! {"
2275 ## User (checkpoint)
2276
2277 Lorem
2278
2279 ## Assistant
2280
2281 LOREM
2282
2283 ## User (checkpoint)
2284
2285 ipsum
2286
2287 ## Assistant
2288
2289 IPSUM
2290
2291 "}
2292 );
2293 });
2294 assert_eq!(
2295 fs.files(),
2296 vec![
2297 Path::new(path!("/test/file-0")),
2298 Path::new(path!("/test/file-1"))
2299 ]
2300 );
2301
2302 // Checkpoint isn't stored when there are no changes.
2303 simulate_changes.store(false, SeqCst);
2304 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2305 .await
2306 .unwrap();
2307 thread.read_with(cx, |thread, cx| {
2308 assert_eq!(
2309 thread.to_markdown(cx),
2310 indoc! {"
2311 ## User (checkpoint)
2312
2313 Lorem
2314
2315 ## Assistant
2316
2317 LOREM
2318
2319 ## User (checkpoint)
2320
2321 ipsum
2322
2323 ## Assistant
2324
2325 IPSUM
2326
2327 ## User
2328
2329 dolor
2330
2331 ## Assistant
2332
2333 DOLOR
2334
2335 "}
2336 );
2337 });
2338 assert_eq!(
2339 fs.files(),
2340 vec![
2341 Path::new(path!("/test/file-0")),
2342 Path::new(path!("/test/file-1"))
2343 ]
2344 );
2345
2346 // Rewinding the conversation truncates the history and restores the checkpoint.
2347 thread
2348 .update(cx, |thread, cx| {
2349 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2350 panic!("unexpected entries {:?}", thread.entries)
2351 };
2352 thread.rewind(message.id.clone().unwrap(), cx)
2353 })
2354 .await
2355 .unwrap();
2356 thread.read_with(cx, |thread, cx| {
2357 assert_eq!(
2358 thread.to_markdown(cx),
2359 indoc! {"
2360 ## User (checkpoint)
2361
2362 Lorem
2363
2364 ## Assistant
2365
2366 LOREM
2367
2368 "}
2369 );
2370 });
2371 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2372 }
2373
2374 async fn run_until_first_tool_call(
2375 thread: &Entity<AcpThread>,
2376 cx: &mut TestAppContext,
2377 ) -> usize {
2378 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2379
2380 let subscription = cx.update(|cx| {
2381 cx.subscribe(thread, move |thread, _, cx| {
2382 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2383 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2384 return tx.try_send(ix).unwrap();
2385 }
2386 }
2387 })
2388 });
2389
2390 select! {
2391 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2392 panic!("Timeout waiting for tool call")
2393 }
2394 ix = rx.next().fuse() => {
2395 drop(subscription);
2396 ix.unwrap()
2397 }
2398 }
2399 }
2400
2401 #[derive(Clone, Default)]
2402 struct FakeAgentConnection {
2403 auth_methods: Vec<acp::AuthMethod>,
2404 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2405 on_user_message: Option<
2406 Rc<
2407 dyn Fn(
2408 acp::PromptRequest,
2409 WeakEntity<AcpThread>,
2410 AsyncApp,
2411 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2412 + 'static,
2413 >,
2414 >,
2415 }
2416
2417 impl FakeAgentConnection {
2418 fn new() -> Self {
2419 Self {
2420 auth_methods: Vec::new(),
2421 on_user_message: None,
2422 sessions: Arc::default(),
2423 }
2424 }
2425
2426 #[expect(unused)]
2427 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2428 self.auth_methods = auth_methods;
2429 self
2430 }
2431
2432 fn on_user_message(
2433 mut self,
2434 handler: impl Fn(
2435 acp::PromptRequest,
2436 WeakEntity<AcpThread>,
2437 AsyncApp,
2438 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2439 + 'static,
2440 ) -> Self {
2441 self.on_user_message.replace(Rc::new(handler));
2442 self
2443 }
2444 }
2445
2446 impl AgentConnection for FakeAgentConnection {
2447 fn auth_methods(&self) -> &[acp::AuthMethod] {
2448 &self.auth_methods
2449 }
2450
2451 fn new_thread(
2452 self: Rc<Self>,
2453 project: Entity<Project>,
2454 _cwd: &Path,
2455 cx: &mut App,
2456 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2457 let session_id = acp::SessionId(
2458 rand::thread_rng()
2459 .sample_iter(&rand::distributions::Alphanumeric)
2460 .take(7)
2461 .map(char::from)
2462 .collect::<String>()
2463 .into(),
2464 );
2465 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2466 let thread = cx.new(|_cx| {
2467 AcpThread::new(
2468 "Test",
2469 self.clone(),
2470 project,
2471 action_log,
2472 session_id.clone(),
2473 )
2474 });
2475 self.sessions.lock().insert(session_id, thread.downgrade());
2476 Task::ready(Ok(thread))
2477 }
2478
2479 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2480 if self.auth_methods().iter().any(|m| m.id == method) {
2481 Task::ready(Ok(()))
2482 } else {
2483 Task::ready(Err(anyhow!("Invalid Auth Method")))
2484 }
2485 }
2486
2487 fn prompt(
2488 &self,
2489 _id: Option<UserMessageId>,
2490 params: acp::PromptRequest,
2491 cx: &mut App,
2492 ) -> Task<gpui::Result<acp::PromptResponse>> {
2493 let sessions = self.sessions.lock();
2494 let thread = sessions.get(¶ms.session_id).unwrap();
2495 if let Some(handler) = &self.on_user_message {
2496 let handler = handler.clone();
2497 let thread = thread.clone();
2498 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2499 } else {
2500 Task::ready(Ok(acp::PromptResponse {
2501 stop_reason: acp::StopReason::EndTurn,
2502 }))
2503 }
2504 }
2505
2506 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2507 let sessions = self.sessions.lock();
2508 let thread = sessions.get(session_id).unwrap().clone();
2509
2510 cx.spawn(async move |cx| {
2511 thread
2512 .update(cx, |thread, cx| thread.cancel(cx))
2513 .unwrap()
2514 .await
2515 })
2516 .detach();
2517 }
2518
2519 fn session_editor(
2520 &self,
2521 session_id: &acp::SessionId,
2522 _cx: &mut App,
2523 ) -> Option<Rc<dyn AgentSessionEditor>> {
2524 Some(Rc::new(FakeAgentSessionEditor {
2525 _session_id: session_id.clone(),
2526 }))
2527 }
2528
2529 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2530 self
2531 }
2532 }
2533
2534 struct FakeAgentSessionEditor {
2535 _session_id: acp::SessionId,
2536 }
2537
2538 impl AgentSessionEditor for FakeAgentSessionEditor {
2539 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2540 Task::ready(Ok(()))
2541 }
2542 }
2543}