1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15use task::{Shell, ShellBuilder};
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_default_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path.to_string_lossy().into_owned(),
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
789 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
790}
791
792#[derive(Debug)]
793pub enum AcpThreadEvent {
794 NewEntry,
795 TitleUpdated,
796 TokenUsageUpdated,
797 EntryUpdated(usize),
798 EntriesRemoved(Range<usize>),
799 ToolAuthorizationRequired,
800 Retry(RetryStatus),
801 Stopped,
802 Error,
803 LoadError(LoadError),
804 PromptCapabilitiesUpdated,
805 Refusal,
806 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
807 ModeUpdated(acp::SessionModeId),
808}
809
810impl EventEmitter<AcpThreadEvent> for AcpThread {}
811
812#[derive(PartialEq, Eq, Debug)]
813pub enum ThreadStatus {
814 Idle,
815 Generating,
816}
817
818#[derive(Debug, Clone)]
819pub enum LoadError {
820 Unsupported {
821 command: SharedString,
822 current_version: SharedString,
823 minimum_version: SharedString,
824 },
825 FailedToInstall(SharedString),
826 Exited {
827 status: ExitStatus,
828 },
829 Other(SharedString),
830}
831
832impl Display for LoadError {
833 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
834 match self {
835 LoadError::Unsupported {
836 command: path,
837 current_version,
838 minimum_version,
839 } => {
840 write!(
841 f,
842 "version {current_version} from {path} is not supported (need at least {minimum_version})"
843 )
844 }
845 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
846 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
847 LoadError::Other(msg) => write!(f, "{msg}"),
848 }
849 }
850}
851
852impl Error for LoadError {}
853
854impl AcpThread {
855 pub fn new(
856 title: impl Into<SharedString>,
857 connection: Rc<dyn AgentConnection>,
858 project: Entity<Project>,
859 action_log: Entity<ActionLog>,
860 session_id: acp::SessionId,
861 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
862 cx: &mut Context<Self>,
863 ) -> Self {
864 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
865 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
866 loop {
867 let caps = prompt_capabilities_rx.recv().await?;
868 this.update(cx, |this, cx| {
869 this.prompt_capabilities = caps;
870 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
871 })?;
872 }
873 });
874
875 Self {
876 action_log,
877 shared_buffers: Default::default(),
878 entries: Default::default(),
879 plan: Default::default(),
880 title: title.into(),
881 project,
882 send_task: None,
883 connection,
884 session_id,
885 token_usage: None,
886 prompt_capabilities,
887 _observe_prompt_capabilities: task,
888 terminals: HashMap::default(),
889 }
890 }
891
892 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
893 self.prompt_capabilities.clone()
894 }
895
896 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
897 &self.connection
898 }
899
900 pub fn action_log(&self) -> &Entity<ActionLog> {
901 &self.action_log
902 }
903
904 pub fn project(&self) -> &Entity<Project> {
905 &self.project
906 }
907
908 pub fn title(&self) -> SharedString {
909 self.title.clone()
910 }
911
912 pub fn entries(&self) -> &[AgentThreadEntry] {
913 &self.entries
914 }
915
916 pub fn session_id(&self) -> &acp::SessionId {
917 &self.session_id
918 }
919
920 pub fn status(&self) -> ThreadStatus {
921 if self.send_task.is_some() {
922 ThreadStatus::Generating
923 } else {
924 ThreadStatus::Idle
925 }
926 }
927
928 pub fn token_usage(&self) -> Option<&TokenUsage> {
929 self.token_usage.as_ref()
930 }
931
932 pub fn has_pending_edit_tool_calls(&self) -> bool {
933 for entry in self.entries.iter().rev() {
934 match entry {
935 AgentThreadEntry::UserMessage(_) => return false,
936 AgentThreadEntry::ToolCall(
937 call @ ToolCall {
938 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
939 ..
940 },
941 ) if call.diffs().next().is_some() => {
942 return true;
943 }
944 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
945 }
946 }
947
948 false
949 }
950
951 pub fn used_tools_since_last_user_message(&self) -> bool {
952 for entry in self.entries.iter().rev() {
953 match entry {
954 AgentThreadEntry::UserMessage(..) => return false,
955 AgentThreadEntry::AssistantMessage(..) => continue,
956 AgentThreadEntry::ToolCall(..) => return true,
957 }
958 }
959
960 false
961 }
962
963 pub fn handle_session_update(
964 &mut self,
965 update: acp::SessionUpdate,
966 cx: &mut Context<Self>,
967 ) -> Result<(), acp::Error> {
968 match update {
969 acp::SessionUpdate::UserMessageChunk { content } => {
970 self.push_user_content_block(None, content, cx);
971 }
972 acp::SessionUpdate::AgentMessageChunk { content } => {
973 self.push_assistant_content_block(content, false, cx);
974 }
975 acp::SessionUpdate::AgentThoughtChunk { content } => {
976 self.push_assistant_content_block(content, true, cx);
977 }
978 acp::SessionUpdate::ToolCall(tool_call) => {
979 self.upsert_tool_call(tool_call, cx)?;
980 }
981 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
982 self.update_tool_call(tool_call_update, cx)?;
983 }
984 acp::SessionUpdate::Plan(plan) => {
985 self.update_plan(plan, cx);
986 }
987 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
988 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
989 }
990 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
991 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
992 }
993 }
994 Ok(())
995 }
996
997 pub fn push_user_content_block(
998 &mut self,
999 message_id: Option<UserMessageId>,
1000 chunk: acp::ContentBlock,
1001 cx: &mut Context<Self>,
1002 ) {
1003 let language_registry = self.project.read(cx).languages().clone();
1004 let entries_len = self.entries.len();
1005
1006 if let Some(last_entry) = self.entries.last_mut()
1007 && let AgentThreadEntry::UserMessage(UserMessage {
1008 id,
1009 content,
1010 chunks,
1011 ..
1012 }) = last_entry
1013 {
1014 *id = message_id.or(id.take());
1015 content.append(chunk.clone(), &language_registry, cx);
1016 chunks.push(chunk);
1017 let idx = entries_len - 1;
1018 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1019 } else {
1020 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1021 self.push_entry(
1022 AgentThreadEntry::UserMessage(UserMessage {
1023 id: message_id,
1024 content,
1025 chunks: vec![chunk],
1026 checkpoint: None,
1027 }),
1028 cx,
1029 );
1030 }
1031 }
1032
1033 pub fn push_assistant_content_block(
1034 &mut self,
1035 chunk: acp::ContentBlock,
1036 is_thought: bool,
1037 cx: &mut Context<Self>,
1038 ) {
1039 let language_registry = self.project.read(cx).languages().clone();
1040 let entries_len = self.entries.len();
1041 if let Some(last_entry) = self.entries.last_mut()
1042 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1043 {
1044 let idx = entries_len - 1;
1045 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1046 match (chunks.last_mut(), is_thought) {
1047 (Some(AssistantMessageChunk::Message { block }), false)
1048 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1049 block.append(chunk, &language_registry, cx)
1050 }
1051 _ => {
1052 let block = ContentBlock::new(chunk, &language_registry, cx);
1053 if is_thought {
1054 chunks.push(AssistantMessageChunk::Thought { block })
1055 } else {
1056 chunks.push(AssistantMessageChunk::Message { block })
1057 }
1058 }
1059 }
1060 } else {
1061 let block = ContentBlock::new(chunk, &language_registry, cx);
1062 let chunk = if is_thought {
1063 AssistantMessageChunk::Thought { block }
1064 } else {
1065 AssistantMessageChunk::Message { block }
1066 };
1067
1068 self.push_entry(
1069 AgentThreadEntry::AssistantMessage(AssistantMessage {
1070 chunks: vec![chunk],
1071 }),
1072 cx,
1073 );
1074 }
1075 }
1076
1077 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1078 self.entries.push(entry);
1079 cx.emit(AcpThreadEvent::NewEntry);
1080 }
1081
1082 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1083 self.connection.set_title(&self.session_id, cx).is_some()
1084 }
1085
1086 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1087 if title != self.title {
1088 self.title = title.clone();
1089 cx.emit(AcpThreadEvent::TitleUpdated);
1090 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1091 return set_title.run(title, cx);
1092 }
1093 }
1094 Task::ready(Ok(()))
1095 }
1096
1097 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1098 self.token_usage = usage;
1099 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1100 }
1101
1102 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1103 cx.emit(AcpThreadEvent::Retry(status));
1104 }
1105
1106 pub fn update_tool_call(
1107 &mut self,
1108 update: impl Into<ToolCallUpdate>,
1109 cx: &mut Context<Self>,
1110 ) -> Result<()> {
1111 let update = update.into();
1112 let languages = self.project.read(cx).languages().clone();
1113
1114 let ix = match self.index_for_tool_call(update.id()) {
1115 Some(ix) => ix,
1116 None => {
1117 // Tool call not found - create a failed tool call entry
1118 let failed_tool_call = ToolCall {
1119 id: update.id().clone(),
1120 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1121 kind: acp::ToolKind::Fetch,
1122 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1123 acp::ContentBlock::Text(acp::TextContent {
1124 text: "Tool call not found".to_string(),
1125 annotations: None,
1126 meta: None,
1127 }),
1128 &languages,
1129 cx,
1130 ))],
1131 status: ToolCallStatus::Failed,
1132 locations: Vec::new(),
1133 resolved_locations: Vec::new(),
1134 raw_input: None,
1135 raw_output: None,
1136 };
1137 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1138 return Ok(());
1139 }
1140 };
1141 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1142 unreachable!()
1143 };
1144
1145 match update {
1146 ToolCallUpdate::UpdateFields(update) => {
1147 let location_updated = update.fields.locations.is_some();
1148 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1149 if location_updated {
1150 self.resolve_locations(update.id, cx);
1151 }
1152 }
1153 ToolCallUpdate::UpdateDiff(update) => {
1154 call.content.clear();
1155 call.content.push(ToolCallContent::Diff(update.diff));
1156 }
1157 ToolCallUpdate::UpdateTerminal(update) => {
1158 call.content.clear();
1159 call.content
1160 .push(ToolCallContent::Terminal(update.terminal));
1161 }
1162 }
1163
1164 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1165
1166 Ok(())
1167 }
1168
1169 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1170 pub fn upsert_tool_call(
1171 &mut self,
1172 tool_call: acp::ToolCall,
1173 cx: &mut Context<Self>,
1174 ) -> Result<(), acp::Error> {
1175 let status = tool_call.status.into();
1176 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1177 }
1178
1179 /// Fails if id does not match an existing entry.
1180 pub fn upsert_tool_call_inner(
1181 &mut self,
1182 update: acp::ToolCallUpdate,
1183 status: ToolCallStatus,
1184 cx: &mut Context<Self>,
1185 ) -> Result<(), acp::Error> {
1186 let language_registry = self.project.read(cx).languages().clone();
1187 let id = update.id.clone();
1188
1189 if let Some(ix) = self.index_for_tool_call(&id) {
1190 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1191 unreachable!()
1192 };
1193
1194 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1195 call.status = status;
1196
1197 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1198 } else {
1199 let call = ToolCall::from_acp(
1200 update.try_into()?,
1201 status,
1202 language_registry,
1203 &self.terminals,
1204 cx,
1205 )?;
1206 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1207 };
1208
1209 self.resolve_locations(id, cx);
1210 Ok(())
1211 }
1212
1213 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1214 self.entries
1215 .iter()
1216 .enumerate()
1217 .rev()
1218 .find_map(|(index, entry)| {
1219 if let AgentThreadEntry::ToolCall(tool_call) = entry
1220 && &tool_call.id == id
1221 {
1222 Some(index)
1223 } else {
1224 None
1225 }
1226 })
1227 }
1228
1229 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1230 // The tool call we are looking for is typically the last one, or very close to the end.
1231 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1232 self.entries
1233 .iter_mut()
1234 .enumerate()
1235 .rev()
1236 .find_map(|(index, tool_call)| {
1237 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1238 && &tool_call.id == id
1239 {
1240 Some((index, tool_call))
1241 } else {
1242 None
1243 }
1244 })
1245 }
1246
1247 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1248 self.entries
1249 .iter()
1250 .enumerate()
1251 .rev()
1252 .find_map(|(index, tool_call)| {
1253 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1254 && &tool_call.id == id
1255 {
1256 Some((index, tool_call))
1257 } else {
1258 None
1259 }
1260 })
1261 }
1262
1263 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1264 let project = self.project.clone();
1265 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1266 return;
1267 };
1268 let task = tool_call.resolve_locations(project, cx);
1269 cx.spawn(async move |this, cx| {
1270 let resolved_locations = task.await;
1271 this.update(cx, |this, cx| {
1272 let project = this.project.clone();
1273 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1274 return;
1275 };
1276 if let Some(Some(location)) = resolved_locations.last() {
1277 project.update(cx, |project, cx| {
1278 if let Some(agent_location) = project.agent_location() {
1279 let should_ignore = agent_location.buffer == location.buffer
1280 && location
1281 .buffer
1282 .update(cx, |buffer, _| {
1283 let snapshot = buffer.snapshot();
1284 let old_position =
1285 agent_location.position.to_point(&snapshot);
1286 let new_position = location.position.to_point(&snapshot);
1287 // ignore this so that when we get updates from the edit tool
1288 // the position doesn't reset to the startof line
1289 old_position.row == new_position.row
1290 && old_position.column > new_position.column
1291 })
1292 .ok()
1293 .unwrap_or_default();
1294 if !should_ignore {
1295 project.set_agent_location(Some(location.clone()), cx);
1296 }
1297 }
1298 });
1299 }
1300 if tool_call.resolved_locations != resolved_locations {
1301 tool_call.resolved_locations = resolved_locations;
1302 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1303 }
1304 })
1305 })
1306 .detach();
1307 }
1308
1309 pub fn request_tool_call_authorization(
1310 &mut self,
1311 tool_call: acp::ToolCallUpdate,
1312 options: Vec<acp::PermissionOption>,
1313 respect_always_allow_setting: bool,
1314 cx: &mut Context<Self>,
1315 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1316 let (tx, rx) = oneshot::channel();
1317
1318 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1319 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1320 // some tools would (incorrectly) continue to auto-accept.
1321 if let Some(allow_once_option) = options.iter().find_map(|option| {
1322 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1323 Some(option.id.clone())
1324 } else {
1325 None
1326 }
1327 }) {
1328 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1329 return Ok(async {
1330 acp::RequestPermissionOutcome::Selected {
1331 option_id: allow_once_option,
1332 }
1333 }
1334 .boxed());
1335 }
1336 }
1337
1338 let status = ToolCallStatus::WaitingForConfirmation {
1339 options,
1340 respond_tx: tx,
1341 };
1342
1343 self.upsert_tool_call_inner(tool_call, status, cx)?;
1344 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1345
1346 let fut = async {
1347 match rx.await {
1348 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1349 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1350 }
1351 }
1352 .boxed();
1353
1354 Ok(fut)
1355 }
1356
1357 pub fn authorize_tool_call(
1358 &mut self,
1359 id: acp::ToolCallId,
1360 option_id: acp::PermissionOptionId,
1361 option_kind: acp::PermissionOptionKind,
1362 cx: &mut Context<Self>,
1363 ) {
1364 let Some((ix, call)) = self.tool_call_mut(&id) else {
1365 return;
1366 };
1367
1368 let new_status = match option_kind {
1369 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1370 ToolCallStatus::Rejected
1371 }
1372 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1373 ToolCallStatus::InProgress
1374 }
1375 };
1376
1377 let curr_status = mem::replace(&mut call.status, new_status);
1378
1379 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1380 respond_tx.send(option_id).log_err();
1381 } else if cfg!(debug_assertions) {
1382 panic!("tried to authorize an already authorized tool call");
1383 }
1384
1385 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1386 }
1387
1388 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1389 let mut first_tool_call = None;
1390
1391 for entry in self.entries.iter().rev() {
1392 match &entry {
1393 AgentThreadEntry::ToolCall(call) => {
1394 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1395 first_tool_call = Some(call);
1396 } else {
1397 continue;
1398 }
1399 }
1400 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1401 // Reached the beginning of the turn.
1402 // If we had pending permission requests in the previous turn, they have been cancelled.
1403 break;
1404 }
1405 }
1406 }
1407
1408 first_tool_call
1409 }
1410
1411 pub fn plan(&self) -> &Plan {
1412 &self.plan
1413 }
1414
1415 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1416 let new_entries_len = request.entries.len();
1417 let mut new_entries = request.entries.into_iter();
1418
1419 // Reuse existing markdown to prevent flickering
1420 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1421 let PlanEntry {
1422 content,
1423 priority,
1424 status,
1425 } = old;
1426 content.update(cx, |old, cx| {
1427 old.replace(new.content, cx);
1428 });
1429 *priority = new.priority;
1430 *status = new.status;
1431 }
1432 for new in new_entries {
1433 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1434 }
1435 self.plan.entries.truncate(new_entries_len);
1436
1437 cx.notify();
1438 }
1439
1440 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1441 self.plan
1442 .entries
1443 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1444 cx.notify();
1445 }
1446
1447 #[cfg(any(test, feature = "test-support"))]
1448 pub fn send_raw(
1449 &mut self,
1450 message: &str,
1451 cx: &mut Context<Self>,
1452 ) -> BoxFuture<'static, Result<()>> {
1453 self.send(
1454 vec![acp::ContentBlock::Text(acp::TextContent {
1455 text: message.to_string(),
1456 annotations: None,
1457 meta: None,
1458 })],
1459 cx,
1460 )
1461 }
1462
1463 pub fn send(
1464 &mut self,
1465 message: Vec<acp::ContentBlock>,
1466 cx: &mut Context<Self>,
1467 ) -> BoxFuture<'static, Result<()>> {
1468 let block = ContentBlock::new_combined(
1469 message.clone(),
1470 self.project.read(cx).languages().clone(),
1471 cx,
1472 );
1473 let request = acp::PromptRequest {
1474 prompt: message.clone(),
1475 session_id: self.session_id.clone(),
1476 meta: None,
1477 };
1478 let git_store = self.project.read(cx).git_store().clone();
1479
1480 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1481 Some(UserMessageId::new())
1482 } else {
1483 None
1484 };
1485
1486 self.run_turn(cx, async move |this, cx| {
1487 this.update(cx, |this, cx| {
1488 this.push_entry(
1489 AgentThreadEntry::UserMessage(UserMessage {
1490 id: message_id.clone(),
1491 content: block,
1492 chunks: message,
1493 checkpoint: None,
1494 }),
1495 cx,
1496 );
1497 })
1498 .ok();
1499
1500 let old_checkpoint = git_store
1501 .update(cx, |git, cx| git.checkpoint(cx))?
1502 .await
1503 .context("failed to get old checkpoint")
1504 .log_err();
1505 this.update(cx, |this, cx| {
1506 if let Some((_ix, message)) = this.last_user_message() {
1507 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1508 git_checkpoint,
1509 show: false,
1510 });
1511 }
1512 this.connection.prompt(message_id, request, cx)
1513 })?
1514 .await
1515 })
1516 }
1517
1518 pub fn can_resume(&self, cx: &App) -> bool {
1519 self.connection.resume(&self.session_id, cx).is_some()
1520 }
1521
1522 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1523 self.run_turn(cx, async move |this, cx| {
1524 this.update(cx, |this, cx| {
1525 this.connection
1526 .resume(&this.session_id, cx)
1527 .map(|resume| resume.run(cx))
1528 })?
1529 .context("resuming a session is not supported")?
1530 .await
1531 })
1532 }
1533
1534 fn run_turn(
1535 &mut self,
1536 cx: &mut Context<Self>,
1537 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1538 ) -> BoxFuture<'static, Result<()>> {
1539 self.clear_completed_plan_entries(cx);
1540
1541 let (tx, rx) = oneshot::channel();
1542 let cancel_task = self.cancel(cx);
1543
1544 self.send_task = Some(cx.spawn(async move |this, cx| {
1545 cancel_task.await;
1546 tx.send(f(this, cx).await).ok();
1547 }));
1548
1549 cx.spawn(async move |this, cx| {
1550 let response = rx.await;
1551
1552 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1553 .await?;
1554
1555 this.update(cx, |this, cx| {
1556 this.project
1557 .update(cx, |project, cx| project.set_agent_location(None, cx));
1558 match response {
1559 Ok(Err(e)) => {
1560 this.send_task.take();
1561 cx.emit(AcpThreadEvent::Error);
1562 Err(e)
1563 }
1564 result => {
1565 let canceled = matches!(
1566 result,
1567 Ok(Ok(acp::PromptResponse {
1568 stop_reason: acp::StopReason::Cancelled,
1569 meta: None,
1570 }))
1571 );
1572
1573 // We only take the task if the current prompt wasn't canceled.
1574 //
1575 // This prompt may have been canceled because another one was sent
1576 // while it was still generating. In these cases, dropping `send_task`
1577 // would cause the next generation to be canceled.
1578 if !canceled {
1579 this.send_task.take();
1580 }
1581
1582 // Handle refusal - distinguish between user prompt and tool call refusals
1583 if let Ok(Ok(acp::PromptResponse {
1584 stop_reason: acp::StopReason::Refusal,
1585 meta: _,
1586 })) = result
1587 {
1588 if let Some((user_msg_ix, _)) = this.last_user_message() {
1589 // Check if there's a completed tool call with results after the last user message
1590 // This indicates the refusal is in response to tool output, not the user's prompt
1591 let has_completed_tool_call_after_user_msg =
1592 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1593 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1594 // Check if the tool call has completed and has output
1595 matches!(tool_call.status, ToolCallStatus::Completed)
1596 && tool_call.raw_output.is_some()
1597 } else {
1598 false
1599 }
1600 });
1601
1602 if has_completed_tool_call_after_user_msg {
1603 // Refusal is due to tool output - don't truncate, just notify
1604 // The model refused based on what the tool returned
1605 cx.emit(AcpThreadEvent::Refusal);
1606 } else {
1607 // User prompt was refused - truncate back to before the user message
1608 let range = user_msg_ix..this.entries.len();
1609 if range.start < range.end {
1610 this.entries.truncate(user_msg_ix);
1611 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1612 }
1613 cx.emit(AcpThreadEvent::Refusal);
1614 }
1615 } else {
1616 // No user message found, treat as general refusal
1617 cx.emit(AcpThreadEvent::Refusal);
1618 }
1619 }
1620
1621 cx.emit(AcpThreadEvent::Stopped);
1622 Ok(())
1623 }
1624 }
1625 })?
1626 })
1627 .boxed()
1628 }
1629
1630 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1631 let Some(send_task) = self.send_task.take() else {
1632 return Task::ready(());
1633 };
1634
1635 for entry in self.entries.iter_mut() {
1636 if let AgentThreadEntry::ToolCall(call) = entry {
1637 let cancel = matches!(
1638 call.status,
1639 ToolCallStatus::Pending
1640 | ToolCallStatus::WaitingForConfirmation { .. }
1641 | ToolCallStatus::InProgress
1642 );
1643
1644 if cancel {
1645 call.status = ToolCallStatus::Canceled;
1646 }
1647 }
1648 }
1649
1650 self.connection.cancel(&self.session_id, cx);
1651
1652 // Wait for the send task to complete
1653 cx.foreground_executor().spawn(send_task)
1654 }
1655
1656 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1657 pub fn restore_checkpoint(
1658 &mut self,
1659 id: UserMessageId,
1660 cx: &mut Context<Self>,
1661 ) -> Task<Result<()>> {
1662 let Some((_, message)) = self.user_message_mut(&id) else {
1663 return Task::ready(Err(anyhow!("message not found")));
1664 };
1665
1666 let checkpoint = message
1667 .checkpoint
1668 .as_ref()
1669 .map(|c| c.git_checkpoint.clone());
1670 let rewind = self.rewind(id.clone(), cx);
1671 let git_store = self.project.read(cx).git_store().clone();
1672
1673 cx.spawn(async move |_, cx| {
1674 rewind.await?;
1675 if let Some(checkpoint) = checkpoint {
1676 git_store
1677 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1678 .await?;
1679 }
1680
1681 Ok(())
1682 })
1683 }
1684
1685 /// Rewinds this thread to before the entry at `index`, removing it and all
1686 /// subsequent entries while rejecting any action_log changes made from that point.
1687 /// Unlike `restore_checkpoint`, this method does not restore from git.
1688 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1689 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1690 return Task::ready(Err(anyhow!("not supported")));
1691 };
1692
1693 cx.spawn(async move |this, cx| {
1694 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1695 this.update(cx, |this, cx| {
1696 if let Some((ix, _)) = this.user_message_mut(&id) {
1697 let range = ix..this.entries.len();
1698 this.entries.truncate(ix);
1699 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1700 }
1701 this.action_log()
1702 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1703 })?
1704 .await;
1705 Ok(())
1706 })
1707 }
1708
1709 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1710 let git_store = self.project.read(cx).git_store().clone();
1711
1712 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1713 if let Some(checkpoint) = message.checkpoint.as_ref() {
1714 checkpoint.git_checkpoint.clone()
1715 } else {
1716 return Task::ready(Ok(()));
1717 }
1718 } else {
1719 return Task::ready(Ok(()));
1720 };
1721
1722 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1723 cx.spawn(async move |this, cx| {
1724 let new_checkpoint = new_checkpoint
1725 .await
1726 .context("failed to get new checkpoint")
1727 .log_err();
1728 if let Some(new_checkpoint) = new_checkpoint {
1729 let equal = git_store
1730 .update(cx, |git, cx| {
1731 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1732 })?
1733 .await
1734 .unwrap_or(true);
1735 this.update(cx, |this, cx| {
1736 let (ix, message) = this.last_user_message().context("no user message")?;
1737 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1738 checkpoint.show = !equal;
1739 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1740 anyhow::Ok(())
1741 })??;
1742 }
1743
1744 Ok(())
1745 })
1746 }
1747
1748 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1749 self.entries
1750 .iter_mut()
1751 .enumerate()
1752 .rev()
1753 .find_map(|(ix, entry)| {
1754 if let AgentThreadEntry::UserMessage(message) = entry {
1755 Some((ix, message))
1756 } else {
1757 None
1758 }
1759 })
1760 }
1761
1762 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1763 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1764 if let AgentThreadEntry::UserMessage(message) = entry {
1765 if message.id.as_ref() == Some(id) {
1766 Some((ix, message))
1767 } else {
1768 None
1769 }
1770 } else {
1771 None
1772 }
1773 })
1774 }
1775
1776 pub fn read_text_file(
1777 &self,
1778 path: PathBuf,
1779 line: Option<u32>,
1780 limit: Option<u32>,
1781 reuse_shared_snapshot: bool,
1782 cx: &mut Context<Self>,
1783 ) -> Task<Result<String, acp::Error>> {
1784 // Args are 1-based, move to 0-based
1785 let line = line.unwrap_or_default().saturating_sub(1);
1786 let limit = limit.unwrap_or(u32::MAX);
1787 let project = self.project.clone();
1788 let action_log = self.action_log.clone();
1789 cx.spawn(async move |this, cx| {
1790 let load = project
1791 .update(cx, |project, cx| {
1792 let path = project
1793 .project_path_for_absolute_path(&path, cx)
1794 .ok_or_else(|| {
1795 acp::Error::resource_not_found(Some(path.display().to_string()))
1796 })?;
1797 Ok(project.open_buffer(path, cx))
1798 })
1799 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1800 .flatten()?;
1801
1802 let buffer = load.await?;
1803
1804 let snapshot = if reuse_shared_snapshot {
1805 this.read_with(cx, |this, _| {
1806 this.shared_buffers.get(&buffer.clone()).cloned()
1807 })
1808 .log_err()
1809 .flatten()
1810 } else {
1811 None
1812 };
1813
1814 let snapshot = if let Some(snapshot) = snapshot {
1815 snapshot
1816 } else {
1817 action_log.update(cx, |action_log, cx| {
1818 action_log.buffer_read(buffer.clone(), cx);
1819 })?;
1820
1821 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1822 this.update(cx, |this, _| {
1823 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1824 })?;
1825 snapshot
1826 };
1827
1828 let max_point = snapshot.max_point();
1829 let start_position = Point::new(line, 0);
1830
1831 if start_position > max_point {
1832 return Err(acp::Error::invalid_params().with_data(format!(
1833 "Attempting to read beyond the end of the file, line {}:{}",
1834 max_point.row + 1,
1835 max_point.column
1836 )));
1837 }
1838
1839 let start = snapshot.anchor_before(start_position);
1840 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1841
1842 project.update(cx, |project, cx| {
1843 project.set_agent_location(
1844 Some(AgentLocation {
1845 buffer: buffer.downgrade(),
1846 position: start,
1847 }),
1848 cx,
1849 );
1850 })?;
1851
1852 Ok(snapshot.text_for_range(start..end).collect::<String>())
1853 })
1854 }
1855
1856 pub fn write_text_file(
1857 &self,
1858 path: PathBuf,
1859 content: String,
1860 cx: &mut Context<Self>,
1861 ) -> Task<Result<()>> {
1862 let project = self.project.clone();
1863 let action_log = self.action_log.clone();
1864 cx.spawn(async move |this, cx| {
1865 let load = project.update(cx, |project, cx| {
1866 let path = project
1867 .project_path_for_absolute_path(&path, cx)
1868 .context("invalid path")?;
1869 anyhow::Ok(project.open_buffer(path, cx))
1870 });
1871 let buffer = load??.await?;
1872 let snapshot = this.update(cx, |this, cx| {
1873 this.shared_buffers
1874 .get(&buffer)
1875 .cloned()
1876 .unwrap_or_else(|| buffer.read(cx).snapshot())
1877 })?;
1878 let edits = cx
1879 .background_executor()
1880 .spawn(async move {
1881 let old_text = snapshot.text();
1882 text_diff(old_text.as_str(), &content)
1883 .into_iter()
1884 .map(|(range, replacement)| {
1885 (
1886 snapshot.anchor_after(range.start)
1887 ..snapshot.anchor_before(range.end),
1888 replacement,
1889 )
1890 })
1891 .collect::<Vec<_>>()
1892 })
1893 .await;
1894
1895 project.update(cx, |project, cx| {
1896 project.set_agent_location(
1897 Some(AgentLocation {
1898 buffer: buffer.downgrade(),
1899 position: edits
1900 .last()
1901 .map(|(range, _)| range.end)
1902 .unwrap_or(Anchor::MIN),
1903 }),
1904 cx,
1905 );
1906 })?;
1907
1908 let format_on_save = cx.update(|cx| {
1909 action_log.update(cx, |action_log, cx| {
1910 action_log.buffer_read(buffer.clone(), cx);
1911 });
1912
1913 let format_on_save = buffer.update(cx, |buffer, cx| {
1914 buffer.edit(edits, None, cx);
1915
1916 let settings = language::language_settings::language_settings(
1917 buffer.language().map(|l| l.name()),
1918 buffer.file(),
1919 cx,
1920 );
1921
1922 settings.format_on_save != FormatOnSave::Off
1923 });
1924 action_log.update(cx, |action_log, cx| {
1925 action_log.buffer_edited(buffer.clone(), cx);
1926 });
1927 format_on_save
1928 })?;
1929
1930 if format_on_save {
1931 let format_task = project.update(cx, |project, cx| {
1932 project.format(
1933 HashSet::from_iter([buffer.clone()]),
1934 LspFormatTarget::Buffers,
1935 false,
1936 FormatTrigger::Save,
1937 cx,
1938 )
1939 })?;
1940 format_task.await.log_err();
1941
1942 action_log.update(cx, |action_log, cx| {
1943 action_log.buffer_edited(buffer.clone(), cx);
1944 })?;
1945 }
1946
1947 project
1948 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1949 .await
1950 })
1951 }
1952
1953 pub fn create_terminal(
1954 &self,
1955 command: String,
1956 args: Vec<String>,
1957 extra_env: Vec<acp::EnvVariable>,
1958 cwd: Option<PathBuf>,
1959 output_byte_limit: Option<u64>,
1960 cx: &mut Context<Self>,
1961 ) -> Task<Result<Entity<Terminal>>> {
1962 let env = match &cwd {
1963 Some(dir) => self.project.update(cx, |project, cx| {
1964 project.directory_environment(dir.as_path().into(), cx)
1965 }),
1966 None => Task::ready(None).shared(),
1967 };
1968
1969 let env = cx.spawn(async move |_, _| {
1970 let mut env = env.await.unwrap_or_default();
1971 // Disables paging for `git` and hopefully other commands
1972 env.insert("PAGER".into(), "".into());
1973 for var in extra_env {
1974 env.insert(var.name, var.value);
1975 }
1976 env
1977 });
1978
1979 let project = self.project.clone();
1980 let language_registry = project.read(cx).languages().clone();
1981
1982 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1983 let terminal_task = cx.spawn({
1984 let terminal_id = terminal_id.clone();
1985 async move |_this, cx| {
1986 let env = env.await;
1987 let (task_command, task_args) = ShellBuilder::new(
1988 project
1989 .update(cx, |project, cx| {
1990 project
1991 .remote_client()
1992 .and_then(|r| r.read(cx).default_system_shell())
1993 })?
1994 .as_deref(),
1995 &Shell::Program(get_default_system_shell()),
1996 )
1997 .redirect_stdin_to_dev_null()
1998 .build(Some(command.clone()), &args);
1999 let terminal = project
2000 .update(cx, |project, cx| {
2001 project.create_terminal_task(
2002 task::SpawnInTerminal {
2003 command: Some(task_command),
2004 args: task_args,
2005 cwd: cwd.clone(),
2006 env,
2007 ..Default::default()
2008 },
2009 cx,
2010 )
2011 })?
2012 .await?;
2013
2014 cx.new(|cx| {
2015 Terminal::new(
2016 terminal_id,
2017 &format!("{} {}", command, args.join(" ")),
2018 cwd,
2019 output_byte_limit.map(|l| l as usize),
2020 terminal,
2021 language_registry,
2022 cx,
2023 )
2024 })
2025 }
2026 });
2027
2028 cx.spawn(async move |this, cx| {
2029 let terminal = terminal_task.await?;
2030 this.update(cx, |this, _cx| {
2031 this.terminals.insert(terminal_id, terminal.clone());
2032 terminal
2033 })
2034 })
2035 }
2036
2037 pub fn kill_terminal(
2038 &mut self,
2039 terminal_id: acp::TerminalId,
2040 cx: &mut Context<Self>,
2041 ) -> Result<()> {
2042 self.terminals
2043 .get(&terminal_id)
2044 .context("Terminal not found")?
2045 .update(cx, |terminal, cx| {
2046 terminal.kill(cx);
2047 });
2048
2049 Ok(())
2050 }
2051
2052 pub fn release_terminal(
2053 &mut self,
2054 terminal_id: acp::TerminalId,
2055 cx: &mut Context<Self>,
2056 ) -> Result<()> {
2057 self.terminals
2058 .remove(&terminal_id)
2059 .context("Terminal not found")?
2060 .update(cx, |terminal, cx| {
2061 terminal.kill(cx);
2062 });
2063
2064 Ok(())
2065 }
2066
2067 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2068 self.terminals
2069 .get(&terminal_id)
2070 .context("Terminal not found")
2071 .cloned()
2072 }
2073
2074 pub fn to_markdown(&self, cx: &App) -> String {
2075 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2076 }
2077
2078 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2079 cx.emit(AcpThreadEvent::LoadError(error));
2080 }
2081}
2082
2083fn markdown_for_raw_output(
2084 raw_output: &serde_json::Value,
2085 language_registry: &Arc<LanguageRegistry>,
2086 cx: &mut App,
2087) -> Option<Entity<Markdown>> {
2088 match raw_output {
2089 serde_json::Value::Null => None,
2090 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2091 Markdown::new(
2092 value.to_string().into(),
2093 Some(language_registry.clone()),
2094 None,
2095 cx,
2096 )
2097 })),
2098 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2099 Markdown::new(
2100 value.to_string().into(),
2101 Some(language_registry.clone()),
2102 None,
2103 cx,
2104 )
2105 })),
2106 serde_json::Value::String(value) => Some(cx.new(|cx| {
2107 Markdown::new(
2108 value.clone().into(),
2109 Some(language_registry.clone()),
2110 None,
2111 cx,
2112 )
2113 })),
2114 value => Some(cx.new(|cx| {
2115 Markdown::new(
2116 format!("```json\n{}\n```", value).into(),
2117 Some(language_registry.clone()),
2118 None,
2119 cx,
2120 )
2121 })),
2122 }
2123}
2124
2125#[cfg(test)]
2126mod tests {
2127 use super::*;
2128 use anyhow::anyhow;
2129 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2130 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2131 use indoc::indoc;
2132 use project::{FakeFs, Fs};
2133 use rand::{distr, prelude::*};
2134 use serde_json::json;
2135 use settings::SettingsStore;
2136 use smol::stream::StreamExt as _;
2137 use std::{
2138 any::Any,
2139 cell::RefCell,
2140 path::Path,
2141 rc::Rc,
2142 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2143 time::Duration,
2144 };
2145 use util::path;
2146
2147 fn init_test(cx: &mut TestAppContext) {
2148 env_logger::try_init().ok();
2149 cx.update(|cx| {
2150 let settings_store = SettingsStore::test(cx);
2151 cx.set_global(settings_store);
2152 Project::init_settings(cx);
2153 language::init(cx);
2154 });
2155 }
2156
2157 #[gpui::test]
2158 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2159 init_test(cx);
2160
2161 let fs = FakeFs::new(cx.executor());
2162 let project = Project::test(fs, [], cx).await;
2163 let connection = Rc::new(FakeAgentConnection::new());
2164 let thread = cx
2165 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2166 .await
2167 .unwrap();
2168
2169 // Test creating a new user message
2170 thread.update(cx, |thread, cx| {
2171 thread.push_user_content_block(
2172 None,
2173 acp::ContentBlock::Text(acp::TextContent {
2174 annotations: None,
2175 text: "Hello, ".to_string(),
2176 meta: None,
2177 }),
2178 cx,
2179 );
2180 });
2181
2182 thread.update(cx, |thread, cx| {
2183 assert_eq!(thread.entries.len(), 1);
2184 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2185 assert_eq!(user_msg.id, None);
2186 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2187 } else {
2188 panic!("Expected UserMessage");
2189 }
2190 });
2191
2192 // Test appending to existing user message
2193 let message_1_id = UserMessageId::new();
2194 thread.update(cx, |thread, cx| {
2195 thread.push_user_content_block(
2196 Some(message_1_id.clone()),
2197 acp::ContentBlock::Text(acp::TextContent {
2198 annotations: None,
2199 text: "world!".to_string(),
2200 meta: None,
2201 }),
2202 cx,
2203 );
2204 });
2205
2206 thread.update(cx, |thread, cx| {
2207 assert_eq!(thread.entries.len(), 1);
2208 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2209 assert_eq!(user_msg.id, Some(message_1_id));
2210 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2211 } else {
2212 panic!("Expected UserMessage");
2213 }
2214 });
2215
2216 // Test creating new user message after assistant message
2217 thread.update(cx, |thread, cx| {
2218 thread.push_assistant_content_block(
2219 acp::ContentBlock::Text(acp::TextContent {
2220 annotations: None,
2221 text: "Assistant response".to_string(),
2222 meta: None,
2223 }),
2224 false,
2225 cx,
2226 );
2227 });
2228
2229 let message_2_id = UserMessageId::new();
2230 thread.update(cx, |thread, cx| {
2231 thread.push_user_content_block(
2232 Some(message_2_id.clone()),
2233 acp::ContentBlock::Text(acp::TextContent {
2234 annotations: None,
2235 text: "New user message".to_string(),
2236 meta: None,
2237 }),
2238 cx,
2239 );
2240 });
2241
2242 thread.update(cx, |thread, cx| {
2243 assert_eq!(thread.entries.len(), 3);
2244 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2245 assert_eq!(user_msg.id, Some(message_2_id));
2246 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2247 } else {
2248 panic!("Expected UserMessage at index 2");
2249 }
2250 });
2251 }
2252
2253 #[gpui::test]
2254 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2255 init_test(cx);
2256
2257 let fs = FakeFs::new(cx.executor());
2258 let project = Project::test(fs, [], cx).await;
2259 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2260 |_, thread, mut cx| {
2261 async move {
2262 thread.update(&mut cx, |thread, cx| {
2263 thread
2264 .handle_session_update(
2265 acp::SessionUpdate::AgentThoughtChunk {
2266 content: "Thinking ".into(),
2267 },
2268 cx,
2269 )
2270 .unwrap();
2271 thread
2272 .handle_session_update(
2273 acp::SessionUpdate::AgentThoughtChunk {
2274 content: "hard!".into(),
2275 },
2276 cx,
2277 )
2278 .unwrap();
2279 })?;
2280 Ok(acp::PromptResponse {
2281 stop_reason: acp::StopReason::EndTurn,
2282 meta: None,
2283 })
2284 }
2285 .boxed_local()
2286 },
2287 ));
2288
2289 let thread = cx
2290 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2291 .await
2292 .unwrap();
2293
2294 thread
2295 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2296 .await
2297 .unwrap();
2298
2299 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2300 assert_eq!(
2301 output,
2302 indoc! {r#"
2303 ## User
2304
2305 Hello from Zed!
2306
2307 ## Assistant
2308
2309 <thinking>
2310 Thinking hard!
2311 </thinking>
2312
2313 "#}
2314 );
2315 }
2316
2317 #[gpui::test]
2318 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2319 init_test(cx);
2320
2321 let fs = FakeFs::new(cx.executor());
2322 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2323 .await;
2324 let project = Project::test(fs.clone(), [], cx).await;
2325 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2326 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2327 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2328 move |_, thread, mut cx| {
2329 let read_file_tx = read_file_tx.clone();
2330 async move {
2331 let content = thread
2332 .update(&mut cx, |thread, cx| {
2333 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2334 })
2335 .unwrap()
2336 .await
2337 .unwrap();
2338 assert_eq!(content, "one\ntwo\nthree\n");
2339 read_file_tx.take().unwrap().send(()).unwrap();
2340 thread
2341 .update(&mut cx, |thread, cx| {
2342 thread.write_text_file(
2343 path!("/tmp/foo").into(),
2344 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2345 cx,
2346 )
2347 })
2348 .unwrap()
2349 .await
2350 .unwrap();
2351 Ok(acp::PromptResponse {
2352 stop_reason: acp::StopReason::EndTurn,
2353 meta: None,
2354 })
2355 }
2356 .boxed_local()
2357 },
2358 ));
2359
2360 let (worktree, pathbuf) = project
2361 .update(cx, |project, cx| {
2362 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2363 })
2364 .await
2365 .unwrap();
2366 let buffer = project
2367 .update(cx, |project, cx| {
2368 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2369 })
2370 .await
2371 .unwrap();
2372
2373 let thread = cx
2374 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2375 .await
2376 .unwrap();
2377
2378 let request = thread.update(cx, |thread, cx| {
2379 thread.send_raw("Extend the count in /tmp/foo", cx)
2380 });
2381 read_file_rx.await.ok();
2382 buffer.update(cx, |buffer, cx| {
2383 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2384 });
2385 cx.run_until_parked();
2386 assert_eq!(
2387 buffer.read_with(cx, |buffer, _| buffer.text()),
2388 "zero\none\ntwo\nthree\nfour\nfive\n"
2389 );
2390 assert_eq!(
2391 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2392 "zero\none\ntwo\nthree\nfour\nfive\n"
2393 );
2394 request.await.unwrap();
2395 }
2396
2397 #[gpui::test]
2398 async fn test_reading_from_line(cx: &mut TestAppContext) {
2399 init_test(cx);
2400
2401 let fs = FakeFs::new(cx.executor());
2402 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2403 .await;
2404 let project = Project::test(fs.clone(), [], cx).await;
2405 project
2406 .update(cx, |project, cx| {
2407 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2408 })
2409 .await
2410 .unwrap();
2411
2412 let connection = Rc::new(FakeAgentConnection::new());
2413
2414 let thread = cx
2415 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2416 .await
2417 .unwrap();
2418
2419 // Whole file
2420 let content = thread
2421 .update(cx, |thread, cx| {
2422 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2423 })
2424 .await
2425 .unwrap();
2426
2427 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2428
2429 // Only start line
2430 let content = thread
2431 .update(cx, |thread, cx| {
2432 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2433 })
2434 .await
2435 .unwrap();
2436
2437 assert_eq!(content, "three\nfour\n");
2438
2439 // Only limit
2440 let content = thread
2441 .update(cx, |thread, cx| {
2442 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2443 })
2444 .await
2445 .unwrap();
2446
2447 assert_eq!(content, "one\ntwo\n");
2448
2449 // Range
2450 let content = thread
2451 .update(cx, |thread, cx| {
2452 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2453 })
2454 .await
2455 .unwrap();
2456
2457 assert_eq!(content, "two\nthree\n");
2458
2459 // Invalid
2460 let err = thread
2461 .update(cx, |thread, cx| {
2462 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2463 })
2464 .await
2465 .unwrap_err();
2466
2467 assert_eq!(
2468 err.to_string(),
2469 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2470 );
2471 }
2472
2473 #[gpui::test]
2474 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2475 init_test(cx);
2476
2477 let fs = FakeFs::new(cx.executor());
2478 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2479 let project = Project::test(fs.clone(), [], cx).await;
2480 project
2481 .update(cx, |project, cx| {
2482 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2483 })
2484 .await
2485 .unwrap();
2486
2487 let connection = Rc::new(FakeAgentConnection::new());
2488
2489 let thread = cx
2490 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2491 .await
2492 .unwrap();
2493
2494 // Whole file
2495 let content = thread
2496 .update(cx, |thread, cx| {
2497 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2498 })
2499 .await
2500 .unwrap();
2501
2502 assert_eq!(content, "");
2503
2504 // Only start line
2505 let content = thread
2506 .update(cx, |thread, cx| {
2507 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2508 })
2509 .await
2510 .unwrap();
2511
2512 assert_eq!(content, "");
2513
2514 // Only limit
2515 let content = thread
2516 .update(cx, |thread, cx| {
2517 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2518 })
2519 .await
2520 .unwrap();
2521
2522 assert_eq!(content, "");
2523
2524 // Range
2525 let content = thread
2526 .update(cx, |thread, cx| {
2527 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2528 })
2529 .await
2530 .unwrap();
2531
2532 assert_eq!(content, "");
2533
2534 // Invalid
2535 let err = thread
2536 .update(cx, |thread, cx| {
2537 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2538 })
2539 .await
2540 .unwrap_err();
2541
2542 assert_eq!(
2543 err.to_string(),
2544 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2545 );
2546 }
2547 #[gpui::test]
2548 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2549 init_test(cx);
2550
2551 let fs = FakeFs::new(cx.executor());
2552 fs.insert_tree(path!("/tmp"), json!({})).await;
2553 let project = Project::test(fs.clone(), [], cx).await;
2554 project
2555 .update(cx, |project, cx| {
2556 project.find_or_create_worktree(path!("/tmp"), true, cx)
2557 })
2558 .await
2559 .unwrap();
2560
2561 let connection = Rc::new(FakeAgentConnection::new());
2562
2563 let thread = cx
2564 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2565 .await
2566 .unwrap();
2567
2568 // Out of project file
2569 let err = thread
2570 .update(cx, |thread, cx| {
2571 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2572 })
2573 .await
2574 .unwrap_err();
2575
2576 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2577 }
2578
2579 #[gpui::test]
2580 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2581 init_test(cx);
2582
2583 let fs = FakeFs::new(cx.executor());
2584 let project = Project::test(fs, [], cx).await;
2585 let id = acp::ToolCallId("test".into());
2586
2587 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2588 let id = id.clone();
2589 move |_, thread, mut cx| {
2590 let id = id.clone();
2591 async move {
2592 thread
2593 .update(&mut cx, |thread, cx| {
2594 thread.handle_session_update(
2595 acp::SessionUpdate::ToolCall(acp::ToolCall {
2596 id: id.clone(),
2597 title: "Label".into(),
2598 kind: acp::ToolKind::Fetch,
2599 status: acp::ToolCallStatus::InProgress,
2600 content: vec![],
2601 locations: vec![],
2602 raw_input: None,
2603 raw_output: None,
2604 meta: None,
2605 }),
2606 cx,
2607 )
2608 })
2609 .unwrap()
2610 .unwrap();
2611 Ok(acp::PromptResponse {
2612 stop_reason: acp::StopReason::EndTurn,
2613 meta: None,
2614 })
2615 }
2616 .boxed_local()
2617 }
2618 }));
2619
2620 let thread = cx
2621 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2622 .await
2623 .unwrap();
2624
2625 let request = thread.update(cx, |thread, cx| {
2626 thread.send_raw("Fetch https://example.com", cx)
2627 });
2628
2629 run_until_first_tool_call(&thread, cx).await;
2630
2631 thread.read_with(cx, |thread, _| {
2632 assert!(matches!(
2633 thread.entries[1],
2634 AgentThreadEntry::ToolCall(ToolCall {
2635 status: ToolCallStatus::InProgress,
2636 ..
2637 })
2638 ));
2639 });
2640
2641 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2642
2643 thread.read_with(cx, |thread, _| {
2644 assert!(matches!(
2645 &thread.entries[1],
2646 AgentThreadEntry::ToolCall(ToolCall {
2647 status: ToolCallStatus::Canceled,
2648 ..
2649 })
2650 ));
2651 });
2652
2653 thread
2654 .update(cx, |thread, cx| {
2655 thread.handle_session_update(
2656 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2657 id,
2658 fields: acp::ToolCallUpdateFields {
2659 status: Some(acp::ToolCallStatus::Completed),
2660 ..Default::default()
2661 },
2662 meta: None,
2663 }),
2664 cx,
2665 )
2666 })
2667 .unwrap();
2668
2669 request.await.unwrap();
2670
2671 thread.read_with(cx, |thread, _| {
2672 assert!(matches!(
2673 thread.entries[1],
2674 AgentThreadEntry::ToolCall(ToolCall {
2675 status: ToolCallStatus::Completed,
2676 ..
2677 })
2678 ));
2679 });
2680 }
2681
2682 #[gpui::test]
2683 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2684 init_test(cx);
2685 let fs = FakeFs::new(cx.background_executor.clone());
2686 fs.insert_tree(path!("/test"), json!({})).await;
2687 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2688
2689 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2690 move |_, thread, mut cx| {
2691 async move {
2692 thread
2693 .update(&mut cx, |thread, cx| {
2694 thread.handle_session_update(
2695 acp::SessionUpdate::ToolCall(acp::ToolCall {
2696 id: acp::ToolCallId("test".into()),
2697 title: "Label".into(),
2698 kind: acp::ToolKind::Edit,
2699 status: acp::ToolCallStatus::Completed,
2700 content: vec![acp::ToolCallContent::Diff {
2701 diff: acp::Diff {
2702 path: "/test/test.txt".into(),
2703 old_text: None,
2704 new_text: "foo".into(),
2705 meta: None,
2706 },
2707 }],
2708 locations: vec![],
2709 raw_input: None,
2710 raw_output: None,
2711 meta: None,
2712 }),
2713 cx,
2714 )
2715 })
2716 .unwrap()
2717 .unwrap();
2718 Ok(acp::PromptResponse {
2719 stop_reason: acp::StopReason::EndTurn,
2720 meta: None,
2721 })
2722 }
2723 .boxed_local()
2724 }
2725 }));
2726
2727 let thread = cx
2728 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2729 .await
2730 .unwrap();
2731
2732 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2733 .await
2734 .unwrap();
2735
2736 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2737 }
2738
2739 #[gpui::test(iterations = 10)]
2740 async fn test_checkpoints(cx: &mut TestAppContext) {
2741 init_test(cx);
2742 let fs = FakeFs::new(cx.background_executor.clone());
2743 fs.insert_tree(
2744 path!("/test"),
2745 json!({
2746 ".git": {}
2747 }),
2748 )
2749 .await;
2750 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2751
2752 let simulate_changes = Arc::new(AtomicBool::new(true));
2753 let next_filename = Arc::new(AtomicUsize::new(0));
2754 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2755 let simulate_changes = simulate_changes.clone();
2756 let next_filename = next_filename.clone();
2757 let fs = fs.clone();
2758 move |request, thread, mut cx| {
2759 let fs = fs.clone();
2760 let simulate_changes = simulate_changes.clone();
2761 let next_filename = next_filename.clone();
2762 async move {
2763 if simulate_changes.load(SeqCst) {
2764 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2765 fs.write(Path::new(&filename), b"").await?;
2766 }
2767
2768 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2769 panic!("expected text content block");
2770 };
2771 thread.update(&mut cx, |thread, cx| {
2772 thread
2773 .handle_session_update(
2774 acp::SessionUpdate::AgentMessageChunk {
2775 content: content.text.to_uppercase().into(),
2776 },
2777 cx,
2778 )
2779 .unwrap();
2780 })?;
2781 Ok(acp::PromptResponse {
2782 stop_reason: acp::StopReason::EndTurn,
2783 meta: None,
2784 })
2785 }
2786 .boxed_local()
2787 }
2788 }));
2789 let thread = cx
2790 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2791 .await
2792 .unwrap();
2793
2794 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2795 .await
2796 .unwrap();
2797 thread.read_with(cx, |thread, cx| {
2798 assert_eq!(
2799 thread.to_markdown(cx),
2800 indoc! {"
2801 ## User (checkpoint)
2802
2803 Lorem
2804
2805 ## Assistant
2806
2807 LOREM
2808
2809 "}
2810 );
2811 });
2812 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2813
2814 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2815 .await
2816 .unwrap();
2817 thread.read_with(cx, |thread, cx| {
2818 assert_eq!(
2819 thread.to_markdown(cx),
2820 indoc! {"
2821 ## User (checkpoint)
2822
2823 Lorem
2824
2825 ## Assistant
2826
2827 LOREM
2828
2829 ## User (checkpoint)
2830
2831 ipsum
2832
2833 ## Assistant
2834
2835 IPSUM
2836
2837 "}
2838 );
2839 });
2840 assert_eq!(
2841 fs.files(),
2842 vec![
2843 Path::new(path!("/test/file-0")),
2844 Path::new(path!("/test/file-1"))
2845 ]
2846 );
2847
2848 // Checkpoint isn't stored when there are no changes.
2849 simulate_changes.store(false, SeqCst);
2850 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2851 .await
2852 .unwrap();
2853 thread.read_with(cx, |thread, cx| {
2854 assert_eq!(
2855 thread.to_markdown(cx),
2856 indoc! {"
2857 ## User (checkpoint)
2858
2859 Lorem
2860
2861 ## Assistant
2862
2863 LOREM
2864
2865 ## User (checkpoint)
2866
2867 ipsum
2868
2869 ## Assistant
2870
2871 IPSUM
2872
2873 ## User
2874
2875 dolor
2876
2877 ## Assistant
2878
2879 DOLOR
2880
2881 "}
2882 );
2883 });
2884 assert_eq!(
2885 fs.files(),
2886 vec![
2887 Path::new(path!("/test/file-0")),
2888 Path::new(path!("/test/file-1"))
2889 ]
2890 );
2891
2892 // Rewinding the conversation truncates the history and restores the checkpoint.
2893 thread
2894 .update(cx, |thread, cx| {
2895 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2896 panic!("unexpected entries {:?}", thread.entries)
2897 };
2898 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2899 })
2900 .await
2901 .unwrap();
2902 thread.read_with(cx, |thread, cx| {
2903 assert_eq!(
2904 thread.to_markdown(cx),
2905 indoc! {"
2906 ## User (checkpoint)
2907
2908 Lorem
2909
2910 ## Assistant
2911
2912 LOREM
2913
2914 "}
2915 );
2916 });
2917 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2918 }
2919
2920 #[gpui::test]
2921 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2922 use std::sync::atomic::AtomicUsize;
2923 init_test(cx);
2924
2925 let fs = FakeFs::new(cx.executor());
2926 let project = Project::test(fs, None, cx).await;
2927
2928 // Create a connection that simulates refusal after tool result
2929 let prompt_count = Arc::new(AtomicUsize::new(0));
2930 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2931 let prompt_count = prompt_count.clone();
2932 move |_request, thread, mut cx| {
2933 let count = prompt_count.fetch_add(1, SeqCst);
2934 async move {
2935 if count == 0 {
2936 // First prompt: Generate a tool call with result
2937 thread.update(&mut cx, |thread, cx| {
2938 thread
2939 .handle_session_update(
2940 acp::SessionUpdate::ToolCall(acp::ToolCall {
2941 id: acp::ToolCallId("tool1".into()),
2942 title: "Test Tool".into(),
2943 kind: acp::ToolKind::Fetch,
2944 status: acp::ToolCallStatus::Completed,
2945 content: vec![],
2946 locations: vec![],
2947 raw_input: Some(serde_json::json!({"query": "test"})),
2948 raw_output: Some(
2949 serde_json::json!({"result": "inappropriate content"}),
2950 ),
2951 meta: None,
2952 }),
2953 cx,
2954 )
2955 .unwrap();
2956 })?;
2957
2958 // Now return refusal because of the tool result
2959 Ok(acp::PromptResponse {
2960 stop_reason: acp::StopReason::Refusal,
2961 meta: None,
2962 })
2963 } else {
2964 Ok(acp::PromptResponse {
2965 stop_reason: acp::StopReason::EndTurn,
2966 meta: None,
2967 })
2968 }
2969 }
2970 .boxed_local()
2971 }
2972 }));
2973
2974 let thread = cx
2975 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2976 .await
2977 .unwrap();
2978
2979 // Track if we see a Refusal event
2980 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2981 let saw_refusal_event_captured = saw_refusal_event.clone();
2982 thread.update(cx, |_thread, cx| {
2983 cx.subscribe(
2984 &thread,
2985 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2986 if matches!(event, AcpThreadEvent::Refusal) {
2987 *saw_refusal_event_captured.lock().unwrap() = true;
2988 }
2989 },
2990 )
2991 .detach();
2992 });
2993
2994 // Send a user message - this will trigger tool call and then refusal
2995 let send_task = thread.update(cx, |thread, cx| {
2996 thread.send(
2997 vec![acp::ContentBlock::Text(acp::TextContent {
2998 text: "Hello".into(),
2999 annotations: None,
3000 meta: None,
3001 })],
3002 cx,
3003 )
3004 });
3005 cx.background_executor.spawn(send_task).detach();
3006 cx.run_until_parked();
3007
3008 // Verify that:
3009 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3010 // 2. The user message was NOT truncated
3011 assert!(
3012 *saw_refusal_event.lock().unwrap(),
3013 "Refusal event should be emitted for tool result refusals"
3014 );
3015
3016 thread.read_with(cx, |thread, _| {
3017 let entries = thread.entries();
3018 assert!(entries.len() >= 2, "Should have user message and tool call");
3019
3020 // Verify user message is still there
3021 assert!(
3022 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3023 "User message should not be truncated"
3024 );
3025
3026 // Verify tool call is there with result
3027 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3028 assert!(
3029 tool_call.raw_output.is_some(),
3030 "Tool call should have output"
3031 );
3032 } else {
3033 panic!("Expected tool call at index 1");
3034 }
3035 });
3036 }
3037
3038 #[gpui::test]
3039 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3040 init_test(cx);
3041
3042 let fs = FakeFs::new(cx.executor());
3043 let project = Project::test(fs, None, cx).await;
3044
3045 let refuse_next = Arc::new(AtomicBool::new(false));
3046 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3047 let refuse_next = refuse_next.clone();
3048 move |_request, _thread, _cx| {
3049 if refuse_next.load(SeqCst) {
3050 async move {
3051 Ok(acp::PromptResponse {
3052 stop_reason: acp::StopReason::Refusal,
3053 meta: None,
3054 })
3055 }
3056 .boxed_local()
3057 } else {
3058 async move {
3059 Ok(acp::PromptResponse {
3060 stop_reason: acp::StopReason::EndTurn,
3061 meta: None,
3062 })
3063 }
3064 .boxed_local()
3065 }
3066 }
3067 }));
3068
3069 let thread = cx
3070 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3071 .await
3072 .unwrap();
3073
3074 // Track if we see a Refusal event
3075 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3076 let saw_refusal_event_captured = saw_refusal_event.clone();
3077 thread.update(cx, |_thread, cx| {
3078 cx.subscribe(
3079 &thread,
3080 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3081 if matches!(event, AcpThreadEvent::Refusal) {
3082 *saw_refusal_event_captured.lock().unwrap() = true;
3083 }
3084 },
3085 )
3086 .detach();
3087 });
3088
3089 // Send a message that will be refused
3090 refuse_next.store(true, SeqCst);
3091 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3092 .await
3093 .unwrap();
3094
3095 // Verify that a Refusal event WAS emitted for user prompt refusal
3096 assert!(
3097 *saw_refusal_event.lock().unwrap(),
3098 "Refusal event should be emitted for user prompt refusals"
3099 );
3100
3101 // Verify the message was truncated (user prompt refusal)
3102 thread.read_with(cx, |thread, cx| {
3103 assert_eq!(thread.to_markdown(cx), "");
3104 });
3105 }
3106
3107 #[gpui::test]
3108 async fn test_refusal(cx: &mut TestAppContext) {
3109 init_test(cx);
3110 let fs = FakeFs::new(cx.background_executor.clone());
3111 fs.insert_tree(path!("/"), json!({})).await;
3112 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3113
3114 let refuse_next = Arc::new(AtomicBool::new(false));
3115 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3116 let refuse_next = refuse_next.clone();
3117 move |request, thread, mut cx| {
3118 let refuse_next = refuse_next.clone();
3119 async move {
3120 if refuse_next.load(SeqCst) {
3121 return Ok(acp::PromptResponse {
3122 stop_reason: acp::StopReason::Refusal,
3123 meta: None,
3124 });
3125 }
3126
3127 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3128 panic!("expected text content block");
3129 };
3130 thread.update(&mut cx, |thread, cx| {
3131 thread
3132 .handle_session_update(
3133 acp::SessionUpdate::AgentMessageChunk {
3134 content: content.text.to_uppercase().into(),
3135 },
3136 cx,
3137 )
3138 .unwrap();
3139 })?;
3140 Ok(acp::PromptResponse {
3141 stop_reason: acp::StopReason::EndTurn,
3142 meta: None,
3143 })
3144 }
3145 .boxed_local()
3146 }
3147 }));
3148 let thread = cx
3149 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3150 .await
3151 .unwrap();
3152
3153 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3154 .await
3155 .unwrap();
3156 thread.read_with(cx, |thread, cx| {
3157 assert_eq!(
3158 thread.to_markdown(cx),
3159 indoc! {"
3160 ## User
3161
3162 hello
3163
3164 ## Assistant
3165
3166 HELLO
3167
3168 "}
3169 );
3170 });
3171
3172 // Simulate refusing the second message. The message should be truncated
3173 // when a user prompt is refused.
3174 refuse_next.store(true, SeqCst);
3175 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3176 .await
3177 .unwrap();
3178 thread.read_with(cx, |thread, cx| {
3179 assert_eq!(
3180 thread.to_markdown(cx),
3181 indoc! {"
3182 ## User
3183
3184 hello
3185
3186 ## Assistant
3187
3188 HELLO
3189
3190 "}
3191 );
3192 });
3193 }
3194
3195 async fn run_until_first_tool_call(
3196 thread: &Entity<AcpThread>,
3197 cx: &mut TestAppContext,
3198 ) -> usize {
3199 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3200
3201 let subscription = cx.update(|cx| {
3202 cx.subscribe(thread, move |thread, _, cx| {
3203 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3204 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3205 return tx.try_send(ix).unwrap();
3206 }
3207 }
3208 })
3209 });
3210
3211 select! {
3212 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3213 panic!("Timeout waiting for tool call")
3214 }
3215 ix = rx.next().fuse() => {
3216 drop(subscription);
3217 ix.unwrap()
3218 }
3219 }
3220 }
3221
3222 #[derive(Clone, Default)]
3223 struct FakeAgentConnection {
3224 auth_methods: Vec<acp::AuthMethod>,
3225 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3226 on_user_message: Option<
3227 Rc<
3228 dyn Fn(
3229 acp::PromptRequest,
3230 WeakEntity<AcpThread>,
3231 AsyncApp,
3232 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3233 + 'static,
3234 >,
3235 >,
3236 }
3237
3238 impl FakeAgentConnection {
3239 fn new() -> Self {
3240 Self {
3241 auth_methods: Vec::new(),
3242 on_user_message: None,
3243 sessions: Arc::default(),
3244 }
3245 }
3246
3247 #[expect(unused)]
3248 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3249 self.auth_methods = auth_methods;
3250 self
3251 }
3252
3253 fn on_user_message(
3254 mut self,
3255 handler: impl Fn(
3256 acp::PromptRequest,
3257 WeakEntity<AcpThread>,
3258 AsyncApp,
3259 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3260 + 'static,
3261 ) -> Self {
3262 self.on_user_message.replace(Rc::new(handler));
3263 self
3264 }
3265 }
3266
3267 impl AgentConnection for FakeAgentConnection {
3268 fn auth_methods(&self) -> &[acp::AuthMethod] {
3269 &self.auth_methods
3270 }
3271
3272 fn new_thread(
3273 self: Rc<Self>,
3274 project: Entity<Project>,
3275 _cwd: &Path,
3276 cx: &mut App,
3277 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3278 let session_id = acp::SessionId(
3279 rand::rng()
3280 .sample_iter(&distr::Alphanumeric)
3281 .take(7)
3282 .map(char::from)
3283 .collect::<String>()
3284 .into(),
3285 );
3286 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3287 let thread = cx.new(|cx| {
3288 AcpThread::new(
3289 "Test",
3290 self.clone(),
3291 project,
3292 action_log,
3293 session_id.clone(),
3294 watch::Receiver::constant(acp::PromptCapabilities {
3295 image: true,
3296 audio: true,
3297 embedded_context: true,
3298 meta: None,
3299 }),
3300 cx,
3301 )
3302 });
3303 self.sessions.lock().insert(session_id, thread.downgrade());
3304 Task::ready(Ok(thread))
3305 }
3306
3307 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3308 if self.auth_methods().iter().any(|m| m.id == method) {
3309 Task::ready(Ok(()))
3310 } else {
3311 Task::ready(Err(anyhow!("Invalid Auth Method")))
3312 }
3313 }
3314
3315 fn prompt(
3316 &self,
3317 _id: Option<UserMessageId>,
3318 params: acp::PromptRequest,
3319 cx: &mut App,
3320 ) -> Task<gpui::Result<acp::PromptResponse>> {
3321 let sessions = self.sessions.lock();
3322 let thread = sessions.get(¶ms.session_id).unwrap();
3323 if let Some(handler) = &self.on_user_message {
3324 let handler = handler.clone();
3325 let thread = thread.clone();
3326 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3327 } else {
3328 Task::ready(Ok(acp::PromptResponse {
3329 stop_reason: acp::StopReason::EndTurn,
3330 meta: None,
3331 }))
3332 }
3333 }
3334
3335 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3336 let sessions = self.sessions.lock();
3337 let thread = sessions.get(session_id).unwrap().clone();
3338
3339 cx.spawn(async move |cx| {
3340 thread
3341 .update(cx, |thread, cx| thread.cancel(cx))
3342 .unwrap()
3343 .await
3344 })
3345 .detach();
3346 }
3347
3348 fn truncate(
3349 &self,
3350 session_id: &acp::SessionId,
3351 _cx: &App,
3352 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3353 Some(Rc::new(FakeAgentSessionEditor {
3354 _session_id: session_id.clone(),
3355 }))
3356 }
3357
3358 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3359 self
3360 }
3361 }
3362
3363 struct FakeAgentSessionEditor {
3364 _session_id: acp::SessionId,
3365 }
3366
3367 impl AgentSessionTruncate for FakeAgentSessionEditor {
3368 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3369 Task::ready(Ok(()))
3370 }
3371 }
3372
3373 #[gpui::test]
3374 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3375 init_test(cx);
3376
3377 let fs = FakeFs::new(cx.executor());
3378 let project = Project::test(fs, [], cx).await;
3379 let connection = Rc::new(FakeAgentConnection::new());
3380 let thread = cx
3381 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3382 .await
3383 .unwrap();
3384
3385 // Try to update a tool call that doesn't exist
3386 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3387 thread.update(cx, |thread, cx| {
3388 let result = thread.handle_session_update(
3389 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3390 id: nonexistent_id.clone(),
3391 fields: acp::ToolCallUpdateFields {
3392 status: Some(acp::ToolCallStatus::Completed),
3393 ..Default::default()
3394 },
3395 meta: None,
3396 }),
3397 cx,
3398 );
3399
3400 // The update should succeed (not return an error)
3401 assert!(result.is_ok());
3402
3403 // There should now be exactly one entry in the thread
3404 assert_eq!(thread.entries.len(), 1);
3405
3406 // The entry should be a failed tool call
3407 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3408 assert_eq!(tool_call.id, nonexistent_id);
3409 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3410 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3411
3412 // Check that the content contains the error message
3413 assert_eq!(tool_call.content.len(), 1);
3414 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3415 match content_block {
3416 ContentBlock::Markdown { markdown } => {
3417 let markdown_text = markdown.read(cx).source();
3418 assert!(markdown_text.contains("Tool call not found"));
3419 }
3420 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3421 ContentBlock::ResourceLink { .. } => {
3422 panic!("Expected markdown content, got resource link")
3423 }
3424 }
3425 } else {
3426 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3427 }
3428 } else {
3429 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3430 }
3431 });
3432 }
3433}