1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use ::terminal::terminal_settings::TerminalSettings;
7use agent_settings::AgentSettings;
8use collections::HashSet;
9pub use connection::*;
10pub use diff::*;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::ActionLog;
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77}
78
79impl AssistantMessage {
80 pub fn to_markdown(&self, cx: &App) -> String {
81 format!(
82 "## Assistant\n\n{}\n\n",
83 self.chunks
84 .iter()
85 .map(|chunk| chunk.to_markdown(cx))
86 .join("\n\n")
87 )
88 }
89}
90
91#[derive(Debug, PartialEq)]
92pub enum AssistantMessageChunk {
93 Message { block: ContentBlock },
94 Thought { block: ContentBlock },
95}
96
97impl AssistantMessageChunk {
98 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Message {
100 block: ContentBlock::new(chunk.into(), language_registry, cx),
101 }
102 }
103
104 fn to_markdown(&self, cx: &App) -> String {
105 match self {
106 Self::Message { block } => block.to_markdown(cx).to_string(),
107 Self::Thought { block } => {
108 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
109 }
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum AgentThreadEntry {
116 UserMessage(UserMessage),
117 AssistantMessage(AssistantMessage),
118 ToolCall(ToolCall),
119}
120
121impl AgentThreadEntry {
122 pub fn to_markdown(&self, cx: &App) -> String {
123 match self {
124 Self::UserMessage(message) => message.to_markdown(cx),
125 Self::AssistantMessage(message) => message.to_markdown(cx),
126 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
127 }
128 }
129
130 pub fn user_message(&self) -> Option<&UserMessage> {
131 if let AgentThreadEntry::UserMessage(message) = self {
132 Some(message)
133 } else {
134 None
135 }
136 }
137
138 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.diffs())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
147 if let AgentThreadEntry::ToolCall(call) = self {
148 itertools::Either::Left(call.terminals())
149 } else {
150 itertools::Either::Right(std::iter::empty())
151 }
152 }
153
154 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
155 if let AgentThreadEntry::ToolCall(ToolCall {
156 locations,
157 resolved_locations,
158 ..
159 }) = self
160 {
161 Some((
162 locations.get(ix)?.clone(),
163 resolved_locations.get(ix)?.clone()?,
164 ))
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 pub content: Vec<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179 pub resolved_locations: Vec<Option<AgentLocation>>,
180 pub raw_input: Option<serde_json::Value>,
181 pub raw_output: Option<serde_json::Value>,
182}
183
184impl ToolCall {
185 fn from_acp(
186 tool_call: acp::ToolCall,
187 status: ToolCallStatus,
188 language_registry: Arc<LanguageRegistry>,
189 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
190 cx: &mut App,
191 ) -> Result<Self> {
192 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
193 first_line.to_owned() + "…"
194 } else {
195 tool_call.title
196 };
197 let mut content = Vec::with_capacity(tool_call.content.len());
198 for item in tool_call.content {
199 content.push(ToolCallContent::from_acp(
200 item,
201 language_registry.clone(),
202 terminals,
203 cx,
204 )?);
205 }
206
207 let result = Self {
208 id: tool_call.id,
209 label: cx
210 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
211 kind: tool_call.kind,
212 content,
213 locations: tool_call.locations,
214 resolved_locations: Vec::default(),
215 status,
216 raw_input: tool_call.raw_input,
217 raw_output: tool_call.raw_output,
218 };
219 Ok(result)
220 }
221
222 fn update_fields(
223 &mut self,
224 fields: acp::ToolCallUpdateFields,
225 language_registry: Arc<LanguageRegistry>,
226 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
227 cx: &mut App,
228 ) -> Result<()> {
229 let acp::ToolCallUpdateFields {
230 kind,
231 status,
232 title,
233 content,
234 locations,
235 raw_input,
236 raw_output,
237 } = fields;
238
239 if let Some(kind) = kind {
240 self.kind = kind;
241 }
242
243 if let Some(status) = status {
244 self.status = status.into();
245 }
246
247 if let Some(title) = title {
248 self.label.update(cx, |label, cx| {
249 if let Some((first_line, _)) = title.split_once("\n") {
250 label.replace(first_line.to_owned() + "…", cx)
251 } else {
252 label.replace(title, cx);
253 }
254 });
255 }
256
257 if let Some(content) = content {
258 let new_content_len = content.len();
259 let mut content = content.into_iter();
260
261 // Reuse existing content if we can
262 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
263 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
264 }
265 for new in content {
266 self.content.push(ToolCallContent::from_acp(
267 new,
268 language_registry.clone(),
269 terminals,
270 cx,
271 )?)
272 }
273 self.content.truncate(new_content_len);
274 }
275
276 if let Some(locations) = locations {
277 self.locations = locations;
278 }
279
280 if let Some(raw_input) = raw_input {
281 self.raw_input = Some(raw_input);
282 }
283
284 if let Some(raw_output) = raw_output {
285 if self.content.is_empty()
286 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
287 {
288 self.content
289 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
290 markdown,
291 }));
292 }
293 self.raw_output = Some(raw_output);
294 }
295 Ok(())
296 }
297
298 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
299 self.content.iter().filter_map(|content| match content {
300 ToolCallContent::Diff(diff) => Some(diff),
301 ToolCallContent::ContentBlock(_) => None,
302 ToolCallContent::Terminal(_) => None,
303 })
304 }
305
306 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
307 self.content.iter().filter_map(|content| match content {
308 ToolCallContent::Terminal(terminal) => Some(terminal),
309 ToolCallContent::ContentBlock(_) => None,
310 ToolCallContent::Diff(_) => None,
311 })
312 }
313
314 fn to_markdown(&self, cx: &App) -> String {
315 let mut markdown = format!(
316 "**Tool Call: {}**\nStatus: {}\n\n",
317 self.label.read(cx).source(),
318 self.status
319 );
320 for content in &self.content {
321 markdown.push_str(content.to_markdown(cx).as_str());
322 markdown.push_str("\n\n");
323 }
324 markdown
325 }
326
327 async fn resolve_location(
328 location: acp::ToolCallLocation,
329 project: WeakEntity<Project>,
330 cx: &mut AsyncApp,
331 ) -> Option<AgentLocation> {
332 let buffer = project
333 .update(cx, |project, cx| {
334 project
335 .project_path_for_absolute_path(&location.path, cx)
336 .map(|path| project.open_buffer(path, cx))
337 })
338 .ok()??;
339 let buffer = buffer.await.log_err()?;
340 let position = buffer
341 .update(cx, |buffer, _| {
342 if let Some(row) = location.line {
343 let snapshot = buffer.snapshot();
344 let column = snapshot.indent_size_for_line(row).len;
345 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
346 snapshot.anchor_before(point)
347 } else {
348 Anchor::MIN
349 }
350 })
351 .ok()?;
352
353 Some(AgentLocation {
354 buffer: buffer.downgrade(),
355 position,
356 })
357 }
358
359 fn resolve_locations(
360 &self,
361 project: Entity<Project>,
362 cx: &mut App,
363 ) -> Task<Vec<Option<AgentLocation>>> {
364 let locations = self.locations.clone();
365 project.update(cx, |_, cx| {
366 cx.spawn(async move |project, cx| {
367 let mut new_locations = Vec::new();
368 for location in locations {
369 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
370 }
371 new_locations
372 })
373 })
374 }
375}
376
377#[derive(Debug)]
378pub enum ToolCallStatus {
379 /// The tool call hasn't started running yet, but we start showing it to
380 /// the user.
381 Pending,
382 /// The tool call is waiting for confirmation from the user.
383 WaitingForConfirmation {
384 options: Vec<acp::PermissionOption>,
385 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
386 },
387 /// The tool call is currently running.
388 InProgress,
389 /// The tool call completed successfully.
390 Completed,
391 /// The tool call failed.
392 Failed,
393 /// The user rejected the tool call.
394 Rejected,
395 /// The user canceled generation so the tool call was canceled.
396 Canceled,
397}
398
399impl From<acp::ToolCallStatus> for ToolCallStatus {
400 fn from(status: acp::ToolCallStatus) -> Self {
401 match status {
402 acp::ToolCallStatus::Pending => Self::Pending,
403 acp::ToolCallStatus::InProgress => Self::InProgress,
404 acp::ToolCallStatus::Completed => Self::Completed,
405 acp::ToolCallStatus::Failed => Self::Failed,
406 }
407 }
408}
409
410impl Display for ToolCallStatus {
411 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
412 write!(
413 f,
414 "{}",
415 match self {
416 ToolCallStatus::Pending => "Pending",
417 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
418 ToolCallStatus::InProgress => "In Progress",
419 ToolCallStatus::Completed => "Completed",
420 ToolCallStatus::Failed => "Failed",
421 ToolCallStatus::Rejected => "Rejected",
422 ToolCallStatus::Canceled => "Canceled",
423 }
424 )
425 }
426}
427
428#[derive(Debug, PartialEq, Clone)]
429pub enum ContentBlock {
430 Empty,
431 Markdown { markdown: Entity<Markdown> },
432 ResourceLink { resource_link: acp::ResourceLink },
433}
434
435impl ContentBlock {
436 pub fn new(
437 block: acp::ContentBlock,
438 language_registry: &Arc<LanguageRegistry>,
439 cx: &mut App,
440 ) -> Self {
441 let mut this = Self::Empty;
442 this.append(block, language_registry, cx);
443 this
444 }
445
446 pub fn new_combined(
447 blocks: impl IntoIterator<Item = acp::ContentBlock>,
448 language_registry: Arc<LanguageRegistry>,
449 cx: &mut App,
450 ) -> Self {
451 let mut this = Self::Empty;
452 for block in blocks {
453 this.append(block, &language_registry, cx);
454 }
455 this
456 }
457
458 pub fn append(
459 &mut self,
460 block: acp::ContentBlock,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) {
464 if matches!(self, ContentBlock::Empty)
465 && let acp::ContentBlock::ResourceLink(resource_link) = block
466 {
467 *self = ContentBlock::ResourceLink { resource_link };
468 return;
469 }
470
471 let new_content = self.block_string_contents(block);
472
473 match self {
474 ContentBlock::Empty => {
475 *self = Self::create_markdown_block(new_content, language_registry, cx);
476 }
477 ContentBlock::Markdown { markdown } => {
478 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
479 }
480 ContentBlock::ResourceLink { resource_link } => {
481 let existing_content = Self::resource_link_md(&resource_link.uri);
482 let combined = format!("{}\n{}", existing_content, new_content);
483
484 *self = Self::create_markdown_block(combined, language_registry, cx);
485 }
486 }
487 }
488
489 fn create_markdown_block(
490 content: String,
491 language_registry: &Arc<LanguageRegistry>,
492 cx: &mut App,
493 ) -> ContentBlock {
494 ContentBlock::Markdown {
495 markdown: cx
496 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
497 }
498 }
499
500 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
501 match block {
502 acp::ContentBlock::Text(text_content) => text_content.text,
503 acp::ContentBlock::ResourceLink(resource_link) => {
504 Self::resource_link_md(&resource_link.uri)
505 }
506 acp::ContentBlock::Resource(acp::EmbeddedResource {
507 resource:
508 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
509 uri,
510 ..
511 }),
512 ..
513 }) => Self::resource_link_md(&uri),
514 acp::ContentBlock::Image(image) => Self::image_md(&image),
515 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
516 }
517 }
518
519 fn resource_link_md(uri: &str) -> String {
520 if let Some(uri) = MentionUri::parse(uri).log_err() {
521 uri.as_link().to_string()
522 } else {
523 uri.to_string()
524 }
525 }
526
527 fn image_md(_image: &acp::ImageContent) -> String {
528 "`Image`".into()
529 }
530
531 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
532 match self {
533 ContentBlock::Empty => "",
534 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
535 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
536 }
537 }
538
539 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
540 match self {
541 ContentBlock::Empty => None,
542 ContentBlock::Markdown { markdown } => Some(markdown),
543 ContentBlock::ResourceLink { .. } => None,
544 }
545 }
546
547 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
548 match self {
549 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
550 _ => None,
551 }
552 }
553}
554
555#[derive(Debug)]
556pub enum ToolCallContent {
557 ContentBlock(ContentBlock),
558 Diff(Entity<Diff>),
559 Terminal(Entity<Terminal>),
560}
561
562impl ToolCallContent {
563 pub fn from_acp(
564 content: acp::ToolCallContent,
565 language_registry: Arc<LanguageRegistry>,
566 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
567 cx: &mut App,
568 ) -> Result<Self> {
569 match content {
570 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
571 content,
572 &language_registry,
573 cx,
574 ))),
575 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
576 Diff::finalized(
577 diff.path.to_string_lossy().into_owned(),
578 diff.old_text,
579 diff.new_text,
580 language_registry,
581 cx,
582 )
583 }))),
584 acp::ToolCallContent::Terminal { terminal_id } => terminals
585 .get(&terminal_id)
586 .cloned()
587 .map(Self::Terminal)
588 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
589 }
590 }
591
592 pub fn update_from_acp(
593 &mut self,
594 new: acp::ToolCallContent,
595 language_registry: Arc<LanguageRegistry>,
596 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
597 cx: &mut App,
598 ) -> Result<()> {
599 let needs_update = match (&self, &new) {
600 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
601 old_diff.read(cx).needs_update(
602 new_diff.old_text.as_deref().unwrap_or(""),
603 &new_diff.new_text,
604 cx,
605 )
606 }
607 _ => true,
608 };
609
610 if needs_update {
611 *self = Self::from_acp(new, language_registry, terminals, cx)?;
612 }
613 Ok(())
614 }
615
616 pub fn to_markdown(&self, cx: &App) -> String {
617 match self {
618 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
619 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
620 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
621 }
622 }
623}
624
625#[derive(Debug, PartialEq)]
626pub enum ToolCallUpdate {
627 UpdateFields(acp::ToolCallUpdate),
628 UpdateDiff(ToolCallUpdateDiff),
629 UpdateTerminal(ToolCallUpdateTerminal),
630}
631
632impl ToolCallUpdate {
633 fn id(&self) -> &acp::ToolCallId {
634 match self {
635 Self::UpdateFields(update) => &update.id,
636 Self::UpdateDiff(diff) => &diff.id,
637 Self::UpdateTerminal(terminal) => &terminal.id,
638 }
639 }
640}
641
642impl From<acp::ToolCallUpdate> for ToolCallUpdate {
643 fn from(update: acp::ToolCallUpdate) -> Self {
644 Self::UpdateFields(update)
645 }
646}
647
648impl From<ToolCallUpdateDiff> for ToolCallUpdate {
649 fn from(diff: ToolCallUpdateDiff) -> Self {
650 Self::UpdateDiff(diff)
651 }
652}
653
654#[derive(Debug, PartialEq)]
655pub struct ToolCallUpdateDiff {
656 pub id: acp::ToolCallId,
657 pub diff: Entity<Diff>,
658}
659
660impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
661 fn from(terminal: ToolCallUpdateTerminal) -> Self {
662 Self::UpdateTerminal(terminal)
663 }
664}
665
666#[derive(Debug, PartialEq)]
667pub struct ToolCallUpdateTerminal {
668 pub id: acp::ToolCallId,
669 pub terminal: Entity<Terminal>,
670}
671
672#[derive(Debug, Default)]
673pub struct Plan {
674 pub entries: Vec<PlanEntry>,
675}
676
677#[derive(Debug)]
678pub struct PlanStats<'a> {
679 pub in_progress_entry: Option<&'a PlanEntry>,
680 pub pending: u32,
681 pub completed: u32,
682}
683
684impl Plan {
685 pub fn is_empty(&self) -> bool {
686 self.entries.is_empty()
687 }
688
689 pub fn stats(&self) -> PlanStats<'_> {
690 let mut stats = PlanStats {
691 in_progress_entry: None,
692 pending: 0,
693 completed: 0,
694 };
695
696 for entry in &self.entries {
697 match &entry.status {
698 acp::PlanEntryStatus::Pending => {
699 stats.pending += 1;
700 }
701 acp::PlanEntryStatus::InProgress => {
702 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
703 }
704 acp::PlanEntryStatus::Completed => {
705 stats.completed += 1;
706 }
707 }
708 }
709
710 stats
711 }
712}
713
714#[derive(Debug)]
715pub struct PlanEntry {
716 pub content: Entity<Markdown>,
717 pub priority: acp::PlanEntryPriority,
718 pub status: acp::PlanEntryStatus,
719}
720
721impl PlanEntry {
722 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
723 Self {
724 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
725 priority: entry.priority,
726 status: entry.status,
727 }
728 }
729}
730
731#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
732pub struct TokenUsage {
733 pub max_tokens: u64,
734 pub used_tokens: u64,
735}
736
737impl TokenUsage {
738 pub fn ratio(&self) -> TokenUsageRatio {
739 #[cfg(debug_assertions)]
740 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
741 .unwrap_or("0.8".to_string())
742 .parse()
743 .unwrap();
744 #[cfg(not(debug_assertions))]
745 let warning_threshold: f32 = 0.8;
746
747 // When the maximum is unknown because there is no selected model,
748 // avoid showing the token limit warning.
749 if self.max_tokens == 0 {
750 TokenUsageRatio::Normal
751 } else if self.used_tokens >= self.max_tokens {
752 TokenUsageRatio::Exceeded
753 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
754 TokenUsageRatio::Warning
755 } else {
756 TokenUsageRatio::Normal
757 }
758 }
759}
760
761#[derive(Debug, Clone, PartialEq, Eq)]
762pub enum TokenUsageRatio {
763 Normal,
764 Warning,
765 Exceeded,
766}
767
768#[derive(Debug, Clone)]
769pub struct RetryStatus {
770 pub last_error: SharedString,
771 pub attempt: usize,
772 pub max_attempts: usize,
773 pub started_at: Instant,
774 pub duration: Duration,
775}
776
777pub struct AcpThread {
778 title: SharedString,
779 entries: Vec<AgentThreadEntry>,
780 plan: Plan,
781 project: Entity<Project>,
782 action_log: Entity<ActionLog>,
783 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
784 send_task: Option<Task<()>>,
785 connection: Rc<dyn AgentConnection>,
786 session_id: acp::SessionId,
787 token_usage: Option<TokenUsage>,
788 prompt_capabilities: acp::PromptCapabilities,
789 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
790 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
791}
792
793#[derive(Debug)]
794pub enum AcpThreadEvent {
795 NewEntry,
796 TitleUpdated,
797 TokenUsageUpdated,
798 EntryUpdated(usize),
799 EntriesRemoved(Range<usize>),
800 ToolAuthorizationRequired,
801 Retry(RetryStatus),
802 Stopped,
803 Error,
804 LoadError(LoadError),
805 PromptCapabilitiesUpdated,
806 Refusal,
807 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
808 ModeUpdated(acp::SessionModeId),
809}
810
811impl EventEmitter<AcpThreadEvent> for AcpThread {}
812
813#[derive(PartialEq, Eq, Debug)]
814pub enum ThreadStatus {
815 Idle,
816 Generating,
817}
818
819#[derive(Debug, Clone)]
820pub enum LoadError {
821 Unsupported {
822 command: SharedString,
823 current_version: SharedString,
824 minimum_version: SharedString,
825 },
826 FailedToInstall(SharedString),
827 Exited {
828 status: ExitStatus,
829 },
830 Other(SharedString),
831}
832
833impl Display for LoadError {
834 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
835 match self {
836 LoadError::Unsupported {
837 command: path,
838 current_version,
839 minimum_version,
840 } => {
841 write!(
842 f,
843 "version {current_version} from {path} is not supported (need at least {minimum_version})"
844 )
845 }
846 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
847 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
848 LoadError::Other(msg) => write!(f, "{msg}"),
849 }
850 }
851}
852
853impl Error for LoadError {}
854
855impl AcpThread {
856 pub fn new(
857 title: impl Into<SharedString>,
858 connection: Rc<dyn AgentConnection>,
859 project: Entity<Project>,
860 action_log: Entity<ActionLog>,
861 session_id: acp::SessionId,
862 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
863 cx: &mut Context<Self>,
864 ) -> Self {
865 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
866 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
867 loop {
868 let caps = prompt_capabilities_rx.recv().await?;
869 this.update(cx, |this, cx| {
870 this.prompt_capabilities = caps;
871 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
872 })?;
873 }
874 });
875
876 Self {
877 action_log,
878 shared_buffers: Default::default(),
879 entries: Default::default(),
880 plan: Default::default(),
881 title: title.into(),
882 project,
883 send_task: None,
884 connection,
885 session_id,
886 token_usage: None,
887 prompt_capabilities,
888 _observe_prompt_capabilities: task,
889 terminals: HashMap::default(),
890 }
891 }
892
893 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
894 self.prompt_capabilities.clone()
895 }
896
897 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
898 &self.connection
899 }
900
901 pub fn action_log(&self) -> &Entity<ActionLog> {
902 &self.action_log
903 }
904
905 pub fn project(&self) -> &Entity<Project> {
906 &self.project
907 }
908
909 pub fn title(&self) -> SharedString {
910 self.title.clone()
911 }
912
913 pub fn entries(&self) -> &[AgentThreadEntry] {
914 &self.entries
915 }
916
917 pub fn session_id(&self) -> &acp::SessionId {
918 &self.session_id
919 }
920
921 pub fn status(&self) -> ThreadStatus {
922 if self.send_task.is_some() {
923 ThreadStatus::Generating
924 } else {
925 ThreadStatus::Idle
926 }
927 }
928
929 pub fn token_usage(&self) -> Option<&TokenUsage> {
930 self.token_usage.as_ref()
931 }
932
933 pub fn has_pending_edit_tool_calls(&self) -> bool {
934 for entry in self.entries.iter().rev() {
935 match entry {
936 AgentThreadEntry::UserMessage(_) => return false,
937 AgentThreadEntry::ToolCall(
938 call @ ToolCall {
939 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
940 ..
941 },
942 ) if call.diffs().next().is_some() => {
943 return true;
944 }
945 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
946 }
947 }
948
949 false
950 }
951
952 pub fn used_tools_since_last_user_message(&self) -> bool {
953 for entry in self.entries.iter().rev() {
954 match entry {
955 AgentThreadEntry::UserMessage(..) => return false,
956 AgentThreadEntry::AssistantMessage(..) => continue,
957 AgentThreadEntry::ToolCall(..) => return true,
958 }
959 }
960
961 false
962 }
963
964 pub fn handle_session_update(
965 &mut self,
966 update: acp::SessionUpdate,
967 cx: &mut Context<Self>,
968 ) -> Result<(), acp::Error> {
969 match update {
970 acp::SessionUpdate::UserMessageChunk { content } => {
971 self.push_user_content_block(None, content, cx);
972 }
973 acp::SessionUpdate::AgentMessageChunk { content } => {
974 self.push_assistant_content_block(content, false, cx);
975 }
976 acp::SessionUpdate::AgentThoughtChunk { content } => {
977 self.push_assistant_content_block(content, true, cx);
978 }
979 acp::SessionUpdate::ToolCall(tool_call) => {
980 self.upsert_tool_call(tool_call, cx)?;
981 }
982 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
983 self.update_tool_call(tool_call_update, cx)?;
984 }
985 acp::SessionUpdate::Plan(plan) => {
986 self.update_plan(plan, cx);
987 }
988 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
989 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
990 }
991 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
992 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
993 }
994 }
995 Ok(())
996 }
997
998 pub fn push_user_content_block(
999 &mut self,
1000 message_id: Option<UserMessageId>,
1001 chunk: acp::ContentBlock,
1002 cx: &mut Context<Self>,
1003 ) {
1004 let language_registry = self.project.read(cx).languages().clone();
1005 let entries_len = self.entries.len();
1006
1007 if let Some(last_entry) = self.entries.last_mut()
1008 && let AgentThreadEntry::UserMessage(UserMessage {
1009 id,
1010 content,
1011 chunks,
1012 ..
1013 }) = last_entry
1014 {
1015 *id = message_id.or(id.take());
1016 content.append(chunk.clone(), &language_registry, cx);
1017 chunks.push(chunk);
1018 let idx = entries_len - 1;
1019 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1020 } else {
1021 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1022 self.push_entry(
1023 AgentThreadEntry::UserMessage(UserMessage {
1024 id: message_id,
1025 content,
1026 chunks: vec![chunk],
1027 checkpoint: None,
1028 }),
1029 cx,
1030 );
1031 }
1032 }
1033
1034 pub fn push_assistant_content_block(
1035 &mut self,
1036 chunk: acp::ContentBlock,
1037 is_thought: bool,
1038 cx: &mut Context<Self>,
1039 ) {
1040 let language_registry = self.project.read(cx).languages().clone();
1041 let entries_len = self.entries.len();
1042 if let Some(last_entry) = self.entries.last_mut()
1043 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1044 {
1045 let idx = entries_len - 1;
1046 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1047 match (chunks.last_mut(), is_thought) {
1048 (Some(AssistantMessageChunk::Message { block }), false)
1049 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1050 block.append(chunk, &language_registry, cx)
1051 }
1052 _ => {
1053 let block = ContentBlock::new(chunk, &language_registry, cx);
1054 if is_thought {
1055 chunks.push(AssistantMessageChunk::Thought { block })
1056 } else {
1057 chunks.push(AssistantMessageChunk::Message { block })
1058 }
1059 }
1060 }
1061 } else {
1062 let block = ContentBlock::new(chunk, &language_registry, cx);
1063 let chunk = if is_thought {
1064 AssistantMessageChunk::Thought { block }
1065 } else {
1066 AssistantMessageChunk::Message { block }
1067 };
1068
1069 self.push_entry(
1070 AgentThreadEntry::AssistantMessage(AssistantMessage {
1071 chunks: vec![chunk],
1072 }),
1073 cx,
1074 );
1075 }
1076 }
1077
1078 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1079 self.entries.push(entry);
1080 cx.emit(AcpThreadEvent::NewEntry);
1081 }
1082
1083 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1084 self.connection.set_title(&self.session_id, cx).is_some()
1085 }
1086
1087 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1088 if title != self.title {
1089 self.title = title.clone();
1090 cx.emit(AcpThreadEvent::TitleUpdated);
1091 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1092 return set_title.run(title, cx);
1093 }
1094 }
1095 Task::ready(Ok(()))
1096 }
1097
1098 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1099 self.token_usage = usage;
1100 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1101 }
1102
1103 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1104 cx.emit(AcpThreadEvent::Retry(status));
1105 }
1106
1107 pub fn update_tool_call(
1108 &mut self,
1109 update: impl Into<ToolCallUpdate>,
1110 cx: &mut Context<Self>,
1111 ) -> Result<()> {
1112 let update = update.into();
1113 let languages = self.project.read(cx).languages().clone();
1114
1115 let ix = match self.index_for_tool_call(update.id()) {
1116 Some(ix) => ix,
1117 None => {
1118 // Tool call not found - create a failed tool call entry
1119 let failed_tool_call = ToolCall {
1120 id: update.id().clone(),
1121 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1122 kind: acp::ToolKind::Fetch,
1123 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1124 acp::ContentBlock::Text(acp::TextContent {
1125 text: "Tool call not found".to_string(),
1126 annotations: None,
1127 meta: None,
1128 }),
1129 &languages,
1130 cx,
1131 ))],
1132 status: ToolCallStatus::Failed,
1133 locations: Vec::new(),
1134 resolved_locations: Vec::new(),
1135 raw_input: None,
1136 raw_output: None,
1137 };
1138 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1139 return Ok(());
1140 }
1141 };
1142 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1143 unreachable!()
1144 };
1145
1146 match update {
1147 ToolCallUpdate::UpdateFields(update) => {
1148 let location_updated = update.fields.locations.is_some();
1149 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1150 if location_updated {
1151 self.resolve_locations(update.id, cx);
1152 }
1153 }
1154 ToolCallUpdate::UpdateDiff(update) => {
1155 call.content.clear();
1156 call.content.push(ToolCallContent::Diff(update.diff));
1157 }
1158 ToolCallUpdate::UpdateTerminal(update) => {
1159 call.content.clear();
1160 call.content
1161 .push(ToolCallContent::Terminal(update.terminal));
1162 }
1163 }
1164
1165 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1166
1167 Ok(())
1168 }
1169
1170 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1171 pub fn upsert_tool_call(
1172 &mut self,
1173 tool_call: acp::ToolCall,
1174 cx: &mut Context<Self>,
1175 ) -> Result<(), acp::Error> {
1176 let status = tool_call.status.into();
1177 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1178 }
1179
1180 /// Fails if id does not match an existing entry.
1181 pub fn upsert_tool_call_inner(
1182 &mut self,
1183 update: acp::ToolCallUpdate,
1184 status: ToolCallStatus,
1185 cx: &mut Context<Self>,
1186 ) -> Result<(), acp::Error> {
1187 let language_registry = self.project.read(cx).languages().clone();
1188 let id = update.id.clone();
1189
1190 if let Some(ix) = self.index_for_tool_call(&id) {
1191 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1192 unreachable!()
1193 };
1194
1195 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1196 call.status = status;
1197
1198 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1199 } else {
1200 let call = ToolCall::from_acp(
1201 update.try_into()?,
1202 status,
1203 language_registry,
1204 &self.terminals,
1205 cx,
1206 )?;
1207 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1208 };
1209
1210 self.resolve_locations(id, cx);
1211 Ok(())
1212 }
1213
1214 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1215 self.entries
1216 .iter()
1217 .enumerate()
1218 .rev()
1219 .find_map(|(index, entry)| {
1220 if let AgentThreadEntry::ToolCall(tool_call) = entry
1221 && &tool_call.id == id
1222 {
1223 Some(index)
1224 } else {
1225 None
1226 }
1227 })
1228 }
1229
1230 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1231 // The tool call we are looking for is typically the last one, or very close to the end.
1232 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1233 self.entries
1234 .iter_mut()
1235 .enumerate()
1236 .rev()
1237 .find_map(|(index, tool_call)| {
1238 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1239 && &tool_call.id == id
1240 {
1241 Some((index, tool_call))
1242 } else {
1243 None
1244 }
1245 })
1246 }
1247
1248 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1249 self.entries
1250 .iter()
1251 .enumerate()
1252 .rev()
1253 .find_map(|(index, tool_call)| {
1254 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1255 && &tool_call.id == id
1256 {
1257 Some((index, tool_call))
1258 } else {
1259 None
1260 }
1261 })
1262 }
1263
1264 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1265 let project = self.project.clone();
1266 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1267 return;
1268 };
1269 let task = tool_call.resolve_locations(project, cx);
1270 cx.spawn(async move |this, cx| {
1271 let resolved_locations = task.await;
1272 this.update(cx, |this, cx| {
1273 let project = this.project.clone();
1274 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1275 return;
1276 };
1277 if let Some(Some(location)) = resolved_locations.last() {
1278 project.update(cx, |project, cx| {
1279 if let Some(agent_location) = project.agent_location() {
1280 let should_ignore = agent_location.buffer == location.buffer
1281 && location
1282 .buffer
1283 .update(cx, |buffer, _| {
1284 let snapshot = buffer.snapshot();
1285 let old_position =
1286 agent_location.position.to_point(&snapshot);
1287 let new_position = location.position.to_point(&snapshot);
1288 // ignore this so that when we get updates from the edit tool
1289 // the position doesn't reset to the startof line
1290 old_position.row == new_position.row
1291 && old_position.column > new_position.column
1292 })
1293 .ok()
1294 .unwrap_or_default();
1295 if !should_ignore {
1296 project.set_agent_location(Some(location.clone()), cx);
1297 }
1298 }
1299 });
1300 }
1301 if tool_call.resolved_locations != resolved_locations {
1302 tool_call.resolved_locations = resolved_locations;
1303 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1304 }
1305 })
1306 })
1307 .detach();
1308 }
1309
1310 pub fn request_tool_call_authorization(
1311 &mut self,
1312 tool_call: acp::ToolCallUpdate,
1313 options: Vec<acp::PermissionOption>,
1314 respect_always_allow_setting: bool,
1315 cx: &mut Context<Self>,
1316 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1317 let (tx, rx) = oneshot::channel();
1318
1319 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1320 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1321 // some tools would (incorrectly) continue to auto-accept.
1322 if let Some(allow_once_option) = options.iter().find_map(|option| {
1323 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1324 Some(option.id.clone())
1325 } else {
1326 None
1327 }
1328 }) {
1329 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1330 return Ok(async {
1331 acp::RequestPermissionOutcome::Selected {
1332 option_id: allow_once_option,
1333 }
1334 }
1335 .boxed());
1336 }
1337 }
1338
1339 let status = ToolCallStatus::WaitingForConfirmation {
1340 options,
1341 respond_tx: tx,
1342 };
1343
1344 self.upsert_tool_call_inner(tool_call, status, cx)?;
1345 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1346
1347 let fut = async {
1348 match rx.await {
1349 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1350 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1351 }
1352 }
1353 .boxed();
1354
1355 Ok(fut)
1356 }
1357
1358 pub fn authorize_tool_call(
1359 &mut self,
1360 id: acp::ToolCallId,
1361 option_id: acp::PermissionOptionId,
1362 option_kind: acp::PermissionOptionKind,
1363 cx: &mut Context<Self>,
1364 ) {
1365 let Some((ix, call)) = self.tool_call_mut(&id) else {
1366 return;
1367 };
1368
1369 let new_status = match option_kind {
1370 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1371 ToolCallStatus::Rejected
1372 }
1373 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1374 ToolCallStatus::InProgress
1375 }
1376 };
1377
1378 let curr_status = mem::replace(&mut call.status, new_status);
1379
1380 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1381 respond_tx.send(option_id).log_err();
1382 } else if cfg!(debug_assertions) {
1383 panic!("tried to authorize an already authorized tool call");
1384 }
1385
1386 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1387 }
1388
1389 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1390 let mut first_tool_call = None;
1391
1392 for entry in self.entries.iter().rev() {
1393 match &entry {
1394 AgentThreadEntry::ToolCall(call) => {
1395 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1396 first_tool_call = Some(call);
1397 } else {
1398 continue;
1399 }
1400 }
1401 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1402 // Reached the beginning of the turn.
1403 // If we had pending permission requests in the previous turn, they have been cancelled.
1404 break;
1405 }
1406 }
1407 }
1408
1409 first_tool_call
1410 }
1411
1412 pub fn plan(&self) -> &Plan {
1413 &self.plan
1414 }
1415
1416 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1417 let new_entries_len = request.entries.len();
1418 let mut new_entries = request.entries.into_iter();
1419
1420 // Reuse existing markdown to prevent flickering
1421 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1422 let PlanEntry {
1423 content,
1424 priority,
1425 status,
1426 } = old;
1427 content.update(cx, |old, cx| {
1428 old.replace(new.content, cx);
1429 });
1430 *priority = new.priority;
1431 *status = new.status;
1432 }
1433 for new in new_entries {
1434 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1435 }
1436 self.plan.entries.truncate(new_entries_len);
1437
1438 cx.notify();
1439 }
1440
1441 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1442 self.plan
1443 .entries
1444 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1445 cx.notify();
1446 }
1447
1448 #[cfg(any(test, feature = "test-support"))]
1449 pub fn send_raw(
1450 &mut self,
1451 message: &str,
1452 cx: &mut Context<Self>,
1453 ) -> BoxFuture<'static, Result<()>> {
1454 self.send(
1455 vec![acp::ContentBlock::Text(acp::TextContent {
1456 text: message.to_string(),
1457 annotations: None,
1458 meta: None,
1459 })],
1460 cx,
1461 )
1462 }
1463
1464 pub fn send(
1465 &mut self,
1466 message: Vec<acp::ContentBlock>,
1467 cx: &mut Context<Self>,
1468 ) -> BoxFuture<'static, Result<()>> {
1469 let block = ContentBlock::new_combined(
1470 message.clone(),
1471 self.project.read(cx).languages().clone(),
1472 cx,
1473 );
1474 let request = acp::PromptRequest {
1475 prompt: message.clone(),
1476 session_id: self.session_id.clone(),
1477 meta: None,
1478 };
1479 let git_store = self.project.read(cx).git_store().clone();
1480
1481 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1482 Some(UserMessageId::new())
1483 } else {
1484 None
1485 };
1486
1487 self.run_turn(cx, async move |this, cx| {
1488 this.update(cx, |this, cx| {
1489 this.push_entry(
1490 AgentThreadEntry::UserMessage(UserMessage {
1491 id: message_id.clone(),
1492 content: block,
1493 chunks: message,
1494 checkpoint: None,
1495 }),
1496 cx,
1497 );
1498 })
1499 .ok();
1500
1501 let old_checkpoint = git_store
1502 .update(cx, |git, cx| git.checkpoint(cx))?
1503 .await
1504 .context("failed to get old checkpoint")
1505 .log_err();
1506 this.update(cx, |this, cx| {
1507 if let Some((_ix, message)) = this.last_user_message() {
1508 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1509 git_checkpoint,
1510 show: false,
1511 });
1512 }
1513 this.connection.prompt(message_id, request, cx)
1514 })?
1515 .await
1516 })
1517 }
1518
1519 pub fn can_resume(&self, cx: &App) -> bool {
1520 self.connection.resume(&self.session_id, cx).is_some()
1521 }
1522
1523 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1524 self.run_turn(cx, async move |this, cx| {
1525 this.update(cx, |this, cx| {
1526 this.connection
1527 .resume(&this.session_id, cx)
1528 .map(|resume| resume.run(cx))
1529 })?
1530 .context("resuming a session is not supported")?
1531 .await
1532 })
1533 }
1534
1535 fn run_turn(
1536 &mut self,
1537 cx: &mut Context<Self>,
1538 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1539 ) -> BoxFuture<'static, Result<()>> {
1540 self.clear_completed_plan_entries(cx);
1541
1542 let (tx, rx) = oneshot::channel();
1543 let cancel_task = self.cancel(cx);
1544
1545 self.send_task = Some(cx.spawn(async move |this, cx| {
1546 cancel_task.await;
1547 tx.send(f(this, cx).await).ok();
1548 }));
1549
1550 cx.spawn(async move |this, cx| {
1551 let response = rx.await;
1552
1553 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1554 .await?;
1555
1556 this.update(cx, |this, cx| {
1557 this.project
1558 .update(cx, |project, cx| project.set_agent_location(None, cx));
1559 match response {
1560 Ok(Err(e)) => {
1561 this.send_task.take();
1562 cx.emit(AcpThreadEvent::Error);
1563 Err(e)
1564 }
1565 result => {
1566 let canceled = matches!(
1567 result,
1568 Ok(Ok(acp::PromptResponse {
1569 stop_reason: acp::StopReason::Cancelled,
1570 meta: None,
1571 }))
1572 );
1573
1574 // We only take the task if the current prompt wasn't canceled.
1575 //
1576 // This prompt may have been canceled because another one was sent
1577 // while it was still generating. In these cases, dropping `send_task`
1578 // would cause the next generation to be canceled.
1579 if !canceled {
1580 this.send_task.take();
1581 }
1582
1583 // Handle refusal - distinguish between user prompt and tool call refusals
1584 if let Ok(Ok(acp::PromptResponse {
1585 stop_reason: acp::StopReason::Refusal,
1586 meta: _,
1587 })) = result
1588 {
1589 if let Some((user_msg_ix, _)) = this.last_user_message() {
1590 // Check if there's a completed tool call with results after the last user message
1591 // This indicates the refusal is in response to tool output, not the user's prompt
1592 let has_completed_tool_call_after_user_msg =
1593 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1594 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1595 // Check if the tool call has completed and has output
1596 matches!(tool_call.status, ToolCallStatus::Completed)
1597 && tool_call.raw_output.is_some()
1598 } else {
1599 false
1600 }
1601 });
1602
1603 if has_completed_tool_call_after_user_msg {
1604 // Refusal is due to tool output - don't truncate, just notify
1605 // The model refused based on what the tool returned
1606 cx.emit(AcpThreadEvent::Refusal);
1607 } else {
1608 // User prompt was refused - truncate back to before the user message
1609 let range = user_msg_ix..this.entries.len();
1610 if range.start < range.end {
1611 this.entries.truncate(user_msg_ix);
1612 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1613 }
1614 cx.emit(AcpThreadEvent::Refusal);
1615 }
1616 } else {
1617 // No user message found, treat as general refusal
1618 cx.emit(AcpThreadEvent::Refusal);
1619 }
1620 }
1621
1622 cx.emit(AcpThreadEvent::Stopped);
1623 Ok(())
1624 }
1625 }
1626 })?
1627 })
1628 .boxed()
1629 }
1630
1631 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1632 let Some(send_task) = self.send_task.take() else {
1633 return Task::ready(());
1634 };
1635
1636 for entry in self.entries.iter_mut() {
1637 if let AgentThreadEntry::ToolCall(call) = entry {
1638 let cancel = matches!(
1639 call.status,
1640 ToolCallStatus::Pending
1641 | ToolCallStatus::WaitingForConfirmation { .. }
1642 | ToolCallStatus::InProgress
1643 );
1644
1645 if cancel {
1646 call.status = ToolCallStatus::Canceled;
1647 }
1648 }
1649 }
1650
1651 self.connection.cancel(&self.session_id, cx);
1652
1653 // Wait for the send task to complete
1654 cx.foreground_executor().spawn(send_task)
1655 }
1656
1657 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1658 pub fn restore_checkpoint(
1659 &mut self,
1660 id: UserMessageId,
1661 cx: &mut Context<Self>,
1662 ) -> Task<Result<()>> {
1663 let Some((_, message)) = self.user_message_mut(&id) else {
1664 return Task::ready(Err(anyhow!("message not found")));
1665 };
1666
1667 let checkpoint = message
1668 .checkpoint
1669 .as_ref()
1670 .map(|c| c.git_checkpoint.clone());
1671 let rewind = self.rewind(id.clone(), cx);
1672 let git_store = self.project.read(cx).git_store().clone();
1673
1674 cx.spawn(async move |_, cx| {
1675 rewind.await?;
1676 if let Some(checkpoint) = checkpoint {
1677 git_store
1678 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1679 .await?;
1680 }
1681
1682 Ok(())
1683 })
1684 }
1685
1686 /// Rewinds this thread to before the entry at `index`, removing it and all
1687 /// subsequent entries while rejecting any action_log changes made from that point.
1688 /// Unlike `restore_checkpoint`, this method does not restore from git.
1689 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1690 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1691 return Task::ready(Err(anyhow!("not supported")));
1692 };
1693
1694 cx.spawn(async move |this, cx| {
1695 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1696 this.update(cx, |this, cx| {
1697 if let Some((ix, _)) = this.user_message_mut(&id) {
1698 let range = ix..this.entries.len();
1699 this.entries.truncate(ix);
1700 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1701 }
1702 this.action_log()
1703 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1704 })?
1705 .await;
1706 Ok(())
1707 })
1708 }
1709
1710 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1711 let git_store = self.project.read(cx).git_store().clone();
1712
1713 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1714 if let Some(checkpoint) = message.checkpoint.as_ref() {
1715 checkpoint.git_checkpoint.clone()
1716 } else {
1717 return Task::ready(Ok(()));
1718 }
1719 } else {
1720 return Task::ready(Ok(()));
1721 };
1722
1723 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1724 cx.spawn(async move |this, cx| {
1725 let new_checkpoint = new_checkpoint
1726 .await
1727 .context("failed to get new checkpoint")
1728 .log_err();
1729 if let Some(new_checkpoint) = new_checkpoint {
1730 let equal = git_store
1731 .update(cx, |git, cx| {
1732 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1733 })?
1734 .await
1735 .unwrap_or(true);
1736 this.update(cx, |this, cx| {
1737 let (ix, message) = this.last_user_message().context("no user message")?;
1738 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1739 checkpoint.show = !equal;
1740 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1741 anyhow::Ok(())
1742 })??;
1743 }
1744
1745 Ok(())
1746 })
1747 }
1748
1749 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1750 self.entries
1751 .iter_mut()
1752 .enumerate()
1753 .rev()
1754 .find_map(|(ix, entry)| {
1755 if let AgentThreadEntry::UserMessage(message) = entry {
1756 Some((ix, message))
1757 } else {
1758 None
1759 }
1760 })
1761 }
1762
1763 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1764 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1765 if let AgentThreadEntry::UserMessage(message) = entry {
1766 if message.id.as_ref() == Some(id) {
1767 Some((ix, message))
1768 } else {
1769 None
1770 }
1771 } else {
1772 None
1773 }
1774 })
1775 }
1776
1777 pub fn read_text_file(
1778 &self,
1779 path: PathBuf,
1780 line: Option<u32>,
1781 limit: Option<u32>,
1782 reuse_shared_snapshot: bool,
1783 cx: &mut Context<Self>,
1784 ) -> Task<Result<String, acp::Error>> {
1785 // Args are 1-based, move to 0-based
1786 let line = line.unwrap_or_default().saturating_sub(1);
1787 let limit = limit.unwrap_or(u32::MAX);
1788 let project = self.project.clone();
1789 let action_log = self.action_log.clone();
1790 cx.spawn(async move |this, cx| {
1791 let load = project
1792 .update(cx, |project, cx| {
1793 let path = project
1794 .project_path_for_absolute_path(&path, cx)
1795 .ok_or_else(|| {
1796 acp::Error::resource_not_found(Some(path.display().to_string()))
1797 })?;
1798 Ok(project.open_buffer(path, cx))
1799 })
1800 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1801 .flatten()?;
1802
1803 let buffer = load.await?;
1804
1805 let snapshot = if reuse_shared_snapshot {
1806 this.read_with(cx, |this, _| {
1807 this.shared_buffers.get(&buffer.clone()).cloned()
1808 })
1809 .log_err()
1810 .flatten()
1811 } else {
1812 None
1813 };
1814
1815 let snapshot = if let Some(snapshot) = snapshot {
1816 snapshot
1817 } else {
1818 action_log.update(cx, |action_log, cx| {
1819 action_log.buffer_read(buffer.clone(), cx);
1820 })?;
1821
1822 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1823 this.update(cx, |this, _| {
1824 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1825 })?;
1826 snapshot
1827 };
1828
1829 let max_point = snapshot.max_point();
1830 let start_position = Point::new(line, 0);
1831
1832 if start_position > max_point {
1833 return Err(acp::Error::invalid_params().with_data(format!(
1834 "Attempting to read beyond the end of the file, line {}:{}",
1835 max_point.row + 1,
1836 max_point.column
1837 )));
1838 }
1839
1840 let start = snapshot.anchor_before(start_position);
1841 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1842
1843 project.update(cx, |project, cx| {
1844 project.set_agent_location(
1845 Some(AgentLocation {
1846 buffer: buffer.downgrade(),
1847 position: start,
1848 }),
1849 cx,
1850 );
1851 })?;
1852
1853 Ok(snapshot.text_for_range(start..end).collect::<String>())
1854 })
1855 }
1856
1857 pub fn write_text_file(
1858 &self,
1859 path: PathBuf,
1860 content: String,
1861 cx: &mut Context<Self>,
1862 ) -> Task<Result<()>> {
1863 let project = self.project.clone();
1864 let action_log = self.action_log.clone();
1865 cx.spawn(async move |this, cx| {
1866 let load = project.update(cx, |project, cx| {
1867 let path = project
1868 .project_path_for_absolute_path(&path, cx)
1869 .context("invalid path")?;
1870 anyhow::Ok(project.open_buffer(path, cx))
1871 });
1872 let buffer = load??.await?;
1873 let snapshot = this.update(cx, |this, cx| {
1874 this.shared_buffers
1875 .get(&buffer)
1876 .cloned()
1877 .unwrap_or_else(|| buffer.read(cx).snapshot())
1878 })?;
1879 let edits = cx
1880 .background_executor()
1881 .spawn(async move {
1882 let old_text = snapshot.text();
1883 text_diff(old_text.as_str(), &content)
1884 .into_iter()
1885 .map(|(range, replacement)| {
1886 (
1887 snapshot.anchor_after(range.start)
1888 ..snapshot.anchor_before(range.end),
1889 replacement,
1890 )
1891 })
1892 .collect::<Vec<_>>()
1893 })
1894 .await;
1895
1896 project.update(cx, |project, cx| {
1897 project.set_agent_location(
1898 Some(AgentLocation {
1899 buffer: buffer.downgrade(),
1900 position: edits
1901 .last()
1902 .map(|(range, _)| range.end)
1903 .unwrap_or(Anchor::MIN),
1904 }),
1905 cx,
1906 );
1907 })?;
1908
1909 let format_on_save = cx.update(|cx| {
1910 action_log.update(cx, |action_log, cx| {
1911 action_log.buffer_read(buffer.clone(), cx);
1912 });
1913
1914 let format_on_save = buffer.update(cx, |buffer, cx| {
1915 buffer.edit(edits, None, cx);
1916
1917 let settings = language::language_settings::language_settings(
1918 buffer.language().map(|l| l.name()),
1919 buffer.file(),
1920 cx,
1921 );
1922
1923 settings.format_on_save != FormatOnSave::Off
1924 });
1925 action_log.update(cx, |action_log, cx| {
1926 action_log.buffer_edited(buffer.clone(), cx);
1927 });
1928 format_on_save
1929 })?;
1930
1931 if format_on_save {
1932 let format_task = project.update(cx, |project, cx| {
1933 project.format(
1934 HashSet::from_iter([buffer.clone()]),
1935 LspFormatTarget::Buffers,
1936 false,
1937 FormatTrigger::Save,
1938 cx,
1939 )
1940 })?;
1941 format_task.await.log_err();
1942
1943 action_log.update(cx, |action_log, cx| {
1944 action_log.buffer_edited(buffer.clone(), cx);
1945 })?;
1946 }
1947
1948 project
1949 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1950 .await
1951 })
1952 }
1953
1954 pub fn create_terminal(
1955 &self,
1956 command: String,
1957 args: Vec<String>,
1958 extra_env: Vec<acp::EnvVariable>,
1959 cwd: Option<PathBuf>,
1960 output_byte_limit: Option<u64>,
1961 cx: &mut Context<Self>,
1962 ) -> Task<Result<Entity<Terminal>>> {
1963 let env = match &cwd {
1964 Some(dir) => self.project.update(cx, |project, cx| {
1965 let shell = TerminalSettings::get_global(cx).shell.clone();
1966 project.directory_environment(&shell, dir.as_path().into(), cx)
1967 }),
1968 None => Task::ready(None).shared(),
1969 };
1970 let env = cx.spawn(async move |_, _| {
1971 let mut env = env.await.unwrap_or_default();
1972 // Disables paging for `git` and hopefully other commands
1973 env.insert("PAGER".into(), "".into());
1974 for var in extra_env {
1975 env.insert(var.name, var.value);
1976 }
1977 env
1978 });
1979
1980 let project = self.project.clone();
1981 let language_registry = project.read(cx).languages().clone();
1982
1983 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1984 let terminal_task = cx.spawn({
1985 let terminal_id = terminal_id.clone();
1986 async move |_this, cx| {
1987 let env = env.await;
1988 let (task_command, task_args) = ShellBuilder::new(
1989 project
1990 .update(cx, |project, cx| {
1991 project
1992 .remote_client()
1993 .and_then(|r| r.read(cx).default_system_shell())
1994 })?
1995 .as_deref(),
1996 &Shell::Program(get_default_system_shell()),
1997 )
1998 .redirect_stdin_to_dev_null()
1999 .build(Some(command.clone()), &args);
2000 let terminal = project
2001 .update(cx, |project, cx| {
2002 project.create_terminal_task(
2003 task::SpawnInTerminal {
2004 command: Some(task_command),
2005 args: task_args,
2006 cwd: cwd.clone(),
2007 env,
2008 ..Default::default()
2009 },
2010 cx,
2011 )
2012 })?
2013 .await?;
2014
2015 cx.new(|cx| {
2016 Terminal::new(
2017 terminal_id,
2018 &format!("{} {}", command, args.join(" ")),
2019 cwd,
2020 output_byte_limit.map(|l| l as usize),
2021 terminal,
2022 language_registry,
2023 cx,
2024 )
2025 })
2026 }
2027 });
2028
2029 cx.spawn(async move |this, cx| {
2030 let terminal = terminal_task.await?;
2031 this.update(cx, |this, _cx| {
2032 this.terminals.insert(terminal_id, terminal.clone());
2033 terminal
2034 })
2035 })
2036 }
2037
2038 pub fn kill_terminal(
2039 &mut self,
2040 terminal_id: acp::TerminalId,
2041 cx: &mut Context<Self>,
2042 ) -> Result<()> {
2043 self.terminals
2044 .get(&terminal_id)
2045 .context("Terminal not found")?
2046 .update(cx, |terminal, cx| {
2047 terminal.kill(cx);
2048 });
2049
2050 Ok(())
2051 }
2052
2053 pub fn release_terminal(
2054 &mut self,
2055 terminal_id: acp::TerminalId,
2056 cx: &mut Context<Self>,
2057 ) -> Result<()> {
2058 self.terminals
2059 .remove(&terminal_id)
2060 .context("Terminal not found")?
2061 .update(cx, |terminal, cx| {
2062 terminal.kill(cx);
2063 });
2064
2065 Ok(())
2066 }
2067
2068 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2069 self.terminals
2070 .get(&terminal_id)
2071 .context("Terminal not found")
2072 .cloned()
2073 }
2074
2075 pub fn to_markdown(&self, cx: &App) -> String {
2076 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2077 }
2078
2079 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2080 cx.emit(AcpThreadEvent::LoadError(error));
2081 }
2082}
2083
2084fn markdown_for_raw_output(
2085 raw_output: &serde_json::Value,
2086 language_registry: &Arc<LanguageRegistry>,
2087 cx: &mut App,
2088) -> Option<Entity<Markdown>> {
2089 match raw_output {
2090 serde_json::Value::Null => None,
2091 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2092 Markdown::new(
2093 value.to_string().into(),
2094 Some(language_registry.clone()),
2095 None,
2096 cx,
2097 )
2098 })),
2099 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2100 Markdown::new(
2101 value.to_string().into(),
2102 Some(language_registry.clone()),
2103 None,
2104 cx,
2105 )
2106 })),
2107 serde_json::Value::String(value) => Some(cx.new(|cx| {
2108 Markdown::new(
2109 value.clone().into(),
2110 Some(language_registry.clone()),
2111 None,
2112 cx,
2113 )
2114 })),
2115 value => Some(cx.new(|cx| {
2116 Markdown::new(
2117 format!("```json\n{}\n```", value).into(),
2118 Some(language_registry.clone()),
2119 None,
2120 cx,
2121 )
2122 })),
2123 }
2124}
2125
2126#[cfg(test)]
2127mod tests {
2128 use super::*;
2129 use anyhow::anyhow;
2130 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2131 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2132 use indoc::indoc;
2133 use project::{FakeFs, Fs};
2134 use rand::{distr, prelude::*};
2135 use serde_json::json;
2136 use settings::SettingsStore;
2137 use smol::stream::StreamExt as _;
2138 use std::{
2139 any::Any,
2140 cell::RefCell,
2141 path::Path,
2142 rc::Rc,
2143 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2144 time::Duration,
2145 };
2146 use util::path;
2147
2148 fn init_test(cx: &mut TestAppContext) {
2149 env_logger::try_init().ok();
2150 cx.update(|cx| {
2151 let settings_store = SettingsStore::test(cx);
2152 cx.set_global(settings_store);
2153 Project::init_settings(cx);
2154 language::init(cx);
2155 });
2156 }
2157
2158 #[gpui::test]
2159 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2160 init_test(cx);
2161
2162 let fs = FakeFs::new(cx.executor());
2163 let project = Project::test(fs, [], cx).await;
2164 let connection = Rc::new(FakeAgentConnection::new());
2165 let thread = cx
2166 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2167 .await
2168 .unwrap();
2169
2170 // Test creating a new user message
2171 thread.update(cx, |thread, cx| {
2172 thread.push_user_content_block(
2173 None,
2174 acp::ContentBlock::Text(acp::TextContent {
2175 annotations: None,
2176 text: "Hello, ".to_string(),
2177 meta: None,
2178 }),
2179 cx,
2180 );
2181 });
2182
2183 thread.update(cx, |thread, cx| {
2184 assert_eq!(thread.entries.len(), 1);
2185 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2186 assert_eq!(user_msg.id, None);
2187 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2188 } else {
2189 panic!("Expected UserMessage");
2190 }
2191 });
2192
2193 // Test appending to existing user message
2194 let message_1_id = UserMessageId::new();
2195 thread.update(cx, |thread, cx| {
2196 thread.push_user_content_block(
2197 Some(message_1_id.clone()),
2198 acp::ContentBlock::Text(acp::TextContent {
2199 annotations: None,
2200 text: "world!".to_string(),
2201 meta: None,
2202 }),
2203 cx,
2204 );
2205 });
2206
2207 thread.update(cx, |thread, cx| {
2208 assert_eq!(thread.entries.len(), 1);
2209 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2210 assert_eq!(user_msg.id, Some(message_1_id));
2211 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2212 } else {
2213 panic!("Expected UserMessage");
2214 }
2215 });
2216
2217 // Test creating new user message after assistant message
2218 thread.update(cx, |thread, cx| {
2219 thread.push_assistant_content_block(
2220 acp::ContentBlock::Text(acp::TextContent {
2221 annotations: None,
2222 text: "Assistant response".to_string(),
2223 meta: None,
2224 }),
2225 false,
2226 cx,
2227 );
2228 });
2229
2230 let message_2_id = UserMessageId::new();
2231 thread.update(cx, |thread, cx| {
2232 thread.push_user_content_block(
2233 Some(message_2_id.clone()),
2234 acp::ContentBlock::Text(acp::TextContent {
2235 annotations: None,
2236 text: "New user message".to_string(),
2237 meta: None,
2238 }),
2239 cx,
2240 );
2241 });
2242
2243 thread.update(cx, |thread, cx| {
2244 assert_eq!(thread.entries.len(), 3);
2245 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2246 assert_eq!(user_msg.id, Some(message_2_id));
2247 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2248 } else {
2249 panic!("Expected UserMessage at index 2");
2250 }
2251 });
2252 }
2253
2254 #[gpui::test]
2255 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2256 init_test(cx);
2257
2258 let fs = FakeFs::new(cx.executor());
2259 let project = Project::test(fs, [], cx).await;
2260 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2261 |_, thread, mut cx| {
2262 async move {
2263 thread.update(&mut cx, |thread, cx| {
2264 thread
2265 .handle_session_update(
2266 acp::SessionUpdate::AgentThoughtChunk {
2267 content: "Thinking ".into(),
2268 },
2269 cx,
2270 )
2271 .unwrap();
2272 thread
2273 .handle_session_update(
2274 acp::SessionUpdate::AgentThoughtChunk {
2275 content: "hard!".into(),
2276 },
2277 cx,
2278 )
2279 .unwrap();
2280 })?;
2281 Ok(acp::PromptResponse {
2282 stop_reason: acp::StopReason::EndTurn,
2283 meta: None,
2284 })
2285 }
2286 .boxed_local()
2287 },
2288 ));
2289
2290 let thread = cx
2291 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2292 .await
2293 .unwrap();
2294
2295 thread
2296 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2297 .await
2298 .unwrap();
2299
2300 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2301 assert_eq!(
2302 output,
2303 indoc! {r#"
2304 ## User
2305
2306 Hello from Zed!
2307
2308 ## Assistant
2309
2310 <thinking>
2311 Thinking hard!
2312 </thinking>
2313
2314 "#}
2315 );
2316 }
2317
2318 #[gpui::test]
2319 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2320 init_test(cx);
2321
2322 let fs = FakeFs::new(cx.executor());
2323 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2324 .await;
2325 let project = Project::test(fs.clone(), [], cx).await;
2326 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2327 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2328 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2329 move |_, thread, mut cx| {
2330 let read_file_tx = read_file_tx.clone();
2331 async move {
2332 let content = thread
2333 .update(&mut cx, |thread, cx| {
2334 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2335 })
2336 .unwrap()
2337 .await
2338 .unwrap();
2339 assert_eq!(content, "one\ntwo\nthree\n");
2340 read_file_tx.take().unwrap().send(()).unwrap();
2341 thread
2342 .update(&mut cx, |thread, cx| {
2343 thread.write_text_file(
2344 path!("/tmp/foo").into(),
2345 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2346 cx,
2347 )
2348 })
2349 .unwrap()
2350 .await
2351 .unwrap();
2352 Ok(acp::PromptResponse {
2353 stop_reason: acp::StopReason::EndTurn,
2354 meta: None,
2355 })
2356 }
2357 .boxed_local()
2358 },
2359 ));
2360
2361 let (worktree, pathbuf) = project
2362 .update(cx, |project, cx| {
2363 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2364 })
2365 .await
2366 .unwrap();
2367 let buffer = project
2368 .update(cx, |project, cx| {
2369 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2370 })
2371 .await
2372 .unwrap();
2373
2374 let thread = cx
2375 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2376 .await
2377 .unwrap();
2378
2379 let request = thread.update(cx, |thread, cx| {
2380 thread.send_raw("Extend the count in /tmp/foo", cx)
2381 });
2382 read_file_rx.await.ok();
2383 buffer.update(cx, |buffer, cx| {
2384 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2385 });
2386 cx.run_until_parked();
2387 assert_eq!(
2388 buffer.read_with(cx, |buffer, _| buffer.text()),
2389 "zero\none\ntwo\nthree\nfour\nfive\n"
2390 );
2391 assert_eq!(
2392 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2393 "zero\none\ntwo\nthree\nfour\nfive\n"
2394 );
2395 request.await.unwrap();
2396 }
2397
2398 #[gpui::test]
2399 async fn test_reading_from_line(cx: &mut TestAppContext) {
2400 init_test(cx);
2401
2402 let fs = FakeFs::new(cx.executor());
2403 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2404 .await;
2405 let project = Project::test(fs.clone(), [], cx).await;
2406 project
2407 .update(cx, |project, cx| {
2408 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2409 })
2410 .await
2411 .unwrap();
2412
2413 let connection = Rc::new(FakeAgentConnection::new());
2414
2415 let thread = cx
2416 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2417 .await
2418 .unwrap();
2419
2420 // Whole file
2421 let content = thread
2422 .update(cx, |thread, cx| {
2423 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2424 })
2425 .await
2426 .unwrap();
2427
2428 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2429
2430 // Only start line
2431 let content = thread
2432 .update(cx, |thread, cx| {
2433 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2434 })
2435 .await
2436 .unwrap();
2437
2438 assert_eq!(content, "three\nfour\n");
2439
2440 // Only limit
2441 let content = thread
2442 .update(cx, |thread, cx| {
2443 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2444 })
2445 .await
2446 .unwrap();
2447
2448 assert_eq!(content, "one\ntwo\n");
2449
2450 // Range
2451 let content = thread
2452 .update(cx, |thread, cx| {
2453 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2454 })
2455 .await
2456 .unwrap();
2457
2458 assert_eq!(content, "two\nthree\n");
2459
2460 // Invalid
2461 let err = thread
2462 .update(cx, |thread, cx| {
2463 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2464 })
2465 .await
2466 .unwrap_err();
2467
2468 assert_eq!(
2469 err.to_string(),
2470 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2471 );
2472 }
2473
2474 #[gpui::test]
2475 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2476 init_test(cx);
2477
2478 let fs = FakeFs::new(cx.executor());
2479 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2480 let project = Project::test(fs.clone(), [], cx).await;
2481 project
2482 .update(cx, |project, cx| {
2483 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2484 })
2485 .await
2486 .unwrap();
2487
2488 let connection = Rc::new(FakeAgentConnection::new());
2489
2490 let thread = cx
2491 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2492 .await
2493 .unwrap();
2494
2495 // Whole file
2496 let content = thread
2497 .update(cx, |thread, cx| {
2498 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2499 })
2500 .await
2501 .unwrap();
2502
2503 assert_eq!(content, "");
2504
2505 // Only start line
2506 let content = thread
2507 .update(cx, |thread, cx| {
2508 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2509 })
2510 .await
2511 .unwrap();
2512
2513 assert_eq!(content, "");
2514
2515 // Only limit
2516 let content = thread
2517 .update(cx, |thread, cx| {
2518 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2519 })
2520 .await
2521 .unwrap();
2522
2523 assert_eq!(content, "");
2524
2525 // Range
2526 let content = thread
2527 .update(cx, |thread, cx| {
2528 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2529 })
2530 .await
2531 .unwrap();
2532
2533 assert_eq!(content, "");
2534
2535 // Invalid
2536 let err = thread
2537 .update(cx, |thread, cx| {
2538 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2539 })
2540 .await
2541 .unwrap_err();
2542
2543 assert_eq!(
2544 err.to_string(),
2545 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2546 );
2547 }
2548 #[gpui::test]
2549 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2550 init_test(cx);
2551
2552 let fs = FakeFs::new(cx.executor());
2553 fs.insert_tree(path!("/tmp"), json!({})).await;
2554 let project = Project::test(fs.clone(), [], cx).await;
2555 project
2556 .update(cx, |project, cx| {
2557 project.find_or_create_worktree(path!("/tmp"), true, cx)
2558 })
2559 .await
2560 .unwrap();
2561
2562 let connection = Rc::new(FakeAgentConnection::new());
2563
2564 let thread = cx
2565 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2566 .await
2567 .unwrap();
2568
2569 // Out of project file
2570 let err = thread
2571 .update(cx, |thread, cx| {
2572 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2573 })
2574 .await
2575 .unwrap_err();
2576
2577 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2578 }
2579
2580 #[gpui::test]
2581 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2582 init_test(cx);
2583
2584 let fs = FakeFs::new(cx.executor());
2585 let project = Project::test(fs, [], cx).await;
2586 let id = acp::ToolCallId("test".into());
2587
2588 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2589 let id = id.clone();
2590 move |_, thread, mut cx| {
2591 let id = id.clone();
2592 async move {
2593 thread
2594 .update(&mut cx, |thread, cx| {
2595 thread.handle_session_update(
2596 acp::SessionUpdate::ToolCall(acp::ToolCall {
2597 id: id.clone(),
2598 title: "Label".into(),
2599 kind: acp::ToolKind::Fetch,
2600 status: acp::ToolCallStatus::InProgress,
2601 content: vec![],
2602 locations: vec![],
2603 raw_input: None,
2604 raw_output: None,
2605 meta: None,
2606 }),
2607 cx,
2608 )
2609 })
2610 .unwrap()
2611 .unwrap();
2612 Ok(acp::PromptResponse {
2613 stop_reason: acp::StopReason::EndTurn,
2614 meta: None,
2615 })
2616 }
2617 .boxed_local()
2618 }
2619 }));
2620
2621 let thread = cx
2622 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2623 .await
2624 .unwrap();
2625
2626 let request = thread.update(cx, |thread, cx| {
2627 thread.send_raw("Fetch https://example.com", cx)
2628 });
2629
2630 run_until_first_tool_call(&thread, cx).await;
2631
2632 thread.read_with(cx, |thread, _| {
2633 assert!(matches!(
2634 thread.entries[1],
2635 AgentThreadEntry::ToolCall(ToolCall {
2636 status: ToolCallStatus::InProgress,
2637 ..
2638 })
2639 ));
2640 });
2641
2642 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2643
2644 thread.read_with(cx, |thread, _| {
2645 assert!(matches!(
2646 &thread.entries[1],
2647 AgentThreadEntry::ToolCall(ToolCall {
2648 status: ToolCallStatus::Canceled,
2649 ..
2650 })
2651 ));
2652 });
2653
2654 thread
2655 .update(cx, |thread, cx| {
2656 thread.handle_session_update(
2657 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2658 id,
2659 fields: acp::ToolCallUpdateFields {
2660 status: Some(acp::ToolCallStatus::Completed),
2661 ..Default::default()
2662 },
2663 meta: None,
2664 }),
2665 cx,
2666 )
2667 })
2668 .unwrap();
2669
2670 request.await.unwrap();
2671
2672 thread.read_with(cx, |thread, _| {
2673 assert!(matches!(
2674 thread.entries[1],
2675 AgentThreadEntry::ToolCall(ToolCall {
2676 status: ToolCallStatus::Completed,
2677 ..
2678 })
2679 ));
2680 });
2681 }
2682
2683 #[gpui::test]
2684 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2685 init_test(cx);
2686 let fs = FakeFs::new(cx.background_executor.clone());
2687 fs.insert_tree(path!("/test"), json!({})).await;
2688 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2689
2690 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2691 move |_, thread, mut cx| {
2692 async move {
2693 thread
2694 .update(&mut cx, |thread, cx| {
2695 thread.handle_session_update(
2696 acp::SessionUpdate::ToolCall(acp::ToolCall {
2697 id: acp::ToolCallId("test".into()),
2698 title: "Label".into(),
2699 kind: acp::ToolKind::Edit,
2700 status: acp::ToolCallStatus::Completed,
2701 content: vec![acp::ToolCallContent::Diff {
2702 diff: acp::Diff {
2703 path: "/test/test.txt".into(),
2704 old_text: None,
2705 new_text: "foo".into(),
2706 meta: None,
2707 },
2708 }],
2709 locations: vec![],
2710 raw_input: None,
2711 raw_output: None,
2712 meta: None,
2713 }),
2714 cx,
2715 )
2716 })
2717 .unwrap()
2718 .unwrap();
2719 Ok(acp::PromptResponse {
2720 stop_reason: acp::StopReason::EndTurn,
2721 meta: None,
2722 })
2723 }
2724 .boxed_local()
2725 }
2726 }));
2727
2728 let thread = cx
2729 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2730 .await
2731 .unwrap();
2732
2733 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2734 .await
2735 .unwrap();
2736
2737 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2738 }
2739
2740 #[gpui::test(iterations = 10)]
2741 async fn test_checkpoints(cx: &mut TestAppContext) {
2742 init_test(cx);
2743 let fs = FakeFs::new(cx.background_executor.clone());
2744 fs.insert_tree(
2745 path!("/test"),
2746 json!({
2747 ".git": {}
2748 }),
2749 )
2750 .await;
2751 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2752
2753 let simulate_changes = Arc::new(AtomicBool::new(true));
2754 let next_filename = Arc::new(AtomicUsize::new(0));
2755 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2756 let simulate_changes = simulate_changes.clone();
2757 let next_filename = next_filename.clone();
2758 let fs = fs.clone();
2759 move |request, thread, mut cx| {
2760 let fs = fs.clone();
2761 let simulate_changes = simulate_changes.clone();
2762 let next_filename = next_filename.clone();
2763 async move {
2764 if simulate_changes.load(SeqCst) {
2765 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2766 fs.write(Path::new(&filename), b"").await?;
2767 }
2768
2769 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2770 panic!("expected text content block");
2771 };
2772 thread.update(&mut cx, |thread, cx| {
2773 thread
2774 .handle_session_update(
2775 acp::SessionUpdate::AgentMessageChunk {
2776 content: content.text.to_uppercase().into(),
2777 },
2778 cx,
2779 )
2780 .unwrap();
2781 })?;
2782 Ok(acp::PromptResponse {
2783 stop_reason: acp::StopReason::EndTurn,
2784 meta: None,
2785 })
2786 }
2787 .boxed_local()
2788 }
2789 }));
2790 let thread = cx
2791 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2792 .await
2793 .unwrap();
2794
2795 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2796 .await
2797 .unwrap();
2798 thread.read_with(cx, |thread, cx| {
2799 assert_eq!(
2800 thread.to_markdown(cx),
2801 indoc! {"
2802 ## User (checkpoint)
2803
2804 Lorem
2805
2806 ## Assistant
2807
2808 LOREM
2809
2810 "}
2811 );
2812 });
2813 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2814
2815 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2816 .await
2817 .unwrap();
2818 thread.read_with(cx, |thread, cx| {
2819 assert_eq!(
2820 thread.to_markdown(cx),
2821 indoc! {"
2822 ## User (checkpoint)
2823
2824 Lorem
2825
2826 ## Assistant
2827
2828 LOREM
2829
2830 ## User (checkpoint)
2831
2832 ipsum
2833
2834 ## Assistant
2835
2836 IPSUM
2837
2838 "}
2839 );
2840 });
2841 assert_eq!(
2842 fs.files(),
2843 vec![
2844 Path::new(path!("/test/file-0")),
2845 Path::new(path!("/test/file-1"))
2846 ]
2847 );
2848
2849 // Checkpoint isn't stored when there are no changes.
2850 simulate_changes.store(false, SeqCst);
2851 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2852 .await
2853 .unwrap();
2854 thread.read_with(cx, |thread, cx| {
2855 assert_eq!(
2856 thread.to_markdown(cx),
2857 indoc! {"
2858 ## User (checkpoint)
2859
2860 Lorem
2861
2862 ## Assistant
2863
2864 LOREM
2865
2866 ## User (checkpoint)
2867
2868 ipsum
2869
2870 ## Assistant
2871
2872 IPSUM
2873
2874 ## User
2875
2876 dolor
2877
2878 ## Assistant
2879
2880 DOLOR
2881
2882 "}
2883 );
2884 });
2885 assert_eq!(
2886 fs.files(),
2887 vec![
2888 Path::new(path!("/test/file-0")),
2889 Path::new(path!("/test/file-1"))
2890 ]
2891 );
2892
2893 // Rewinding the conversation truncates the history and restores the checkpoint.
2894 thread
2895 .update(cx, |thread, cx| {
2896 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2897 panic!("unexpected entries {:?}", thread.entries)
2898 };
2899 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2900 })
2901 .await
2902 .unwrap();
2903 thread.read_with(cx, |thread, cx| {
2904 assert_eq!(
2905 thread.to_markdown(cx),
2906 indoc! {"
2907 ## User (checkpoint)
2908
2909 Lorem
2910
2911 ## Assistant
2912
2913 LOREM
2914
2915 "}
2916 );
2917 });
2918 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2919 }
2920
2921 #[gpui::test]
2922 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2923 use std::sync::atomic::AtomicUsize;
2924 init_test(cx);
2925
2926 let fs = FakeFs::new(cx.executor());
2927 let project = Project::test(fs, None, cx).await;
2928
2929 // Create a connection that simulates refusal after tool result
2930 let prompt_count = Arc::new(AtomicUsize::new(0));
2931 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2932 let prompt_count = prompt_count.clone();
2933 move |_request, thread, mut cx| {
2934 let count = prompt_count.fetch_add(1, SeqCst);
2935 async move {
2936 if count == 0 {
2937 // First prompt: Generate a tool call with result
2938 thread.update(&mut cx, |thread, cx| {
2939 thread
2940 .handle_session_update(
2941 acp::SessionUpdate::ToolCall(acp::ToolCall {
2942 id: acp::ToolCallId("tool1".into()),
2943 title: "Test Tool".into(),
2944 kind: acp::ToolKind::Fetch,
2945 status: acp::ToolCallStatus::Completed,
2946 content: vec![],
2947 locations: vec![],
2948 raw_input: Some(serde_json::json!({"query": "test"})),
2949 raw_output: Some(
2950 serde_json::json!({"result": "inappropriate content"}),
2951 ),
2952 meta: None,
2953 }),
2954 cx,
2955 )
2956 .unwrap();
2957 })?;
2958
2959 // Now return refusal because of the tool result
2960 Ok(acp::PromptResponse {
2961 stop_reason: acp::StopReason::Refusal,
2962 meta: None,
2963 })
2964 } else {
2965 Ok(acp::PromptResponse {
2966 stop_reason: acp::StopReason::EndTurn,
2967 meta: None,
2968 })
2969 }
2970 }
2971 .boxed_local()
2972 }
2973 }));
2974
2975 let thread = cx
2976 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2977 .await
2978 .unwrap();
2979
2980 // Track if we see a Refusal event
2981 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2982 let saw_refusal_event_captured = saw_refusal_event.clone();
2983 thread.update(cx, |_thread, cx| {
2984 cx.subscribe(
2985 &thread,
2986 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2987 if matches!(event, AcpThreadEvent::Refusal) {
2988 *saw_refusal_event_captured.lock().unwrap() = true;
2989 }
2990 },
2991 )
2992 .detach();
2993 });
2994
2995 // Send a user message - this will trigger tool call and then refusal
2996 let send_task = thread.update(cx, |thread, cx| {
2997 thread.send(
2998 vec![acp::ContentBlock::Text(acp::TextContent {
2999 text: "Hello".into(),
3000 annotations: None,
3001 meta: None,
3002 })],
3003 cx,
3004 )
3005 });
3006 cx.background_executor.spawn(send_task).detach();
3007 cx.run_until_parked();
3008
3009 // Verify that:
3010 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3011 // 2. The user message was NOT truncated
3012 assert!(
3013 *saw_refusal_event.lock().unwrap(),
3014 "Refusal event should be emitted for tool result refusals"
3015 );
3016
3017 thread.read_with(cx, |thread, _| {
3018 let entries = thread.entries();
3019 assert!(entries.len() >= 2, "Should have user message and tool call");
3020
3021 // Verify user message is still there
3022 assert!(
3023 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3024 "User message should not be truncated"
3025 );
3026
3027 // Verify tool call is there with result
3028 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3029 assert!(
3030 tool_call.raw_output.is_some(),
3031 "Tool call should have output"
3032 );
3033 } else {
3034 panic!("Expected tool call at index 1");
3035 }
3036 });
3037 }
3038
3039 #[gpui::test]
3040 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3041 init_test(cx);
3042
3043 let fs = FakeFs::new(cx.executor());
3044 let project = Project::test(fs, None, cx).await;
3045
3046 let refuse_next = Arc::new(AtomicBool::new(false));
3047 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3048 let refuse_next = refuse_next.clone();
3049 move |_request, _thread, _cx| {
3050 if refuse_next.load(SeqCst) {
3051 async move {
3052 Ok(acp::PromptResponse {
3053 stop_reason: acp::StopReason::Refusal,
3054 meta: None,
3055 })
3056 }
3057 .boxed_local()
3058 } else {
3059 async move {
3060 Ok(acp::PromptResponse {
3061 stop_reason: acp::StopReason::EndTurn,
3062 meta: None,
3063 })
3064 }
3065 .boxed_local()
3066 }
3067 }
3068 }));
3069
3070 let thread = cx
3071 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3072 .await
3073 .unwrap();
3074
3075 // Track if we see a Refusal event
3076 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3077 let saw_refusal_event_captured = saw_refusal_event.clone();
3078 thread.update(cx, |_thread, cx| {
3079 cx.subscribe(
3080 &thread,
3081 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3082 if matches!(event, AcpThreadEvent::Refusal) {
3083 *saw_refusal_event_captured.lock().unwrap() = true;
3084 }
3085 },
3086 )
3087 .detach();
3088 });
3089
3090 // Send a message that will be refused
3091 refuse_next.store(true, SeqCst);
3092 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3093 .await
3094 .unwrap();
3095
3096 // Verify that a Refusal event WAS emitted for user prompt refusal
3097 assert!(
3098 *saw_refusal_event.lock().unwrap(),
3099 "Refusal event should be emitted for user prompt refusals"
3100 );
3101
3102 // Verify the message was truncated (user prompt refusal)
3103 thread.read_with(cx, |thread, cx| {
3104 assert_eq!(thread.to_markdown(cx), "");
3105 });
3106 }
3107
3108 #[gpui::test]
3109 async fn test_refusal(cx: &mut TestAppContext) {
3110 init_test(cx);
3111 let fs = FakeFs::new(cx.background_executor.clone());
3112 fs.insert_tree(path!("/"), json!({})).await;
3113 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3114
3115 let refuse_next = Arc::new(AtomicBool::new(false));
3116 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3117 let refuse_next = refuse_next.clone();
3118 move |request, thread, mut cx| {
3119 let refuse_next = refuse_next.clone();
3120 async move {
3121 if refuse_next.load(SeqCst) {
3122 return Ok(acp::PromptResponse {
3123 stop_reason: acp::StopReason::Refusal,
3124 meta: None,
3125 });
3126 }
3127
3128 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3129 panic!("expected text content block");
3130 };
3131 thread.update(&mut cx, |thread, cx| {
3132 thread
3133 .handle_session_update(
3134 acp::SessionUpdate::AgentMessageChunk {
3135 content: content.text.to_uppercase().into(),
3136 },
3137 cx,
3138 )
3139 .unwrap();
3140 })?;
3141 Ok(acp::PromptResponse {
3142 stop_reason: acp::StopReason::EndTurn,
3143 meta: None,
3144 })
3145 }
3146 .boxed_local()
3147 }
3148 }));
3149 let thread = cx
3150 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3151 .await
3152 .unwrap();
3153
3154 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3155 .await
3156 .unwrap();
3157 thread.read_with(cx, |thread, cx| {
3158 assert_eq!(
3159 thread.to_markdown(cx),
3160 indoc! {"
3161 ## User
3162
3163 hello
3164
3165 ## Assistant
3166
3167 HELLO
3168
3169 "}
3170 );
3171 });
3172
3173 // Simulate refusing the second message. The message should be truncated
3174 // when a user prompt is refused.
3175 refuse_next.store(true, SeqCst);
3176 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3177 .await
3178 .unwrap();
3179 thread.read_with(cx, |thread, cx| {
3180 assert_eq!(
3181 thread.to_markdown(cx),
3182 indoc! {"
3183 ## User
3184
3185 hello
3186
3187 ## Assistant
3188
3189 HELLO
3190
3191 "}
3192 );
3193 });
3194 }
3195
3196 async fn run_until_first_tool_call(
3197 thread: &Entity<AcpThread>,
3198 cx: &mut TestAppContext,
3199 ) -> usize {
3200 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3201
3202 let subscription = cx.update(|cx| {
3203 cx.subscribe(thread, move |thread, _, cx| {
3204 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3205 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3206 return tx.try_send(ix).unwrap();
3207 }
3208 }
3209 })
3210 });
3211
3212 select! {
3213 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3214 panic!("Timeout waiting for tool call")
3215 }
3216 ix = rx.next().fuse() => {
3217 drop(subscription);
3218 ix.unwrap()
3219 }
3220 }
3221 }
3222
3223 #[derive(Clone, Default)]
3224 struct FakeAgentConnection {
3225 auth_methods: Vec<acp::AuthMethod>,
3226 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3227 on_user_message: Option<
3228 Rc<
3229 dyn Fn(
3230 acp::PromptRequest,
3231 WeakEntity<AcpThread>,
3232 AsyncApp,
3233 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3234 + 'static,
3235 >,
3236 >,
3237 }
3238
3239 impl FakeAgentConnection {
3240 fn new() -> Self {
3241 Self {
3242 auth_methods: Vec::new(),
3243 on_user_message: None,
3244 sessions: Arc::default(),
3245 }
3246 }
3247
3248 #[expect(unused)]
3249 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3250 self.auth_methods = auth_methods;
3251 self
3252 }
3253
3254 fn on_user_message(
3255 mut self,
3256 handler: impl Fn(
3257 acp::PromptRequest,
3258 WeakEntity<AcpThread>,
3259 AsyncApp,
3260 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3261 + 'static,
3262 ) -> Self {
3263 self.on_user_message.replace(Rc::new(handler));
3264 self
3265 }
3266 }
3267
3268 impl AgentConnection for FakeAgentConnection {
3269 fn auth_methods(&self) -> &[acp::AuthMethod] {
3270 &self.auth_methods
3271 }
3272
3273 fn new_thread(
3274 self: Rc<Self>,
3275 project: Entity<Project>,
3276 _cwd: &Path,
3277 cx: &mut App,
3278 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3279 let session_id = acp::SessionId(
3280 rand::rng()
3281 .sample_iter(&distr::Alphanumeric)
3282 .take(7)
3283 .map(char::from)
3284 .collect::<String>()
3285 .into(),
3286 );
3287 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3288 let thread = cx.new(|cx| {
3289 AcpThread::new(
3290 "Test",
3291 self.clone(),
3292 project,
3293 action_log,
3294 session_id.clone(),
3295 watch::Receiver::constant(acp::PromptCapabilities {
3296 image: true,
3297 audio: true,
3298 embedded_context: true,
3299 meta: None,
3300 }),
3301 cx,
3302 )
3303 });
3304 self.sessions.lock().insert(session_id, thread.downgrade());
3305 Task::ready(Ok(thread))
3306 }
3307
3308 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3309 if self.auth_methods().iter().any(|m| m.id == method) {
3310 Task::ready(Ok(()))
3311 } else {
3312 Task::ready(Err(anyhow!("Invalid Auth Method")))
3313 }
3314 }
3315
3316 fn prompt(
3317 &self,
3318 _id: Option<UserMessageId>,
3319 params: acp::PromptRequest,
3320 cx: &mut App,
3321 ) -> Task<gpui::Result<acp::PromptResponse>> {
3322 let sessions = self.sessions.lock();
3323 let thread = sessions.get(¶ms.session_id).unwrap();
3324 if let Some(handler) = &self.on_user_message {
3325 let handler = handler.clone();
3326 let thread = thread.clone();
3327 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3328 } else {
3329 Task::ready(Ok(acp::PromptResponse {
3330 stop_reason: acp::StopReason::EndTurn,
3331 meta: None,
3332 }))
3333 }
3334 }
3335
3336 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3337 let sessions = self.sessions.lock();
3338 let thread = sessions.get(session_id).unwrap().clone();
3339
3340 cx.spawn(async move |cx| {
3341 thread
3342 .update(cx, |thread, cx| thread.cancel(cx))
3343 .unwrap()
3344 .await
3345 })
3346 .detach();
3347 }
3348
3349 fn truncate(
3350 &self,
3351 session_id: &acp::SessionId,
3352 _cx: &App,
3353 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3354 Some(Rc::new(FakeAgentSessionEditor {
3355 _session_id: session_id.clone(),
3356 }))
3357 }
3358
3359 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3360 self
3361 }
3362 }
3363
3364 struct FakeAgentSessionEditor {
3365 _session_id: acp::SessionId,
3366 }
3367
3368 impl AgentSessionTruncate for FakeAgentSessionEditor {
3369 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3370 Task::ready(Ok(()))
3371 }
3372 }
3373
3374 #[gpui::test]
3375 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3376 init_test(cx);
3377
3378 let fs = FakeFs::new(cx.executor());
3379 let project = Project::test(fs, [], cx).await;
3380 let connection = Rc::new(FakeAgentConnection::new());
3381 let thread = cx
3382 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3383 .await
3384 .unwrap();
3385
3386 // Try to update a tool call that doesn't exist
3387 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3388 thread.update(cx, |thread, cx| {
3389 let result = thread.handle_session_update(
3390 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3391 id: nonexistent_id.clone(),
3392 fields: acp::ToolCallUpdateFields {
3393 status: Some(acp::ToolCallStatus::Completed),
3394 ..Default::default()
3395 },
3396 meta: None,
3397 }),
3398 cx,
3399 );
3400
3401 // The update should succeed (not return an error)
3402 assert!(result.is_ok());
3403
3404 // There should now be exactly one entry in the thread
3405 assert_eq!(thread.entries.len(), 1);
3406
3407 // The entry should be a failed tool call
3408 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3409 assert_eq!(tool_call.id, nonexistent_id);
3410 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3411 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3412
3413 // Check that the content contains the error message
3414 assert_eq!(tool_call.content.len(), 1);
3415 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3416 match content_block {
3417 ContentBlock::Markdown { markdown } => {
3418 let markdown_text = markdown.read(cx).source();
3419 assert!(markdown_text.contains("Tool call not found"));
3420 }
3421 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3422 ContentBlock::ResourceLink { .. } => {
3423 panic!("Expected markdown content, got resource link")
3424 }
3425 }
3426 } else {
3427 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3428 }
3429 } else {
3430 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3431 }
3432 });
3433 }
3434}