1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use ::terminal::TerminalBuilder;
7use ::terminal::terminal_settings::{AlternateScroll, CursorShape};
8use agent_settings::AgentSettings;
9use collections::HashSet;
10pub use connection::*;
11pub use diff::*;
12use language::language_settings::FormatOnSave;
13pub use mention::*;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use serde::{Deserialize, Serialize};
16use settings::Settings as _;
17use task::{Shell, ShellBuilder};
18pub use terminal::*;
19
20use action_log::ActionLog;
21use agent_client_protocol::{self as acp};
22use anyhow::{Context as _, Result, anyhow};
23use editor::Bias;
24use futures::{FutureExt, channel::oneshot, future::BoxFuture};
25use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
26use itertools::Itertools;
27use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
28use markdown::Markdown;
29use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
30use std::collections::HashMap;
31use std::error::Error;
32use std::fmt::{Formatter, Write};
33use std::ops::Range;
34use std::process::ExitStatus;
35use std::rc::Rc;
36use std::time::{Duration, Instant};
37use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
38use ui::App;
39use util::{ResultExt, get_default_system_shell};
40use uuid::Uuid;
41
42#[derive(Debug)]
43pub struct UserMessage {
44 pub id: Option<UserMessageId>,
45 pub content: ContentBlock,
46 pub chunks: Vec<acp::ContentBlock>,
47 pub checkpoint: Option<Checkpoint>,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78}
79
80impl AssistantMessage {
81 pub fn to_markdown(&self, cx: &App) -> String {
82 format!(
83 "## Assistant\n\n{}\n\n",
84 self.chunks
85 .iter()
86 .map(|chunk| chunk.to_markdown(cx))
87 .join("\n\n")
88 )
89 }
90}
91
92#[derive(Debug, PartialEq)]
93pub enum AssistantMessageChunk {
94 Message { block: ContentBlock },
95 Thought { block: ContentBlock },
96}
97
98impl AssistantMessageChunk {
99 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
100 Self::Message {
101 block: ContentBlock::new(chunk.into(), language_registry, cx),
102 }
103 }
104
105 fn to_markdown(&self, cx: &App) -> String {
106 match self {
107 Self::Message { block } => block.to_markdown(cx).to_string(),
108 Self::Thought { block } => {
109 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
110 }
111 }
112 }
113}
114
115#[derive(Debug)]
116pub enum AgentThreadEntry {
117 UserMessage(UserMessage),
118 AssistantMessage(AssistantMessage),
119 ToolCall(ToolCall),
120}
121
122impl AgentThreadEntry {
123 pub fn to_markdown(&self, cx: &App) -> String {
124 match self {
125 Self::UserMessage(message) => message.to_markdown(cx),
126 Self::AssistantMessage(message) => message.to_markdown(cx),
127 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
128 }
129 }
130
131 pub fn user_message(&self) -> Option<&UserMessage> {
132 if let AgentThreadEntry::UserMessage(message) = self {
133 Some(message)
134 } else {
135 None
136 }
137 }
138
139 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
140 if let AgentThreadEntry::ToolCall(call) = self {
141 itertools::Either::Left(call.diffs())
142 } else {
143 itertools::Either::Right(std::iter::empty())
144 }
145 }
146
147 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
148 if let AgentThreadEntry::ToolCall(call) = self {
149 itertools::Either::Left(call.terminals())
150 } else {
151 itertools::Either::Right(std::iter::empty())
152 }
153 }
154
155 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
156 if let AgentThreadEntry::ToolCall(ToolCall {
157 locations,
158 resolved_locations,
159 ..
160 }) = self
161 {
162 Some((
163 locations.get(ix)?.clone(),
164 resolved_locations.get(ix)?.clone()?,
165 ))
166 } else {
167 None
168 }
169 }
170}
171
172#[derive(Debug)]
173pub struct ToolCall {
174 pub id: acp::ToolCallId,
175 pub label: Entity<Markdown>,
176 pub kind: acp::ToolKind,
177 pub content: Vec<ToolCallContent>,
178 pub status: ToolCallStatus,
179 pub locations: Vec<acp::ToolCallLocation>,
180 pub resolved_locations: Vec<Option<AgentLocation>>,
181 pub raw_input: Option<serde_json::Value>,
182 pub raw_output: Option<serde_json::Value>,
183}
184
185impl ToolCall {
186 fn from_acp(
187 tool_call: acp::ToolCall,
188 status: ToolCallStatus,
189 language_registry: Arc<LanguageRegistry>,
190 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
191 cx: &mut App,
192 ) -> Result<Self> {
193 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
194 first_line.to_owned() + "…"
195 } else {
196 tool_call.title
197 };
198 let mut content = Vec::with_capacity(tool_call.content.len());
199 for item in tool_call.content {
200 content.push(ToolCallContent::from_acp(
201 item,
202 language_registry.clone(),
203 terminals,
204 cx,
205 )?);
206 }
207
208 let result = Self {
209 id: tool_call.id,
210 label: cx
211 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
212 kind: tool_call.kind,
213 content,
214 locations: tool_call.locations,
215 resolved_locations: Vec::default(),
216 status,
217 raw_input: tool_call.raw_input,
218 raw_output: tool_call.raw_output,
219 };
220 Ok(result)
221 }
222
223 fn update_fields(
224 &mut self,
225 fields: acp::ToolCallUpdateFields,
226 language_registry: Arc<LanguageRegistry>,
227 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
228 cx: &mut App,
229 ) -> Result<()> {
230 let acp::ToolCallUpdateFields {
231 kind,
232 status,
233 title,
234 content,
235 locations,
236 raw_input,
237 raw_output,
238 } = fields;
239
240 if let Some(kind) = kind {
241 self.kind = kind;
242 }
243
244 if let Some(status) = status {
245 self.status = status.into();
246 }
247
248 if let Some(title) = title {
249 self.label.update(cx, |label, cx| {
250 if let Some((first_line, _)) = title.split_once("\n") {
251 label.replace(first_line.to_owned() + "…", cx)
252 } else {
253 label.replace(title, cx);
254 }
255 });
256 }
257
258 if let Some(content) = content {
259 let new_content_len = content.len();
260 let mut content = content.into_iter();
261
262 // Reuse existing content if we can
263 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
264 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
265 }
266 for new in content {
267 self.content.push(ToolCallContent::from_acp(
268 new,
269 language_registry.clone(),
270 terminals,
271 cx,
272 )?)
273 }
274 self.content.truncate(new_content_len);
275 }
276
277 if let Some(locations) = locations {
278 self.locations = locations;
279 }
280
281 if let Some(raw_input) = raw_input {
282 self.raw_input = Some(raw_input);
283 }
284
285 if let Some(raw_output) = raw_output {
286 if self.content.is_empty()
287 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
288 {
289 self.content
290 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
291 markdown,
292 }));
293 }
294 self.raw_output = Some(raw_output);
295 }
296 Ok(())
297 }
298
299 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
300 self.content.iter().filter_map(|content| match content {
301 ToolCallContent::Diff(diff) => Some(diff),
302 ToolCallContent::ContentBlock(_) => None,
303 ToolCallContent::Terminal(_) => None,
304 })
305 }
306
307 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
308 self.content.iter().filter_map(|content| match content {
309 ToolCallContent::Terminal(terminal) => Some(terminal),
310 ToolCallContent::ContentBlock(_) => None,
311 ToolCallContent::Diff(_) => None,
312 })
313 }
314
315 fn to_markdown(&self, cx: &App) -> String {
316 let mut markdown = format!(
317 "**Tool Call: {}**\nStatus: {}\n\n",
318 self.label.read(cx).source(),
319 self.status
320 );
321 for content in &self.content {
322 markdown.push_str(content.to_markdown(cx).as_str());
323 markdown.push_str("\n\n");
324 }
325 markdown
326 }
327
328 async fn resolve_location(
329 location: acp::ToolCallLocation,
330 project: WeakEntity<Project>,
331 cx: &mut AsyncApp,
332 ) -> Option<AgentLocation> {
333 let buffer = project
334 .update(cx, |project, cx| {
335 project
336 .project_path_for_absolute_path(&location.path, cx)
337 .map(|path| project.open_buffer(path, cx))
338 })
339 .ok()??;
340 let buffer = buffer.await.log_err()?;
341 let position = buffer
342 .update(cx, |buffer, _| {
343 if let Some(row) = location.line {
344 let snapshot = buffer.snapshot();
345 let column = snapshot.indent_size_for_line(row).len;
346 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
347 snapshot.anchor_before(point)
348 } else {
349 Anchor::MIN
350 }
351 })
352 .ok()?;
353
354 Some(AgentLocation {
355 buffer: buffer.downgrade(),
356 position,
357 })
358 }
359
360 fn resolve_locations(
361 &self,
362 project: Entity<Project>,
363 cx: &mut App,
364 ) -> Task<Vec<Option<AgentLocation>>> {
365 let locations = self.locations.clone();
366 project.update(cx, |_, cx| {
367 cx.spawn(async move |project, cx| {
368 let mut new_locations = Vec::new();
369 for location in locations {
370 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
371 }
372 new_locations
373 })
374 })
375 }
376}
377
378#[derive(Debug)]
379pub enum ToolCallStatus {
380 /// The tool call hasn't started running yet, but we start showing it to
381 /// the user.
382 Pending,
383 /// The tool call is waiting for confirmation from the user.
384 WaitingForConfirmation {
385 options: Vec<acp::PermissionOption>,
386 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
387 },
388 /// The tool call is currently running.
389 InProgress,
390 /// The tool call completed successfully.
391 Completed,
392 /// The tool call failed.
393 Failed,
394 /// The user rejected the tool call.
395 Rejected,
396 /// The user canceled generation so the tool call was canceled.
397 Canceled,
398}
399
400impl From<acp::ToolCallStatus> for ToolCallStatus {
401 fn from(status: acp::ToolCallStatus) -> Self {
402 match status {
403 acp::ToolCallStatus::Pending => Self::Pending,
404 acp::ToolCallStatus::InProgress => Self::InProgress,
405 acp::ToolCallStatus::Completed => Self::Completed,
406 acp::ToolCallStatus::Failed => Self::Failed,
407 }
408 }
409}
410
411impl Display for ToolCallStatus {
412 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
413 write!(
414 f,
415 "{}",
416 match self {
417 ToolCallStatus::Pending => "Pending",
418 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
419 ToolCallStatus::InProgress => "In Progress",
420 ToolCallStatus::Completed => "Completed",
421 ToolCallStatus::Failed => "Failed",
422 ToolCallStatus::Rejected => "Rejected",
423 ToolCallStatus::Canceled => "Canceled",
424 }
425 )
426 }
427}
428
429#[derive(Debug, PartialEq, Clone)]
430pub enum ContentBlock {
431 Empty,
432 Markdown { markdown: Entity<Markdown> },
433 ResourceLink { resource_link: acp::ResourceLink },
434}
435
436impl ContentBlock {
437 pub fn new(
438 block: acp::ContentBlock,
439 language_registry: &Arc<LanguageRegistry>,
440 cx: &mut App,
441 ) -> Self {
442 let mut this = Self::Empty;
443 this.append(block, language_registry, cx);
444 this
445 }
446
447 pub fn new_combined(
448 blocks: impl IntoIterator<Item = acp::ContentBlock>,
449 language_registry: Arc<LanguageRegistry>,
450 cx: &mut App,
451 ) -> Self {
452 let mut this = Self::Empty;
453 for block in blocks {
454 this.append(block, &language_registry, cx);
455 }
456 this
457 }
458
459 pub fn append(
460 &mut self,
461 block: acp::ContentBlock,
462 language_registry: &Arc<LanguageRegistry>,
463 cx: &mut App,
464 ) {
465 if matches!(self, ContentBlock::Empty)
466 && let acp::ContentBlock::ResourceLink(resource_link) = block
467 {
468 *self = ContentBlock::ResourceLink { resource_link };
469 return;
470 }
471
472 let new_content = self.block_string_contents(block);
473
474 match self {
475 ContentBlock::Empty => {
476 *self = Self::create_markdown_block(new_content, language_registry, cx);
477 }
478 ContentBlock::Markdown { markdown } => {
479 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
480 }
481 ContentBlock::ResourceLink { resource_link } => {
482 let existing_content = Self::resource_link_md(&resource_link.uri);
483 let combined = format!("{}\n{}", existing_content, new_content);
484
485 *self = Self::create_markdown_block(combined, language_registry, cx);
486 }
487 }
488 }
489
490 fn create_markdown_block(
491 content: String,
492 language_registry: &Arc<LanguageRegistry>,
493 cx: &mut App,
494 ) -> ContentBlock {
495 ContentBlock::Markdown {
496 markdown: cx
497 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
498 }
499 }
500
501 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
502 match block {
503 acp::ContentBlock::Text(text_content) => text_content.text,
504 acp::ContentBlock::ResourceLink(resource_link) => {
505 Self::resource_link_md(&resource_link.uri)
506 }
507 acp::ContentBlock::Resource(acp::EmbeddedResource {
508 resource:
509 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
510 uri,
511 ..
512 }),
513 ..
514 }) => Self::resource_link_md(&uri),
515 acp::ContentBlock::Image(image) => Self::image_md(&image),
516 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
517 }
518 }
519
520 fn resource_link_md(uri: &str) -> String {
521 if let Some(uri) = MentionUri::parse(uri).log_err() {
522 uri.as_link().to_string()
523 } else {
524 uri.to_string()
525 }
526 }
527
528 fn image_md(_image: &acp::ImageContent) -> String {
529 "`Image`".into()
530 }
531
532 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
533 match self {
534 ContentBlock::Empty => "",
535 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
536 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
537 }
538 }
539
540 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
541 match self {
542 ContentBlock::Empty => None,
543 ContentBlock::Markdown { markdown } => Some(markdown),
544 ContentBlock::ResourceLink { .. } => None,
545 }
546 }
547
548 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
549 match self {
550 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
551 _ => None,
552 }
553 }
554}
555
556#[derive(Debug)]
557pub enum ToolCallContent {
558 ContentBlock(ContentBlock),
559 Diff(Entity<Diff>),
560 Terminal(Entity<Terminal>),
561}
562
563impl ToolCallContent {
564 pub fn from_acp(
565 content: acp::ToolCallContent,
566 language_registry: Arc<LanguageRegistry>,
567 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
568 cx: &mut App,
569 ) -> Result<Self> {
570 match content {
571 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
572 content,
573 &language_registry,
574 cx,
575 ))),
576 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
577 Diff::finalized(
578 diff.path,
579 diff.old_text,
580 diff.new_text,
581 language_registry,
582 cx,
583 )
584 }))),
585 acp::ToolCallContent::Terminal { terminal_id } => terminals
586 .get(&terminal_id)
587 .cloned()
588 .map(Self::Terminal)
589 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
590 }
591 }
592
593 pub fn update_from_acp(
594 &mut self,
595 new: acp::ToolCallContent,
596 language_registry: Arc<LanguageRegistry>,
597 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
598 cx: &mut App,
599 ) -> Result<()> {
600 let needs_update = match (&self, &new) {
601 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
602 old_diff.read(cx).needs_update(
603 new_diff.old_text.as_deref().unwrap_or(""),
604 &new_diff.new_text,
605 cx,
606 )
607 }
608 _ => true,
609 };
610
611 if needs_update {
612 *self = Self::from_acp(new, language_registry, terminals, cx)?;
613 }
614 Ok(())
615 }
616
617 pub fn to_markdown(&self, cx: &App) -> String {
618 match self {
619 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
620 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
621 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
622 }
623 }
624}
625
626#[derive(Debug, PartialEq)]
627pub enum ToolCallUpdate {
628 UpdateFields(acp::ToolCallUpdate),
629 UpdateDiff(ToolCallUpdateDiff),
630 UpdateTerminal(ToolCallUpdateTerminal),
631}
632
633impl ToolCallUpdate {
634 fn id(&self) -> &acp::ToolCallId {
635 match self {
636 Self::UpdateFields(update) => &update.id,
637 Self::UpdateDiff(diff) => &diff.id,
638 Self::UpdateTerminal(terminal) => &terminal.id,
639 }
640 }
641}
642
643impl From<acp::ToolCallUpdate> for ToolCallUpdate {
644 fn from(update: acp::ToolCallUpdate) -> Self {
645 Self::UpdateFields(update)
646 }
647}
648
649impl From<ToolCallUpdateDiff> for ToolCallUpdate {
650 fn from(diff: ToolCallUpdateDiff) -> Self {
651 Self::UpdateDiff(diff)
652 }
653}
654
655#[derive(Debug, PartialEq)]
656pub struct ToolCallUpdateDiff {
657 pub id: acp::ToolCallId,
658 pub diff: Entity<Diff>,
659}
660
661impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
662 fn from(terminal: ToolCallUpdateTerminal) -> Self {
663 Self::UpdateTerminal(terminal)
664 }
665}
666
667#[derive(Debug, PartialEq)]
668pub struct ToolCallUpdateTerminal {
669 pub id: acp::ToolCallId,
670 pub terminal: Entity<Terminal>,
671}
672
673#[derive(Debug, Default)]
674pub struct Plan {
675 pub entries: Vec<PlanEntry>,
676}
677
678#[derive(Debug)]
679pub struct PlanStats<'a> {
680 pub in_progress_entry: Option<&'a PlanEntry>,
681 pub pending: u32,
682 pub completed: u32,
683}
684
685impl Plan {
686 pub fn is_empty(&self) -> bool {
687 self.entries.is_empty()
688 }
689
690 pub fn stats(&self) -> PlanStats<'_> {
691 let mut stats = PlanStats {
692 in_progress_entry: None,
693 pending: 0,
694 completed: 0,
695 };
696
697 for entry in &self.entries {
698 match &entry.status {
699 acp::PlanEntryStatus::Pending => {
700 stats.pending += 1;
701 }
702 acp::PlanEntryStatus::InProgress => {
703 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
704 }
705 acp::PlanEntryStatus::Completed => {
706 stats.completed += 1;
707 }
708 }
709 }
710
711 stats
712 }
713}
714
715#[derive(Debug)]
716pub struct PlanEntry {
717 pub content: Entity<Markdown>,
718 pub priority: acp::PlanEntryPriority,
719 pub status: acp::PlanEntryStatus,
720}
721
722impl PlanEntry {
723 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
724 Self {
725 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
726 priority: entry.priority,
727 status: entry.status,
728 }
729 }
730}
731
732#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
733pub struct TokenUsage {
734 pub max_tokens: u64,
735 pub used_tokens: u64,
736}
737
738impl TokenUsage {
739 pub fn ratio(&self) -> TokenUsageRatio {
740 #[cfg(debug_assertions)]
741 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
742 .unwrap_or("0.8".to_string())
743 .parse()
744 .unwrap();
745 #[cfg(not(debug_assertions))]
746 let warning_threshold: f32 = 0.8;
747
748 // When the maximum is unknown because there is no selected model,
749 // avoid showing the token limit warning.
750 if self.max_tokens == 0 {
751 TokenUsageRatio::Normal
752 } else if self.used_tokens >= self.max_tokens {
753 TokenUsageRatio::Exceeded
754 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
755 TokenUsageRatio::Warning
756 } else {
757 TokenUsageRatio::Normal
758 }
759 }
760}
761
762#[derive(Debug, Clone, PartialEq, Eq)]
763pub enum TokenUsageRatio {
764 Normal,
765 Warning,
766 Exceeded,
767}
768
769#[derive(Debug, Clone)]
770pub struct RetryStatus {
771 pub last_error: SharedString,
772 pub attempt: usize,
773 pub max_attempts: usize,
774 pub started_at: Instant,
775 pub duration: Duration,
776}
777
778pub struct AcpThread {
779 title: SharedString,
780 entries: Vec<AgentThreadEntry>,
781 plan: Plan,
782 project: Entity<Project>,
783 action_log: Entity<ActionLog>,
784 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
785 send_task: Option<Task<()>>,
786 connection: Rc<dyn AgentConnection>,
787 session_id: acp::SessionId,
788 token_usage: Option<TokenUsage>,
789 prompt_capabilities: acp::PromptCapabilities,
790 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
791 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
792}
793
794#[derive(Debug)]
795pub enum AcpThreadEvent {
796 NewEntry,
797 TitleUpdated,
798 TokenUsageUpdated,
799 EntryUpdated(usize),
800 EntriesRemoved(Range<usize>),
801 ToolAuthorizationRequired,
802 Retry(RetryStatus),
803 Stopped,
804 Error,
805 LoadError(LoadError),
806 PromptCapabilitiesUpdated,
807 Refusal,
808 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
809 ModeUpdated(acp::SessionModeId),
810}
811
812impl EventEmitter<AcpThreadEvent> for AcpThread {}
813
814#[derive(PartialEq, Eq, Debug)]
815pub enum ThreadStatus {
816 Idle,
817 Generating,
818}
819
820#[derive(Debug, Clone)]
821pub enum LoadError {
822 Unsupported {
823 command: SharedString,
824 current_version: SharedString,
825 minimum_version: SharedString,
826 },
827 FailedToInstall(SharedString),
828 Exited {
829 status: ExitStatus,
830 },
831 Other(SharedString),
832}
833
834impl Display for LoadError {
835 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
836 match self {
837 LoadError::Unsupported {
838 command: path,
839 current_version,
840 minimum_version,
841 } => {
842 write!(
843 f,
844 "version {current_version} from {path} is not supported (need at least {minimum_version})"
845 )
846 }
847 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
848 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
849 LoadError::Other(msg) => write!(f, "{msg}"),
850 }
851 }
852}
853
854impl Error for LoadError {}
855
856impl AcpThread {
857 pub fn new(
858 title: impl Into<SharedString>,
859 connection: Rc<dyn AgentConnection>,
860 project: Entity<Project>,
861 action_log: Entity<ActionLog>,
862 session_id: acp::SessionId,
863 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
864 cx: &mut Context<Self>,
865 ) -> Self {
866 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
867 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
868 loop {
869 let caps = prompt_capabilities_rx.recv().await?;
870 this.update(cx, |this, cx| {
871 this.prompt_capabilities = caps;
872 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
873 })?;
874 }
875 });
876
877 Self {
878 action_log,
879 shared_buffers: Default::default(),
880 entries: Default::default(),
881 plan: Default::default(),
882 title: title.into(),
883 project,
884 send_task: None,
885 connection,
886 session_id,
887 token_usage: None,
888 prompt_capabilities,
889 _observe_prompt_capabilities: task,
890 terminals: HashMap::default(),
891 }
892 }
893
894 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
895 self.prompt_capabilities.clone()
896 }
897
898 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
899 &self.connection
900 }
901
902 pub fn action_log(&self) -> &Entity<ActionLog> {
903 &self.action_log
904 }
905
906 pub fn project(&self) -> &Entity<Project> {
907 &self.project
908 }
909
910 pub fn title(&self) -> SharedString {
911 self.title.clone()
912 }
913
914 pub fn entries(&self) -> &[AgentThreadEntry] {
915 &self.entries
916 }
917
918 pub fn session_id(&self) -> &acp::SessionId {
919 &self.session_id
920 }
921
922 pub fn status(&self) -> ThreadStatus {
923 if self.send_task.is_some() {
924 ThreadStatus::Generating
925 } else {
926 ThreadStatus::Idle
927 }
928 }
929
930 pub fn token_usage(&self) -> Option<&TokenUsage> {
931 self.token_usage.as_ref()
932 }
933
934 pub fn has_pending_edit_tool_calls(&self) -> bool {
935 for entry in self.entries.iter().rev() {
936 match entry {
937 AgentThreadEntry::UserMessage(_) => return false,
938 AgentThreadEntry::ToolCall(
939 call @ ToolCall {
940 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
941 ..
942 },
943 ) if call.diffs().next().is_some() => {
944 return true;
945 }
946 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
947 }
948 }
949
950 false
951 }
952
953 pub fn used_tools_since_last_user_message(&self) -> bool {
954 for entry in self.entries.iter().rev() {
955 match entry {
956 AgentThreadEntry::UserMessage(..) => return false,
957 AgentThreadEntry::AssistantMessage(..) => continue,
958 AgentThreadEntry::ToolCall(..) => return true,
959 }
960 }
961
962 false
963 }
964
965 pub fn handle_session_update(
966 &mut self,
967 update: acp::SessionUpdate,
968 cx: &mut Context<Self>,
969 ) -> Result<(), acp::Error> {
970 match update {
971 acp::SessionUpdate::UserMessageChunk { content } => {
972 self.push_user_content_block(None, content, cx);
973 }
974 acp::SessionUpdate::AgentMessageChunk { content } => {
975 self.push_assistant_content_block(content, false, cx);
976 }
977 acp::SessionUpdate::AgentThoughtChunk { content } => {
978 self.push_assistant_content_block(content, true, cx);
979 }
980 acp::SessionUpdate::ToolCall(tool_call) => {
981 self.upsert_tool_call(tool_call, cx)?;
982 }
983 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
984 self.update_tool_call(tool_call_update, cx)?;
985 }
986 acp::SessionUpdate::Plan(plan) => {
987 self.update_plan(plan, cx);
988 }
989 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
990 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
991 }
992 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
993 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
994 }
995 }
996 Ok(())
997 }
998
999 pub fn push_user_content_block(
1000 &mut self,
1001 message_id: Option<UserMessageId>,
1002 chunk: acp::ContentBlock,
1003 cx: &mut Context<Self>,
1004 ) {
1005 let language_registry = self.project.read(cx).languages().clone();
1006 let entries_len = self.entries.len();
1007
1008 if let Some(last_entry) = self.entries.last_mut()
1009 && let AgentThreadEntry::UserMessage(UserMessage {
1010 id,
1011 content,
1012 chunks,
1013 ..
1014 }) = last_entry
1015 {
1016 *id = message_id.or(id.take());
1017 content.append(chunk.clone(), &language_registry, cx);
1018 chunks.push(chunk);
1019 let idx = entries_len - 1;
1020 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1021 } else {
1022 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1023 self.push_entry(
1024 AgentThreadEntry::UserMessage(UserMessage {
1025 id: message_id,
1026 content,
1027 chunks: vec![chunk],
1028 checkpoint: None,
1029 }),
1030 cx,
1031 );
1032 }
1033 }
1034
1035 pub fn push_assistant_content_block(
1036 &mut self,
1037 chunk: acp::ContentBlock,
1038 is_thought: bool,
1039 cx: &mut Context<Self>,
1040 ) {
1041 let language_registry = self.project.read(cx).languages().clone();
1042 let entries_len = self.entries.len();
1043 if let Some(last_entry) = self.entries.last_mut()
1044 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1045 {
1046 let idx = entries_len - 1;
1047 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1048 match (chunks.last_mut(), is_thought) {
1049 (Some(AssistantMessageChunk::Message { block }), false)
1050 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1051 block.append(chunk, &language_registry, cx)
1052 }
1053 _ => {
1054 let block = ContentBlock::new(chunk, &language_registry, cx);
1055 if is_thought {
1056 chunks.push(AssistantMessageChunk::Thought { block })
1057 } else {
1058 chunks.push(AssistantMessageChunk::Message { block })
1059 }
1060 }
1061 }
1062 } else {
1063 let block = ContentBlock::new(chunk, &language_registry, cx);
1064 let chunk = if is_thought {
1065 AssistantMessageChunk::Thought { block }
1066 } else {
1067 AssistantMessageChunk::Message { block }
1068 };
1069
1070 self.push_entry(
1071 AgentThreadEntry::AssistantMessage(AssistantMessage {
1072 chunks: vec![chunk],
1073 }),
1074 cx,
1075 );
1076 }
1077 }
1078
1079 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1080 self.entries.push(entry);
1081 cx.emit(AcpThreadEvent::NewEntry);
1082 }
1083
1084 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1085 self.connection.set_title(&self.session_id, cx).is_some()
1086 }
1087
1088 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1089 if title != self.title {
1090 self.title = title.clone();
1091 cx.emit(AcpThreadEvent::TitleUpdated);
1092 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1093 return set_title.run(title, cx);
1094 }
1095 }
1096 Task::ready(Ok(()))
1097 }
1098
1099 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1100 self.token_usage = usage;
1101 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1102 }
1103
1104 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1105 cx.emit(AcpThreadEvent::Retry(status));
1106 }
1107
1108 pub fn update_tool_call(
1109 &mut self,
1110 update: impl Into<ToolCallUpdate>,
1111 cx: &mut Context<Self>,
1112 ) -> Result<()> {
1113 let update = update.into();
1114 let languages = self.project.read(cx).languages().clone();
1115
1116 let ix = match self.index_for_tool_call(update.id()) {
1117 Some(ix) => ix,
1118 None => {
1119 // Tool call not found - create a failed tool call entry
1120 let failed_tool_call = ToolCall {
1121 id: update.id().clone(),
1122 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1123 kind: acp::ToolKind::Fetch,
1124 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1125 acp::ContentBlock::Text(acp::TextContent {
1126 text: "Tool call not found".to_string(),
1127 annotations: None,
1128 meta: None,
1129 }),
1130 &languages,
1131 cx,
1132 ))],
1133 status: ToolCallStatus::Failed,
1134 locations: Vec::new(),
1135 resolved_locations: Vec::new(),
1136 raw_input: None,
1137 raw_output: None,
1138 };
1139 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1140 return Ok(());
1141 }
1142 };
1143 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1144 unreachable!()
1145 };
1146
1147 match update {
1148 ToolCallUpdate::UpdateFields(update) => {
1149 // Check if there's terminal output in the meta field
1150 let terminal_output_result = update
1151 .meta
1152 .as_ref()
1153 .and_then(|meta| meta.get("terminal_output"))
1154 .and_then(|terminal_output| {
1155 match (
1156 terminal_output.get("terminal_id").and_then(|v| v.as_str()),
1157 terminal_output.get("data").and_then(|v| v.as_str()),
1158 ) {
1159 (Some(terminal_id_str), Some(data_str)) => {
1160 let data = data_str.as_bytes().to_vec();
1161 let terminal_id = acp::TerminalId(terminal_id_str.into());
1162 Some((terminal_id, data))
1163 }
1164 _ => None,
1165 }
1166 });
1167
1168 let location_updated = update.fields.locations.is_some();
1169 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1170
1171 if let Some((terminal_id, data)) = terminal_output_result {
1172 // Silently ignore errors - terminal output streaming is best-effort
1173 let _ = self.write_terminal_output(terminal_id, &data, cx);
1174 }
1175 if location_updated {
1176 self.resolve_locations(update.id, cx);
1177 }
1178 }
1179 ToolCallUpdate::UpdateDiff(update) => {
1180 call.content.clear();
1181 call.content.push(ToolCallContent::Diff(update.diff));
1182 }
1183 ToolCallUpdate::UpdateTerminal(update) => {
1184 call.content.clear();
1185 call.content
1186 .push(ToolCallContent::Terminal(update.terminal));
1187 }
1188 }
1189
1190 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1191
1192 Ok(())
1193 }
1194
1195 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1196 pub fn upsert_tool_call(
1197 &mut self,
1198 tool_call: acp::ToolCall,
1199 cx: &mut Context<Self>,
1200 ) -> Result<(), acp::Error> {
1201 let status = tool_call.status.into();
1202 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1203 }
1204
1205 /// Fails if id does not match an existing entry.
1206 pub fn upsert_tool_call_inner(
1207 &mut self,
1208 update: acp::ToolCallUpdate,
1209 status: ToolCallStatus,
1210 cx: &mut Context<Self>,
1211 ) -> Result<(), acp::Error> {
1212 let language_registry = self.project.read(cx).languages().clone();
1213 let id = update.id.clone();
1214
1215 if let Some(ix) = self.index_for_tool_call(&id) {
1216 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1217 unreachable!()
1218 };
1219
1220 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1221 call.status = status;
1222
1223 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224 } else {
1225 let call = ToolCall::from_acp(
1226 update.try_into()?,
1227 status,
1228 language_registry,
1229 &self.terminals,
1230 cx,
1231 )?;
1232 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1233 };
1234
1235 self.resolve_locations(id, cx);
1236 Ok(())
1237 }
1238
1239 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1240 self.entries
1241 .iter()
1242 .enumerate()
1243 .rev()
1244 .find_map(|(index, entry)| {
1245 if let AgentThreadEntry::ToolCall(tool_call) = entry
1246 && &tool_call.id == id
1247 {
1248 Some(index)
1249 } else {
1250 None
1251 }
1252 })
1253 }
1254
1255 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1256 // The tool call we are looking for is typically the last one, or very close to the end.
1257 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1258 self.entries
1259 .iter_mut()
1260 .enumerate()
1261 .rev()
1262 .find_map(|(index, tool_call)| {
1263 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1264 && &tool_call.id == id
1265 {
1266 Some((index, tool_call))
1267 } else {
1268 None
1269 }
1270 })
1271 }
1272
1273 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1274 self.entries
1275 .iter()
1276 .enumerate()
1277 .rev()
1278 .find_map(|(index, tool_call)| {
1279 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1280 && &tool_call.id == id
1281 {
1282 Some((index, tool_call))
1283 } else {
1284 None
1285 }
1286 })
1287 }
1288
1289 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1290 let project = self.project.clone();
1291 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1292 return;
1293 };
1294 let task = tool_call.resolve_locations(project, cx);
1295 cx.spawn(async move |this, cx| {
1296 let resolved_locations = task.await;
1297 this.update(cx, |this, cx| {
1298 let project = this.project.clone();
1299 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1300 return;
1301 };
1302 if let Some(Some(location)) = resolved_locations.last() {
1303 project.update(cx, |project, cx| {
1304 if let Some(agent_location) = project.agent_location() {
1305 let should_ignore = agent_location.buffer == location.buffer
1306 && location
1307 .buffer
1308 .update(cx, |buffer, _| {
1309 let snapshot = buffer.snapshot();
1310 let old_position =
1311 agent_location.position.to_point(&snapshot);
1312 let new_position = location.position.to_point(&snapshot);
1313 // ignore this so that when we get updates from the edit tool
1314 // the position doesn't reset to the startof line
1315 old_position.row == new_position.row
1316 && old_position.column > new_position.column
1317 })
1318 .ok()
1319 .unwrap_or_default();
1320 if !should_ignore {
1321 project.set_agent_location(Some(location.clone()), cx);
1322 }
1323 }
1324 });
1325 }
1326 if tool_call.resolved_locations != resolved_locations {
1327 tool_call.resolved_locations = resolved_locations;
1328 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1329 }
1330 })
1331 })
1332 .detach();
1333 }
1334
1335 pub fn request_tool_call_authorization(
1336 &mut self,
1337 tool_call: acp::ToolCallUpdate,
1338 options: Vec<acp::PermissionOption>,
1339 respect_always_allow_setting: bool,
1340 cx: &mut Context<Self>,
1341 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1342 let (tx, rx) = oneshot::channel();
1343
1344 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1345 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1346 // some tools would (incorrectly) continue to auto-accept.
1347 if let Some(allow_once_option) = options.iter().find_map(|option| {
1348 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1349 Some(option.id.clone())
1350 } else {
1351 None
1352 }
1353 }) {
1354 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1355 return Ok(async {
1356 acp::RequestPermissionOutcome::Selected {
1357 option_id: allow_once_option,
1358 }
1359 }
1360 .boxed());
1361 }
1362 }
1363
1364 let status = ToolCallStatus::WaitingForConfirmation {
1365 options,
1366 respond_tx: tx,
1367 };
1368
1369 self.upsert_tool_call_inner(tool_call, status, cx)?;
1370 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1371
1372 let fut = async {
1373 match rx.await {
1374 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1375 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1376 }
1377 }
1378 .boxed();
1379
1380 Ok(fut)
1381 }
1382
1383 pub fn authorize_tool_call(
1384 &mut self,
1385 id: acp::ToolCallId,
1386 option_id: acp::PermissionOptionId,
1387 option_kind: acp::PermissionOptionKind,
1388 cx: &mut Context<Self>,
1389 ) {
1390 let Some((ix, call)) = self.tool_call_mut(&id) else {
1391 return;
1392 };
1393
1394 let new_status = match option_kind {
1395 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1396 ToolCallStatus::Rejected
1397 }
1398 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1399 ToolCallStatus::InProgress
1400 }
1401 };
1402
1403 let curr_status = mem::replace(&mut call.status, new_status);
1404
1405 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1406 respond_tx.send(option_id).log_err();
1407 } else if cfg!(debug_assertions) {
1408 panic!("tried to authorize an already authorized tool call");
1409 }
1410
1411 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1412 }
1413
1414 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1415 let mut first_tool_call = None;
1416
1417 for entry in self.entries.iter().rev() {
1418 match &entry {
1419 AgentThreadEntry::ToolCall(call) => {
1420 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1421 first_tool_call = Some(call);
1422 } else {
1423 continue;
1424 }
1425 }
1426 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1427 // Reached the beginning of the turn.
1428 // If we had pending permission requests in the previous turn, they have been cancelled.
1429 break;
1430 }
1431 }
1432 }
1433
1434 first_tool_call
1435 }
1436
1437 pub fn plan(&self) -> &Plan {
1438 &self.plan
1439 }
1440
1441 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1442 let new_entries_len = request.entries.len();
1443 let mut new_entries = request.entries.into_iter();
1444
1445 // Reuse existing markdown to prevent flickering
1446 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1447 let PlanEntry {
1448 content,
1449 priority,
1450 status,
1451 } = old;
1452 content.update(cx, |old, cx| {
1453 old.replace(new.content, cx);
1454 });
1455 *priority = new.priority;
1456 *status = new.status;
1457 }
1458 for new in new_entries {
1459 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1460 }
1461 self.plan.entries.truncate(new_entries_len);
1462
1463 cx.notify();
1464 }
1465
1466 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1467 self.plan
1468 .entries
1469 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1470 cx.notify();
1471 }
1472
1473 #[cfg(any(test, feature = "test-support"))]
1474 pub fn send_raw(
1475 &mut self,
1476 message: &str,
1477 cx: &mut Context<Self>,
1478 ) -> BoxFuture<'static, Result<()>> {
1479 self.send(
1480 vec![acp::ContentBlock::Text(acp::TextContent {
1481 text: message.to_string(),
1482 annotations: None,
1483 meta: None,
1484 })],
1485 cx,
1486 )
1487 }
1488
1489 pub fn send(
1490 &mut self,
1491 message: Vec<acp::ContentBlock>,
1492 cx: &mut Context<Self>,
1493 ) -> BoxFuture<'static, Result<()>> {
1494 let block = ContentBlock::new_combined(
1495 message.clone(),
1496 self.project.read(cx).languages().clone(),
1497 cx,
1498 );
1499 let request = acp::PromptRequest {
1500 prompt: message.clone(),
1501 session_id: self.session_id.clone(),
1502 meta: None,
1503 };
1504 let git_store = self.project.read(cx).git_store().clone();
1505
1506 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1507 Some(UserMessageId::new())
1508 } else {
1509 None
1510 };
1511
1512 self.run_turn(cx, async move |this, cx| {
1513 this.update(cx, |this, cx| {
1514 this.push_entry(
1515 AgentThreadEntry::UserMessage(UserMessage {
1516 id: message_id.clone(),
1517 content: block,
1518 chunks: message,
1519 checkpoint: None,
1520 }),
1521 cx,
1522 );
1523 })
1524 .ok();
1525
1526 let old_checkpoint = git_store
1527 .update(cx, |git, cx| git.checkpoint(cx))?
1528 .await
1529 .context("failed to get old checkpoint")
1530 .log_err();
1531 this.update(cx, |this, cx| {
1532 if let Some((_ix, message)) = this.last_user_message() {
1533 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1534 git_checkpoint,
1535 show: false,
1536 });
1537 }
1538 this.connection.prompt(message_id, request, cx)
1539 })?
1540 .await
1541 })
1542 }
1543
1544 pub fn can_resume(&self, cx: &App) -> bool {
1545 self.connection.resume(&self.session_id, cx).is_some()
1546 }
1547
1548 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1549 self.run_turn(cx, async move |this, cx| {
1550 this.update(cx, |this, cx| {
1551 this.connection
1552 .resume(&this.session_id, cx)
1553 .map(|resume| resume.run(cx))
1554 })?
1555 .context("resuming a session is not supported")?
1556 .await
1557 })
1558 }
1559
1560 fn run_turn(
1561 &mut self,
1562 cx: &mut Context<Self>,
1563 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1564 ) -> BoxFuture<'static, Result<()>> {
1565 self.clear_completed_plan_entries(cx);
1566
1567 let (tx, rx) = oneshot::channel();
1568 let cancel_task = self.cancel(cx);
1569
1570 self.send_task = Some(cx.spawn(async move |this, cx| {
1571 cancel_task.await;
1572 tx.send(f(this, cx).await).ok();
1573 }));
1574
1575 cx.spawn(async move |this, cx| {
1576 let response = rx.await;
1577
1578 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1579 .await?;
1580
1581 this.update(cx, |this, cx| {
1582 this.project
1583 .update(cx, |project, cx| project.set_agent_location(None, cx));
1584 match response {
1585 Ok(Err(e)) => {
1586 this.send_task.take();
1587 cx.emit(AcpThreadEvent::Error);
1588 Err(e)
1589 }
1590 result => {
1591 let canceled = matches!(
1592 result,
1593 Ok(Ok(acp::PromptResponse {
1594 stop_reason: acp::StopReason::Cancelled,
1595 meta: None,
1596 }))
1597 );
1598
1599 // We only take the task if the current prompt wasn't canceled.
1600 //
1601 // This prompt may have been canceled because another one was sent
1602 // while it was still generating. In these cases, dropping `send_task`
1603 // would cause the next generation to be canceled.
1604 if !canceled {
1605 this.send_task.take();
1606 }
1607
1608 // Handle refusal - distinguish between user prompt and tool call refusals
1609 if let Ok(Ok(acp::PromptResponse {
1610 stop_reason: acp::StopReason::Refusal,
1611 meta: _,
1612 })) = result
1613 {
1614 if let Some((user_msg_ix, _)) = this.last_user_message() {
1615 // Check if there's a completed tool call with results after the last user message
1616 // This indicates the refusal is in response to tool output, not the user's prompt
1617 let has_completed_tool_call_after_user_msg =
1618 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1619 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1620 // Check if the tool call has completed and has output
1621 matches!(tool_call.status, ToolCallStatus::Completed)
1622 && tool_call.raw_output.is_some()
1623 } else {
1624 false
1625 }
1626 });
1627
1628 if has_completed_tool_call_after_user_msg {
1629 // Refusal is due to tool output - don't truncate, just notify
1630 // The model refused based on what the tool returned
1631 cx.emit(AcpThreadEvent::Refusal);
1632 } else {
1633 // User prompt was refused - truncate back to before the user message
1634 let range = user_msg_ix..this.entries.len();
1635 if range.start < range.end {
1636 this.entries.truncate(user_msg_ix);
1637 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1638 }
1639 cx.emit(AcpThreadEvent::Refusal);
1640 }
1641 } else {
1642 // No user message found, treat as general refusal
1643 cx.emit(AcpThreadEvent::Refusal);
1644 }
1645 }
1646
1647 cx.emit(AcpThreadEvent::Stopped);
1648 Ok(())
1649 }
1650 }
1651 })?
1652 })
1653 .boxed()
1654 }
1655
1656 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1657 let Some(send_task) = self.send_task.take() else {
1658 return Task::ready(());
1659 };
1660
1661 for entry in self.entries.iter_mut() {
1662 if let AgentThreadEntry::ToolCall(call) = entry {
1663 let cancel = matches!(
1664 call.status,
1665 ToolCallStatus::Pending
1666 | ToolCallStatus::WaitingForConfirmation { .. }
1667 | ToolCallStatus::InProgress
1668 );
1669
1670 if cancel {
1671 call.status = ToolCallStatus::Canceled;
1672 }
1673 }
1674 }
1675
1676 self.connection.cancel(&self.session_id, cx);
1677
1678 // Wait for the send task to complete
1679 cx.foreground_executor().spawn(send_task)
1680 }
1681
1682 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1683 pub fn restore_checkpoint(
1684 &mut self,
1685 id: UserMessageId,
1686 cx: &mut Context<Self>,
1687 ) -> Task<Result<()>> {
1688 let Some((_, message)) = self.user_message_mut(&id) else {
1689 return Task::ready(Err(anyhow!("message not found")));
1690 };
1691
1692 let checkpoint = message
1693 .checkpoint
1694 .as_ref()
1695 .map(|c| c.git_checkpoint.clone());
1696 let rewind = self.rewind(id.clone(), cx);
1697 let git_store = self.project.read(cx).git_store().clone();
1698
1699 cx.spawn(async move |_, cx| {
1700 rewind.await?;
1701 if let Some(checkpoint) = checkpoint {
1702 git_store
1703 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1704 .await?;
1705 }
1706
1707 Ok(())
1708 })
1709 }
1710
1711 /// Rewinds this thread to before the entry at `index`, removing it and all
1712 /// subsequent entries while rejecting any action_log changes made from that point.
1713 /// Unlike `restore_checkpoint`, this method does not restore from git.
1714 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1715 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1716 return Task::ready(Err(anyhow!("not supported")));
1717 };
1718
1719 cx.spawn(async move |this, cx| {
1720 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1721 this.update(cx, |this, cx| {
1722 if let Some((ix, _)) = this.user_message_mut(&id) {
1723 let range = ix..this.entries.len();
1724 this.entries.truncate(ix);
1725 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1726 }
1727 this.action_log()
1728 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1729 })?
1730 .await;
1731 Ok(())
1732 })
1733 }
1734
1735 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1736 let git_store = self.project.read(cx).git_store().clone();
1737
1738 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1739 if let Some(checkpoint) = message.checkpoint.as_ref() {
1740 checkpoint.git_checkpoint.clone()
1741 } else {
1742 return Task::ready(Ok(()));
1743 }
1744 } else {
1745 return Task::ready(Ok(()));
1746 };
1747
1748 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1749 cx.spawn(async move |this, cx| {
1750 let new_checkpoint = new_checkpoint
1751 .await
1752 .context("failed to get new checkpoint")
1753 .log_err();
1754 if let Some(new_checkpoint) = new_checkpoint {
1755 let equal = git_store
1756 .update(cx, |git, cx| {
1757 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1758 })?
1759 .await
1760 .unwrap_or(true);
1761 this.update(cx, |this, cx| {
1762 let (ix, message) = this.last_user_message().context("no user message")?;
1763 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1764 checkpoint.show = !equal;
1765 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1766 anyhow::Ok(())
1767 })??;
1768 }
1769
1770 Ok(())
1771 })
1772 }
1773
1774 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1775 self.entries
1776 .iter_mut()
1777 .enumerate()
1778 .rev()
1779 .find_map(|(ix, entry)| {
1780 if let AgentThreadEntry::UserMessage(message) = entry {
1781 Some((ix, message))
1782 } else {
1783 None
1784 }
1785 })
1786 }
1787
1788 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1789 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1790 if let AgentThreadEntry::UserMessage(message) = entry {
1791 if message.id.as_ref() == Some(id) {
1792 Some((ix, message))
1793 } else {
1794 None
1795 }
1796 } else {
1797 None
1798 }
1799 })
1800 }
1801
1802 pub fn read_text_file(
1803 &self,
1804 path: PathBuf,
1805 line: Option<u32>,
1806 limit: Option<u32>,
1807 reuse_shared_snapshot: bool,
1808 cx: &mut Context<Self>,
1809 ) -> Task<Result<String>> {
1810 // Args are 1-based, move to 0-based
1811 let line = line.unwrap_or_default().saturating_sub(1);
1812 let limit = limit.unwrap_or(u32::MAX);
1813 let project = self.project.clone();
1814 let action_log = self.action_log.clone();
1815 cx.spawn(async move |this, cx| {
1816 let load = project.update(cx, |project, cx| {
1817 let path = project
1818 .project_path_for_absolute_path(&path, cx)
1819 .context("invalid path")?;
1820 anyhow::Ok(project.open_buffer(path, cx))
1821 });
1822 let buffer = load??.await?;
1823
1824 let snapshot = if reuse_shared_snapshot {
1825 this.read_with(cx, |this, _| {
1826 this.shared_buffers.get(&buffer.clone()).cloned()
1827 })
1828 .log_err()
1829 .flatten()
1830 } else {
1831 None
1832 };
1833
1834 let snapshot = if let Some(snapshot) = snapshot {
1835 snapshot
1836 } else {
1837 action_log.update(cx, |action_log, cx| {
1838 action_log.buffer_read(buffer.clone(), cx);
1839 })?;
1840
1841 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1842 this.update(cx, |this, _| {
1843 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1844 })?;
1845 snapshot
1846 };
1847
1848 let max_point = snapshot.max_point();
1849 if line >= max_point.row {
1850 anyhow::bail!(
1851 "Attempting to read beyond the end of the file, line {}:{}",
1852 max_point.row + 1,
1853 max_point.column
1854 );
1855 }
1856
1857 let start = snapshot.anchor_before(Point::new(line, 0));
1858 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1859
1860 project.update(cx, |project, cx| {
1861 project.set_agent_location(
1862 Some(AgentLocation {
1863 buffer: buffer.downgrade(),
1864 position: start,
1865 }),
1866 cx,
1867 );
1868 })?;
1869
1870 Ok(snapshot.text_for_range(start..end).collect::<String>())
1871 })
1872 }
1873
1874 pub fn write_text_file(
1875 &self,
1876 path: PathBuf,
1877 content: String,
1878 cx: &mut Context<Self>,
1879 ) -> Task<Result<()>> {
1880 let project = self.project.clone();
1881 let action_log = self.action_log.clone();
1882 cx.spawn(async move |this, cx| {
1883 let load = project.update(cx, |project, cx| {
1884 let path = project
1885 .project_path_for_absolute_path(&path, cx)
1886 .context("invalid path")?;
1887 anyhow::Ok(project.open_buffer(path, cx))
1888 });
1889 let buffer = load??.await?;
1890 let snapshot = this.update(cx, |this, cx| {
1891 this.shared_buffers
1892 .get(&buffer)
1893 .cloned()
1894 .unwrap_or_else(|| buffer.read(cx).snapshot())
1895 })?;
1896 let edits = cx
1897 .background_executor()
1898 .spawn(async move {
1899 let old_text = snapshot.text();
1900 text_diff(old_text.as_str(), &content)
1901 .into_iter()
1902 .map(|(range, replacement)| {
1903 (
1904 snapshot.anchor_after(range.start)
1905 ..snapshot.anchor_before(range.end),
1906 replacement,
1907 )
1908 })
1909 .collect::<Vec<_>>()
1910 })
1911 .await;
1912
1913 project.update(cx, |project, cx| {
1914 project.set_agent_location(
1915 Some(AgentLocation {
1916 buffer: buffer.downgrade(),
1917 position: edits
1918 .last()
1919 .map(|(range, _)| range.end)
1920 .unwrap_or(Anchor::MIN),
1921 }),
1922 cx,
1923 );
1924 })?;
1925
1926 let format_on_save = cx.update(|cx| {
1927 action_log.update(cx, |action_log, cx| {
1928 action_log.buffer_read(buffer.clone(), cx);
1929 });
1930
1931 let format_on_save = buffer.update(cx, |buffer, cx| {
1932 buffer.edit(edits, None, cx);
1933
1934 let settings = language::language_settings::language_settings(
1935 buffer.language().map(|l| l.name()),
1936 buffer.file(),
1937 cx,
1938 );
1939
1940 settings.format_on_save != FormatOnSave::Off
1941 });
1942 action_log.update(cx, |action_log, cx| {
1943 action_log.buffer_edited(buffer.clone(), cx);
1944 });
1945 format_on_save
1946 })?;
1947
1948 if format_on_save {
1949 let format_task = project.update(cx, |project, cx| {
1950 project.format(
1951 HashSet::from_iter([buffer.clone()]),
1952 LspFormatTarget::Buffers,
1953 false,
1954 FormatTrigger::Save,
1955 cx,
1956 )
1957 })?;
1958 format_task.await.log_err();
1959
1960 action_log.update(cx, |action_log, cx| {
1961 action_log.buffer_edited(buffer.clone(), cx);
1962 })?;
1963 }
1964
1965 project
1966 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1967 .await
1968 })
1969 }
1970
1971 pub fn create_terminal(
1972 &self,
1973 command: String,
1974 args: Vec<String>,
1975 extra_env: Vec<acp::EnvVariable>,
1976 cwd: Option<PathBuf>,
1977 output_byte_limit: Option<u64>,
1978 is_display_only: bool,
1979 cx: &mut Context<Self>,
1980 ) -> Task<Result<Entity<Terminal>>> {
1981 let env = match &cwd {
1982 Some(dir) => self.project.update(cx, |project, cx| {
1983 project.directory_environment(dir.as_path().into(), cx)
1984 }),
1985 None => Task::ready(None).shared(),
1986 };
1987
1988 let env = cx.spawn(async move |_, _| {
1989 let mut env = env.await.unwrap_or_default();
1990 if cfg!(unix) {
1991 env.insert("PAGER".into(), "cat".into());
1992 }
1993 for var in extra_env {
1994 env.insert(var.name, var.value);
1995 }
1996 env
1997 });
1998
1999 let project = self.project.clone();
2000 let language_registry = project.read(cx).languages().clone();
2001
2002 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2003 let terminal_task = cx.spawn({
2004 let terminal_id = terminal_id.clone();
2005 async move |_this, cx| {
2006 let env = env.await;
2007 let (command, args) = ShellBuilder::new(
2008 project
2009 .update(cx, |project, cx| {
2010 project
2011 .remote_client()
2012 .and_then(|r| r.read(cx).default_system_shell())
2013 })?
2014 .as_deref(),
2015 &Shell::Program(get_default_system_shell()),
2016 )
2017 .redirect_stdin_to_dev_null()
2018 .build(Some(command), &args);
2019
2020 let terminal = if is_display_only {
2021 cx.update(|cx| {
2022 TerminalBuilder::new_display_only(
2023 Some(format!("Display: {}", command).into()),
2024 CursorShape::Block,
2025 AlternateScroll::On,
2026 Some(10_000),
2027 cx,
2028 )
2029 })??
2030 } else {
2031 project
2032 .update(cx, |project, cx| {
2033 project.create_terminal_task(
2034 task::SpawnInTerminal {
2035 command: Some(command.clone()),
2036 args: args.clone(),
2037 cwd: cwd.clone(),
2038 env,
2039 ..Default::default()
2040 },
2041 cx,
2042 )
2043 })?
2044 .await?
2045 };
2046
2047 if is_display_only {
2048 // For display-only terminals, we need special handling
2049 cx.new(|cx| {
2050 Terminal::new_display_only(
2051 terminal_id,
2052 &format!("{} {}", command, args.join(" ")),
2053 cwd,
2054 output_byte_limit.map(|l| l as usize),
2055 terminal,
2056 cx,
2057 )
2058 })
2059 } else {
2060 cx.new(|cx| {
2061 Terminal::new(
2062 terminal_id,
2063 &format!("{} {}", command, args.join(" ")),
2064 cwd,
2065 output_byte_limit.map(|l| l as usize),
2066 terminal,
2067 language_registry,
2068 cx,
2069 )
2070 })
2071 }
2072 }
2073 });
2074
2075 cx.spawn(async move |this, cx| {
2076 let terminal = terminal_task.await?;
2077 this.update(cx, |this, _cx| {
2078 this.terminals.insert(terminal_id, terminal.clone());
2079 terminal
2080 })
2081 })
2082 }
2083
2084 pub fn kill_terminal(
2085 &mut self,
2086 terminal_id: acp::TerminalId,
2087 cx: &mut Context<Self>,
2088 ) -> Result<()> {
2089 self.terminals
2090 .get(&terminal_id)
2091 .context("Terminal not found")?
2092 .update(cx, |terminal, cx| {
2093 terminal.kill(cx);
2094 });
2095
2096 Ok(())
2097 }
2098
2099 pub fn release_terminal(
2100 &mut self,
2101 terminal_id: acp::TerminalId,
2102 cx: &mut Context<Self>,
2103 ) -> Result<()> {
2104 self.terminals
2105 .remove(&terminal_id)
2106 .context("Terminal not found")?
2107 .update(cx, |terminal, cx| {
2108 terminal.kill(cx);
2109 });
2110
2111 Ok(())
2112 }
2113
2114 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2115 self.terminals
2116 .get(&terminal_id)
2117 .cloned()
2118 .context("Terminal not found")
2119 }
2120
2121 pub fn write_terminal_output(
2122 &mut self,
2123 terminal_id: acp::TerminalId,
2124 output: &[u8],
2125 cx: &mut Context<Self>,
2126 ) -> Result<()> {
2127 let terminal = self
2128 .terminals
2129 .get(&terminal_id)
2130 .context("Terminal not found")?;
2131
2132 terminal.update(cx, |terminal, cx| {
2133 terminal.write_output(output, cx);
2134 });
2135
2136 Ok(())
2137 }
2138
2139 pub fn to_markdown(&self, cx: &App) -> String {
2140 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2141 }
2142
2143 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2144 cx.emit(AcpThreadEvent::LoadError(error));
2145 }
2146}
2147
2148fn markdown_for_raw_output(
2149 raw_output: &serde_json::Value,
2150 language_registry: &Arc<LanguageRegistry>,
2151 cx: &mut App,
2152) -> Option<Entity<Markdown>> {
2153 match raw_output {
2154 serde_json::Value::Null => None,
2155 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2156 Markdown::new(
2157 value.to_string().into(),
2158 Some(language_registry.clone()),
2159 None,
2160 cx,
2161 )
2162 })),
2163 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2164 Markdown::new(
2165 value.to_string().into(),
2166 Some(language_registry.clone()),
2167 None,
2168 cx,
2169 )
2170 })),
2171 serde_json::Value::String(value) => Some(cx.new(|cx| {
2172 Markdown::new(
2173 value.clone().into(),
2174 Some(language_registry.clone()),
2175 None,
2176 cx,
2177 )
2178 })),
2179 value => Some(cx.new(|cx| {
2180 Markdown::new(
2181 format!("```json\n{}\n```", value).into(),
2182 Some(language_registry.clone()),
2183 None,
2184 cx,
2185 )
2186 })),
2187 }
2188}
2189
2190#[cfg(test)]
2191mod tests {
2192 use super::*;
2193 use anyhow::anyhow;
2194 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2195 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2196 use indoc::indoc;
2197 use project::{FakeFs, Fs};
2198 use rand::{distr, prelude::*};
2199 use serde_json::json;
2200 use settings::SettingsStore;
2201 use smol::stream::StreamExt as _;
2202 use std::{
2203 any::Any,
2204 cell::RefCell,
2205 path::Path,
2206 rc::Rc,
2207 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2208 time::Duration,
2209 };
2210 use util::path;
2211
2212 fn init_test(cx: &mut TestAppContext) {
2213 env_logger::try_init().ok();
2214 cx.update(|cx| {
2215 let settings_store = SettingsStore::test(cx);
2216 cx.set_global(settings_store);
2217 Project::init_settings(cx);
2218 language::init(cx);
2219 });
2220 }
2221
2222 #[gpui::test]
2223 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2224 init_test(cx);
2225
2226 let fs = FakeFs::new(cx.executor());
2227 let project = Project::test(fs, [], cx).await;
2228 let connection = Rc::new(FakeAgentConnection::new());
2229 let thread = cx
2230 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2231 .await
2232 .unwrap();
2233
2234 // Test creating a new user message
2235 thread.update(cx, |thread, cx| {
2236 thread.push_user_content_block(
2237 None,
2238 acp::ContentBlock::Text(acp::TextContent {
2239 annotations: None,
2240 text: "Hello, ".to_string(),
2241 meta: None,
2242 }),
2243 cx,
2244 );
2245 });
2246
2247 thread.update(cx, |thread, cx| {
2248 assert_eq!(thread.entries.len(), 1);
2249 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2250 assert_eq!(user_msg.id, None);
2251 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2252 } else {
2253 panic!("Expected UserMessage");
2254 }
2255 });
2256
2257 // Test appending to existing user message
2258 let message_1_id = UserMessageId::new();
2259 thread.update(cx, |thread, cx| {
2260 thread.push_user_content_block(
2261 Some(message_1_id.clone()),
2262 acp::ContentBlock::Text(acp::TextContent {
2263 annotations: None,
2264 text: "world!".to_string(),
2265 meta: None,
2266 }),
2267 cx,
2268 );
2269 });
2270
2271 thread.update(cx, |thread, cx| {
2272 assert_eq!(thread.entries.len(), 1);
2273 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2274 assert_eq!(user_msg.id, Some(message_1_id));
2275 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2276 } else {
2277 panic!("Expected UserMessage");
2278 }
2279 });
2280
2281 // Test creating new user message after assistant message
2282 thread.update(cx, |thread, cx| {
2283 thread.push_assistant_content_block(
2284 acp::ContentBlock::Text(acp::TextContent {
2285 annotations: None,
2286 text: "Assistant response".to_string(),
2287 meta: None,
2288 }),
2289 false,
2290 cx,
2291 );
2292 });
2293
2294 let message_2_id = UserMessageId::new();
2295 thread.update(cx, |thread, cx| {
2296 thread.push_user_content_block(
2297 Some(message_2_id.clone()),
2298 acp::ContentBlock::Text(acp::TextContent {
2299 annotations: None,
2300 text: "New user message".to_string(),
2301 meta: None,
2302 }),
2303 cx,
2304 );
2305 });
2306
2307 thread.update(cx, |thread, cx| {
2308 assert_eq!(thread.entries.len(), 3);
2309 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2310 assert_eq!(user_msg.id, Some(message_2_id));
2311 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2312 } else {
2313 panic!("Expected UserMessage at index 2");
2314 }
2315 });
2316 }
2317
2318 #[gpui::test]
2319 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2320 init_test(cx);
2321
2322 let fs = FakeFs::new(cx.executor());
2323 let project = Project::test(fs, [], cx).await;
2324 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2325 |_, thread, mut cx| {
2326 async move {
2327 thread.update(&mut cx, |thread, cx| {
2328 thread
2329 .handle_session_update(
2330 acp::SessionUpdate::AgentThoughtChunk {
2331 content: "Thinking ".into(),
2332 },
2333 cx,
2334 )
2335 .unwrap();
2336 thread
2337 .handle_session_update(
2338 acp::SessionUpdate::AgentThoughtChunk {
2339 content: "hard!".into(),
2340 },
2341 cx,
2342 )
2343 .unwrap();
2344 })?;
2345 Ok(acp::PromptResponse {
2346 stop_reason: acp::StopReason::EndTurn,
2347 meta: None,
2348 })
2349 }
2350 .boxed_local()
2351 },
2352 ));
2353
2354 let thread = cx
2355 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2356 .await
2357 .unwrap();
2358
2359 thread
2360 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2361 .await
2362 .unwrap();
2363
2364 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2365 assert_eq!(
2366 output,
2367 indoc! {r#"
2368 ## User
2369
2370 Hello from Zed!
2371
2372 ## Assistant
2373
2374 <thinking>
2375 Thinking hard!
2376 </thinking>
2377
2378 "#}
2379 );
2380 }
2381
2382 #[gpui::test]
2383 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2384 init_test(cx);
2385
2386 let fs = FakeFs::new(cx.executor());
2387 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2388 .await;
2389 let project = Project::test(fs.clone(), [], cx).await;
2390 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2391 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2392 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2393 move |_, thread, mut cx| {
2394 let read_file_tx = read_file_tx.clone();
2395 async move {
2396 let content = thread
2397 .update(&mut cx, |thread, cx| {
2398 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2399 })
2400 .unwrap()
2401 .await
2402 .unwrap();
2403 assert_eq!(content, "one\ntwo\nthree\n");
2404 read_file_tx.take().unwrap().send(()).unwrap();
2405 thread
2406 .update(&mut cx, |thread, cx| {
2407 thread.write_text_file(
2408 path!("/tmp/foo").into(),
2409 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2410 cx,
2411 )
2412 })
2413 .unwrap()
2414 .await
2415 .unwrap();
2416 Ok(acp::PromptResponse {
2417 stop_reason: acp::StopReason::EndTurn,
2418 meta: None,
2419 })
2420 }
2421 .boxed_local()
2422 },
2423 ));
2424
2425 let (worktree, pathbuf) = project
2426 .update(cx, |project, cx| {
2427 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2428 })
2429 .await
2430 .unwrap();
2431 let buffer = project
2432 .update(cx, |project, cx| {
2433 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2434 })
2435 .await
2436 .unwrap();
2437
2438 let thread = cx
2439 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2440 .await
2441 .unwrap();
2442
2443 let request = thread.update(cx, |thread, cx| {
2444 thread.send_raw("Extend the count in /tmp/foo", cx)
2445 });
2446 read_file_rx.await.ok();
2447 buffer.update(cx, |buffer, cx| {
2448 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2449 });
2450 cx.run_until_parked();
2451 assert_eq!(
2452 buffer.read_with(cx, |buffer, _| buffer.text()),
2453 "zero\none\ntwo\nthree\nfour\nfive\n"
2454 );
2455 assert_eq!(
2456 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2457 "zero\none\ntwo\nthree\nfour\nfive\n"
2458 );
2459 request.await.unwrap();
2460 }
2461
2462 #[gpui::test]
2463 async fn test_reading_from_line(cx: &mut TestAppContext) {
2464 init_test(cx);
2465
2466 let fs = FakeFs::new(cx.executor());
2467 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2468 .await;
2469 let project = Project::test(fs.clone(), [], cx).await;
2470 project
2471 .update(cx, |project, cx| {
2472 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2473 })
2474 .await
2475 .unwrap();
2476
2477 let connection = Rc::new(FakeAgentConnection::new());
2478
2479 let thread = cx
2480 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2481 .await
2482 .unwrap();
2483
2484 // Whole file
2485 let content = thread
2486 .update(cx, |thread, cx| {
2487 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2488 })
2489 .await
2490 .unwrap();
2491
2492 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2493
2494 // Only start line
2495 let content = thread
2496 .update(cx, |thread, cx| {
2497 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2498 })
2499 .await
2500 .unwrap();
2501
2502 assert_eq!(content, "three\nfour\n");
2503
2504 // Only limit
2505 let content = thread
2506 .update(cx, |thread, cx| {
2507 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2508 })
2509 .await
2510 .unwrap();
2511
2512 assert_eq!(content, "one\ntwo\n");
2513
2514 // Range
2515 let content = thread
2516 .update(cx, |thread, cx| {
2517 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2518 })
2519 .await
2520 .unwrap();
2521
2522 assert_eq!(content, "two\nthree\n");
2523
2524 // Invalid
2525 let err = thread
2526 .update(cx, |thread, cx| {
2527 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2528 })
2529 .await
2530 .unwrap_err();
2531
2532 assert_eq!(
2533 err.to_string(),
2534 "Attempting to read beyond the end of the file, line 5:0"
2535 );
2536 }
2537
2538 #[gpui::test]
2539 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2540 init_test(cx);
2541
2542 let fs = FakeFs::new(cx.executor());
2543 let project = Project::test(fs, [], cx).await;
2544 let id = acp::ToolCallId("test".into());
2545
2546 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2547 let id = id.clone();
2548 move |_, thread, mut cx| {
2549 let id = id.clone();
2550 async move {
2551 thread
2552 .update(&mut cx, |thread, cx| {
2553 thread.handle_session_update(
2554 acp::SessionUpdate::ToolCall(acp::ToolCall {
2555 id: id.clone(),
2556 title: "Label".into(),
2557 kind: acp::ToolKind::Fetch,
2558 status: acp::ToolCallStatus::InProgress,
2559 content: vec![],
2560 locations: vec![],
2561 raw_input: None,
2562 raw_output: None,
2563 meta: None,
2564 }),
2565 cx,
2566 )
2567 })
2568 .unwrap()
2569 .unwrap();
2570 Ok(acp::PromptResponse {
2571 stop_reason: acp::StopReason::EndTurn,
2572 meta: None,
2573 })
2574 }
2575 .boxed_local()
2576 }
2577 }));
2578
2579 let thread = cx
2580 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2581 .await
2582 .unwrap();
2583
2584 let request = thread.update(cx, |thread, cx| {
2585 thread.send_raw("Fetch https://example.com", cx)
2586 });
2587
2588 run_until_first_tool_call(&thread, cx).await;
2589
2590 thread.read_with(cx, |thread, _| {
2591 assert!(matches!(
2592 thread.entries[1],
2593 AgentThreadEntry::ToolCall(ToolCall {
2594 status: ToolCallStatus::InProgress,
2595 ..
2596 })
2597 ));
2598 });
2599
2600 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2601
2602 thread.read_with(cx, |thread, _| {
2603 assert!(matches!(
2604 &thread.entries[1],
2605 AgentThreadEntry::ToolCall(ToolCall {
2606 status: ToolCallStatus::Canceled,
2607 ..
2608 })
2609 ));
2610 });
2611
2612 thread
2613 .update(cx, |thread, cx| {
2614 thread.handle_session_update(
2615 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2616 id,
2617 fields: acp::ToolCallUpdateFields {
2618 status: Some(acp::ToolCallStatus::Completed),
2619 ..Default::default()
2620 },
2621 meta: None,
2622 }),
2623 cx,
2624 )
2625 })
2626 .unwrap();
2627
2628 request.await.unwrap();
2629
2630 thread.read_with(cx, |thread, _| {
2631 assert!(matches!(
2632 thread.entries[1],
2633 AgentThreadEntry::ToolCall(ToolCall {
2634 status: ToolCallStatus::Completed,
2635 ..
2636 })
2637 ));
2638 });
2639 }
2640
2641 #[gpui::test]
2642 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2643 init_test(cx);
2644 let fs = FakeFs::new(cx.background_executor.clone());
2645 fs.insert_tree(path!("/test"), json!({})).await;
2646 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2647
2648 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2649 move |_, thread, mut cx| {
2650 async move {
2651 thread
2652 .update(&mut cx, |thread, cx| {
2653 thread.handle_session_update(
2654 acp::SessionUpdate::ToolCall(acp::ToolCall {
2655 id: acp::ToolCallId("test".into()),
2656 title: "Label".into(),
2657 kind: acp::ToolKind::Edit,
2658 status: acp::ToolCallStatus::Completed,
2659 content: vec![acp::ToolCallContent::Diff {
2660 diff: acp::Diff {
2661 path: "/test/test.txt".into(),
2662 old_text: None,
2663 new_text: "foo".into(),
2664 meta: None,
2665 },
2666 }],
2667 locations: vec![],
2668 raw_input: None,
2669 raw_output: None,
2670 meta: None,
2671 }),
2672 cx,
2673 )
2674 })
2675 .unwrap()
2676 .unwrap();
2677 Ok(acp::PromptResponse {
2678 stop_reason: acp::StopReason::EndTurn,
2679 meta: None,
2680 })
2681 }
2682 .boxed_local()
2683 }
2684 }));
2685
2686 let thread = cx
2687 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2688 .await
2689 .unwrap();
2690
2691 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2692 .await
2693 .unwrap();
2694
2695 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2696 }
2697
2698 #[gpui::test(iterations = 10)]
2699 async fn test_checkpoints(cx: &mut TestAppContext) {
2700 init_test(cx);
2701 let fs = FakeFs::new(cx.background_executor.clone());
2702 fs.insert_tree(
2703 path!("/test"),
2704 json!({
2705 ".git": {}
2706 }),
2707 )
2708 .await;
2709 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2710
2711 let simulate_changes = Arc::new(AtomicBool::new(true));
2712 let next_filename = Arc::new(AtomicUsize::new(0));
2713 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2714 let simulate_changes = simulate_changes.clone();
2715 let next_filename = next_filename.clone();
2716 let fs = fs.clone();
2717 move |request, thread, mut cx| {
2718 let fs = fs.clone();
2719 let simulate_changes = simulate_changes.clone();
2720 let next_filename = next_filename.clone();
2721 async move {
2722 if simulate_changes.load(SeqCst) {
2723 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2724 fs.write(Path::new(&filename), b"").await?;
2725 }
2726
2727 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2728 panic!("expected text content block");
2729 };
2730 thread.update(&mut cx, |thread, cx| {
2731 thread
2732 .handle_session_update(
2733 acp::SessionUpdate::AgentMessageChunk {
2734 content: content.text.to_uppercase().into(),
2735 },
2736 cx,
2737 )
2738 .unwrap();
2739 })?;
2740 Ok(acp::PromptResponse {
2741 stop_reason: acp::StopReason::EndTurn,
2742 meta: None,
2743 })
2744 }
2745 .boxed_local()
2746 }
2747 }));
2748 let thread = cx
2749 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2750 .await
2751 .unwrap();
2752
2753 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2754 .await
2755 .unwrap();
2756 thread.read_with(cx, |thread, cx| {
2757 assert_eq!(
2758 thread.to_markdown(cx),
2759 indoc! {"
2760 ## User (checkpoint)
2761
2762 Lorem
2763
2764 ## Assistant
2765
2766 LOREM
2767
2768 "}
2769 );
2770 });
2771 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2772
2773 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2774 .await
2775 .unwrap();
2776 thread.read_with(cx, |thread, cx| {
2777 assert_eq!(
2778 thread.to_markdown(cx),
2779 indoc! {"
2780 ## User (checkpoint)
2781
2782 Lorem
2783
2784 ## Assistant
2785
2786 LOREM
2787
2788 ## User (checkpoint)
2789
2790 ipsum
2791
2792 ## Assistant
2793
2794 IPSUM
2795
2796 "}
2797 );
2798 });
2799 assert_eq!(
2800 fs.files(),
2801 vec![
2802 Path::new(path!("/test/file-0")),
2803 Path::new(path!("/test/file-1"))
2804 ]
2805 );
2806
2807 // Checkpoint isn't stored when there are no changes.
2808 simulate_changes.store(false, SeqCst);
2809 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2810 .await
2811 .unwrap();
2812 thread.read_with(cx, |thread, cx| {
2813 assert_eq!(
2814 thread.to_markdown(cx),
2815 indoc! {"
2816 ## User (checkpoint)
2817
2818 Lorem
2819
2820 ## Assistant
2821
2822 LOREM
2823
2824 ## User (checkpoint)
2825
2826 ipsum
2827
2828 ## Assistant
2829
2830 IPSUM
2831
2832 ## User
2833
2834 dolor
2835
2836 ## Assistant
2837
2838 DOLOR
2839
2840 "}
2841 );
2842 });
2843 assert_eq!(
2844 fs.files(),
2845 vec![
2846 Path::new(path!("/test/file-0")),
2847 Path::new(path!("/test/file-1"))
2848 ]
2849 );
2850
2851 // Rewinding the conversation truncates the history and restores the checkpoint.
2852 thread
2853 .update(cx, |thread, cx| {
2854 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2855 panic!("unexpected entries {:?}", thread.entries)
2856 };
2857 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2858 })
2859 .await
2860 .unwrap();
2861 thread.read_with(cx, |thread, cx| {
2862 assert_eq!(
2863 thread.to_markdown(cx),
2864 indoc! {"
2865 ## User (checkpoint)
2866
2867 Lorem
2868
2869 ## Assistant
2870
2871 LOREM
2872
2873 "}
2874 );
2875 });
2876 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2877 }
2878
2879 #[gpui::test]
2880 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2881 use std::sync::atomic::AtomicUsize;
2882 init_test(cx);
2883
2884 let fs = FakeFs::new(cx.executor());
2885 let project = Project::test(fs, None, cx).await;
2886
2887 // Create a connection that simulates refusal after tool result
2888 let prompt_count = Arc::new(AtomicUsize::new(0));
2889 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2890 let prompt_count = prompt_count.clone();
2891 move |_request, thread, mut cx| {
2892 let count = prompt_count.fetch_add(1, SeqCst);
2893 async move {
2894 if count == 0 {
2895 // First prompt: Generate a tool call with result
2896 thread.update(&mut cx, |thread, cx| {
2897 thread
2898 .handle_session_update(
2899 acp::SessionUpdate::ToolCall(acp::ToolCall {
2900 id: acp::ToolCallId("tool1".into()),
2901 title: "Test Tool".into(),
2902 kind: acp::ToolKind::Fetch,
2903 status: acp::ToolCallStatus::Completed,
2904 content: vec![],
2905 locations: vec![],
2906 raw_input: Some(serde_json::json!({"query": "test"})),
2907 raw_output: Some(
2908 serde_json::json!({"result": "inappropriate content"}),
2909 ),
2910 meta: None,
2911 }),
2912 cx,
2913 )
2914 .unwrap();
2915 })?;
2916
2917 // Now return refusal because of the tool result
2918 Ok(acp::PromptResponse {
2919 stop_reason: acp::StopReason::Refusal,
2920 meta: None,
2921 })
2922 } else {
2923 Ok(acp::PromptResponse {
2924 stop_reason: acp::StopReason::EndTurn,
2925 meta: None,
2926 })
2927 }
2928 }
2929 .boxed_local()
2930 }
2931 }));
2932
2933 let thread = cx
2934 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2935 .await
2936 .unwrap();
2937
2938 // Track if we see a Refusal event
2939 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2940 let saw_refusal_event_captured = saw_refusal_event.clone();
2941 thread.update(cx, |_thread, cx| {
2942 cx.subscribe(
2943 &thread,
2944 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2945 if matches!(event, AcpThreadEvent::Refusal) {
2946 *saw_refusal_event_captured.lock().unwrap() = true;
2947 }
2948 },
2949 )
2950 .detach();
2951 });
2952
2953 // Send a user message - this will trigger tool call and then refusal
2954 let send_task = thread.update(cx, |thread, cx| {
2955 thread.send(
2956 vec![acp::ContentBlock::Text(acp::TextContent {
2957 text: "Hello".into(),
2958 annotations: None,
2959 meta: None,
2960 })],
2961 cx,
2962 )
2963 });
2964 cx.background_executor.spawn(send_task).detach();
2965 cx.run_until_parked();
2966
2967 // Verify that:
2968 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2969 // 2. The user message was NOT truncated
2970 assert!(
2971 *saw_refusal_event.lock().unwrap(),
2972 "Refusal event should be emitted for tool result refusals"
2973 );
2974
2975 thread.read_with(cx, |thread, _| {
2976 let entries = thread.entries();
2977 assert!(entries.len() >= 2, "Should have user message and tool call");
2978
2979 // Verify user message is still there
2980 assert!(
2981 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2982 "User message should not be truncated"
2983 );
2984
2985 // Verify tool call is there with result
2986 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2987 assert!(
2988 tool_call.raw_output.is_some(),
2989 "Tool call should have output"
2990 );
2991 } else {
2992 panic!("Expected tool call at index 1");
2993 }
2994 });
2995 }
2996
2997 #[gpui::test]
2998 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2999 init_test(cx);
3000
3001 let fs = FakeFs::new(cx.executor());
3002 let project = Project::test(fs, None, cx).await;
3003
3004 let refuse_next = Arc::new(AtomicBool::new(false));
3005 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3006 let refuse_next = refuse_next.clone();
3007 move |_request, _thread, _cx| {
3008 if refuse_next.load(SeqCst) {
3009 async move {
3010 Ok(acp::PromptResponse {
3011 stop_reason: acp::StopReason::Refusal,
3012 meta: None,
3013 })
3014 }
3015 .boxed_local()
3016 } else {
3017 async move {
3018 Ok(acp::PromptResponse {
3019 stop_reason: acp::StopReason::EndTurn,
3020 meta: None,
3021 })
3022 }
3023 .boxed_local()
3024 }
3025 }
3026 }));
3027
3028 let thread = cx
3029 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3030 .await
3031 .unwrap();
3032
3033 // Track if we see a Refusal event
3034 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3035 let saw_refusal_event_captured = saw_refusal_event.clone();
3036 thread.update(cx, |_thread, cx| {
3037 cx.subscribe(
3038 &thread,
3039 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3040 if matches!(event, AcpThreadEvent::Refusal) {
3041 *saw_refusal_event_captured.lock().unwrap() = true;
3042 }
3043 },
3044 )
3045 .detach();
3046 });
3047
3048 // Send a message that will be refused
3049 refuse_next.store(true, SeqCst);
3050 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3051 .await
3052 .unwrap();
3053
3054 // Verify that a Refusal event WAS emitted for user prompt refusal
3055 assert!(
3056 *saw_refusal_event.lock().unwrap(),
3057 "Refusal event should be emitted for user prompt refusals"
3058 );
3059
3060 // Verify the message was truncated (user prompt refusal)
3061 thread.read_with(cx, |thread, cx| {
3062 assert_eq!(thread.to_markdown(cx), "");
3063 });
3064 }
3065
3066 #[gpui::test]
3067 async fn test_refusal(cx: &mut TestAppContext) {
3068 init_test(cx);
3069 let fs = FakeFs::new(cx.background_executor.clone());
3070 fs.insert_tree(path!("/"), json!({})).await;
3071 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3072
3073 let refuse_next = Arc::new(AtomicBool::new(false));
3074 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3075 let refuse_next = refuse_next.clone();
3076 move |request, thread, mut cx| {
3077 let refuse_next = refuse_next.clone();
3078 async move {
3079 if refuse_next.load(SeqCst) {
3080 return Ok(acp::PromptResponse {
3081 stop_reason: acp::StopReason::Refusal,
3082 meta: None,
3083 });
3084 }
3085
3086 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3087 panic!("expected text content block");
3088 };
3089 thread.update(&mut cx, |thread, cx| {
3090 thread
3091 .handle_session_update(
3092 acp::SessionUpdate::AgentMessageChunk {
3093 content: content.text.to_uppercase().into(),
3094 },
3095 cx,
3096 )
3097 .unwrap();
3098 })?;
3099 Ok(acp::PromptResponse {
3100 stop_reason: acp::StopReason::EndTurn,
3101 meta: None,
3102 })
3103 }
3104 .boxed_local()
3105 }
3106 }));
3107 let thread = cx
3108 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3109 .await
3110 .unwrap();
3111
3112 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3113 .await
3114 .unwrap();
3115 thread.read_with(cx, |thread, cx| {
3116 assert_eq!(
3117 thread.to_markdown(cx),
3118 indoc! {"
3119 ## User
3120
3121 hello
3122
3123 ## Assistant
3124
3125 HELLO
3126
3127 "}
3128 );
3129 });
3130
3131 // Simulate refusing the second message. The message should be truncated
3132 // when a user prompt is refused.
3133 refuse_next.store(true, SeqCst);
3134 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3135 .await
3136 .unwrap();
3137 thread.read_with(cx, |thread, cx| {
3138 assert_eq!(
3139 thread.to_markdown(cx),
3140 indoc! {"
3141 ## User
3142
3143 hello
3144
3145 ## Assistant
3146
3147 HELLO
3148
3149 "}
3150 );
3151 });
3152 }
3153
3154 async fn run_until_first_tool_call(
3155 thread: &Entity<AcpThread>,
3156 cx: &mut TestAppContext,
3157 ) -> usize {
3158 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3159
3160 let subscription = cx.update(|cx| {
3161 cx.subscribe(thread, move |thread, _, cx| {
3162 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3163 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3164 return tx.try_send(ix).unwrap();
3165 }
3166 }
3167 })
3168 });
3169
3170 select! {
3171 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3172 panic!("Timeout waiting for tool call")
3173 }
3174 ix = rx.next().fuse() => {
3175 drop(subscription);
3176 ix.unwrap()
3177 }
3178 }
3179 }
3180
3181 #[derive(Clone, Default)]
3182 struct FakeAgentConnection {
3183 auth_methods: Vec<acp::AuthMethod>,
3184 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3185 on_user_message: Option<
3186 Rc<
3187 dyn Fn(
3188 acp::PromptRequest,
3189 WeakEntity<AcpThread>,
3190 AsyncApp,
3191 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3192 + 'static,
3193 >,
3194 >,
3195 }
3196
3197 impl FakeAgentConnection {
3198 fn new() -> Self {
3199 Self {
3200 auth_methods: Vec::new(),
3201 on_user_message: None,
3202 sessions: Arc::default(),
3203 }
3204 }
3205
3206 #[expect(unused)]
3207 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3208 self.auth_methods = auth_methods;
3209 self
3210 }
3211
3212 fn on_user_message(
3213 mut self,
3214 handler: impl Fn(
3215 acp::PromptRequest,
3216 WeakEntity<AcpThread>,
3217 AsyncApp,
3218 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3219 + 'static,
3220 ) -> Self {
3221 self.on_user_message.replace(Rc::new(handler));
3222 self
3223 }
3224 }
3225
3226 impl AgentConnection for FakeAgentConnection {
3227 fn auth_methods(&self) -> &[acp::AuthMethod] {
3228 &self.auth_methods
3229 }
3230
3231 fn new_thread(
3232 self: Rc<Self>,
3233 project: Entity<Project>,
3234 _cwd: &Path,
3235 cx: &mut App,
3236 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3237 let session_id = acp::SessionId(
3238 rand::rng()
3239 .sample_iter(&distr::Alphanumeric)
3240 .take(7)
3241 .map(char::from)
3242 .collect::<String>()
3243 .into(),
3244 );
3245 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3246 let thread = cx.new(|cx| {
3247 AcpThread::new(
3248 "Test",
3249 self.clone(),
3250 project,
3251 action_log,
3252 session_id.clone(),
3253 watch::Receiver::constant(acp::PromptCapabilities {
3254 image: true,
3255 audio: true,
3256 embedded_context: true,
3257 meta: None,
3258 }),
3259 cx,
3260 )
3261 });
3262 self.sessions.lock().insert(session_id, thread.downgrade());
3263 Task::ready(Ok(thread))
3264 }
3265
3266 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3267 if self.auth_methods().iter().any(|m| m.id == method) {
3268 Task::ready(Ok(()))
3269 } else {
3270 Task::ready(Err(anyhow!("Invalid Auth Method")))
3271 }
3272 }
3273
3274 fn prompt(
3275 &self,
3276 _id: Option<UserMessageId>,
3277 params: acp::PromptRequest,
3278 cx: &mut App,
3279 ) -> Task<gpui::Result<acp::PromptResponse>> {
3280 let sessions = self.sessions.lock();
3281 let thread = sessions.get(¶ms.session_id).unwrap();
3282 if let Some(handler) = &self.on_user_message {
3283 let handler = handler.clone();
3284 let thread = thread.clone();
3285 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3286 } else {
3287 Task::ready(Ok(acp::PromptResponse {
3288 stop_reason: acp::StopReason::EndTurn,
3289 meta: None,
3290 }))
3291 }
3292 }
3293
3294 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3295 let sessions = self.sessions.lock();
3296 let thread = sessions.get(session_id).unwrap().clone();
3297
3298 cx.spawn(async move |cx| {
3299 thread
3300 .update(cx, |thread, cx| thread.cancel(cx))
3301 .unwrap()
3302 .await
3303 })
3304 .detach();
3305 }
3306
3307 fn truncate(
3308 &self,
3309 session_id: &acp::SessionId,
3310 _cx: &App,
3311 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3312 Some(Rc::new(FakeAgentSessionEditor {
3313 _session_id: session_id.clone(),
3314 }))
3315 }
3316
3317 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3318 self
3319 }
3320 }
3321
3322 struct FakeAgentSessionEditor {
3323 _session_id: acp::SessionId,
3324 }
3325
3326 impl AgentSessionTruncate for FakeAgentSessionEditor {
3327 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3328 Task::ready(Ok(()))
3329 }
3330 }
3331
3332 #[gpui::test]
3333 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3334 init_test(cx);
3335
3336 let fs = FakeFs::new(cx.executor());
3337 let project = Project::test(fs, [], cx).await;
3338 let connection = Rc::new(FakeAgentConnection::new());
3339 let thread = cx
3340 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3341 .await
3342 .unwrap();
3343
3344 // Try to update a tool call that doesn't exist
3345 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3346 thread.update(cx, |thread, cx| {
3347 let result = thread.handle_session_update(
3348 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3349 id: nonexistent_id.clone(),
3350 fields: acp::ToolCallUpdateFields {
3351 status: Some(acp::ToolCallStatus::Completed),
3352 ..Default::default()
3353 },
3354 meta: None,
3355 }),
3356 cx,
3357 );
3358
3359 // The update should succeed (not return an error)
3360 assert!(result.is_ok());
3361
3362 // There should now be exactly one entry in the thread
3363 assert_eq!(thread.entries.len(), 1);
3364
3365 // The entry should be a failed tool call
3366 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3367 assert_eq!(tool_call.id, nonexistent_id);
3368 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3369 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3370
3371 // Check that the content contains the error message
3372 assert_eq!(tool_call.content.len(), 1);
3373 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3374 match content_block {
3375 ContentBlock::Markdown { markdown } => {
3376 let markdown_text = markdown.read(cx).source();
3377 assert!(markdown_text.contains("Tool call not found"));
3378 }
3379 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3380 ContentBlock::ResourceLink { .. } => {
3381 panic!("Expected markdown content, got resource link")
3382 }
3383 }
3384 } else {
3385 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3386 }
3387 } else {
3388 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3389 }
3390 });
3391 }
3392}