1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 self.label.update(cx, |label, cx| {
318 if self.kind == acp::ToolKind::Execute {
319 label.replace(title, cx);
320 } else if let Some((first_line, _)) = title.split_once("\n") {
321 label.replace(first_line.to_owned() + "…", cx);
322 } else {
323 label.replace(title, cx);
324 }
325 });
326 }
327
328 if let Some(content) = content {
329 let mut new_content_len = content.len();
330 let mut content = content.into_iter();
331
332 // Reuse existing content if we can
333 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
334 let valid_content =
335 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
336 if !valid_content {
337 new_content_len -= 1;
338 }
339 }
340 for new in content {
341 if let Some(new) = ToolCallContent::from_acp(
342 new,
343 language_registry.clone(),
344 path_style,
345 terminals,
346 cx,
347 )? {
348 self.content.push(new);
349 } else {
350 new_content_len -= 1;
351 }
352 }
353 self.content.truncate(new_content_len);
354 }
355
356 if let Some(locations) = locations {
357 self.locations = locations;
358 }
359
360 if let Some(raw_input) = raw_input {
361 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
362 self.raw_input = Some(raw_input);
363 }
364
365 if let Some(raw_output) = raw_output {
366 if self.content.is_empty()
367 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
368 {
369 self.content
370 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
371 markdown,
372 }));
373 }
374 self.raw_output = Some(raw_output);
375 }
376 Ok(())
377 }
378
379 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
380 self.content.iter().filter_map(|content| match content {
381 ToolCallContent::Diff(diff) => Some(diff),
382 ToolCallContent::ContentBlock(_) => None,
383 ToolCallContent::Terminal(_) => None,
384 })
385 }
386
387 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
388 self.content.iter().filter_map(|content| match content {
389 ToolCallContent::Terminal(terminal) => Some(terminal),
390 ToolCallContent::ContentBlock(_) => None,
391 ToolCallContent::Diff(_) => None,
392 })
393 }
394
395 pub fn is_subagent(&self) -> bool {
396 self.tool_name.as_ref().is_some_and(|s| s == "subagent")
397 || self.subagent_session_id.is_some()
398 }
399
400 pub fn to_markdown(&self, cx: &App) -> String {
401 let mut markdown = format!(
402 "**Tool Call: {}**\nStatus: {}\n\n",
403 self.label.read(cx).source(),
404 self.status
405 );
406 for content in &self.content {
407 markdown.push_str(content.to_markdown(cx).as_str());
408 markdown.push_str("\n\n");
409 }
410 markdown
411 }
412
413 async fn resolve_location(
414 location: acp::ToolCallLocation,
415 project: WeakEntity<Project>,
416 cx: &mut AsyncApp,
417 ) -> Option<ResolvedLocation> {
418 let buffer = project
419 .update(cx, |project, cx| {
420 project
421 .project_path_for_absolute_path(&location.path, cx)
422 .map(|path| project.open_buffer(path, cx))
423 })
424 .ok()??;
425 let buffer = buffer.await.log_err()?;
426 let position = buffer.update(cx, |buffer, _| {
427 let snapshot = buffer.snapshot();
428 if let Some(row) = location.line {
429 let column = snapshot.indent_size_for_line(row).len;
430 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
431 snapshot.anchor_before(point)
432 } else {
433 Anchor::min_for_buffer(snapshot.remote_id())
434 }
435 });
436
437 Some(ResolvedLocation { buffer, position })
438 }
439
440 fn resolve_locations(
441 &self,
442 project: Entity<Project>,
443 cx: &mut App,
444 ) -> Task<Vec<Option<ResolvedLocation>>> {
445 let locations = self.locations.clone();
446 project.update(cx, |_, cx| {
447 cx.spawn(async move |project, cx| {
448 let mut new_locations = Vec::new();
449 for location in locations {
450 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
451 }
452 new_locations
453 })
454 })
455 }
456}
457
458// Separate so we can hold a strong reference to the buffer
459// for saving on the thread
460#[derive(Clone, Debug, PartialEq, Eq)]
461struct ResolvedLocation {
462 buffer: Entity<Buffer>,
463 position: Anchor,
464}
465
466impl From<&ResolvedLocation> for AgentLocation {
467 fn from(value: &ResolvedLocation) -> Self {
468 Self {
469 buffer: value.buffer.downgrade(),
470 position: value.position,
471 }
472 }
473}
474
475#[derive(Debug)]
476pub enum ToolCallStatus {
477 /// The tool call hasn't started running yet, but we start showing it to
478 /// the user.
479 Pending,
480 /// The tool call is waiting for confirmation from the user.
481 WaitingForConfirmation {
482 options: PermissionOptions,
483 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
484 },
485 /// The tool call is currently running.
486 InProgress,
487 /// The tool call completed successfully.
488 Completed,
489 /// The tool call failed.
490 Failed,
491 /// The user rejected the tool call.
492 Rejected,
493 /// The user canceled generation so the tool call was canceled.
494 Canceled,
495}
496
497impl From<acp::ToolCallStatus> for ToolCallStatus {
498 fn from(status: acp::ToolCallStatus) -> Self {
499 match status {
500 acp::ToolCallStatus::Pending => Self::Pending,
501 acp::ToolCallStatus::InProgress => Self::InProgress,
502 acp::ToolCallStatus::Completed => Self::Completed,
503 acp::ToolCallStatus::Failed => Self::Failed,
504 _ => Self::Pending,
505 }
506 }
507}
508
509impl Display for ToolCallStatus {
510 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
511 write!(
512 f,
513 "{}",
514 match self {
515 ToolCallStatus::Pending => "Pending",
516 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
517 ToolCallStatus::InProgress => "In Progress",
518 ToolCallStatus::Completed => "Completed",
519 ToolCallStatus::Failed => "Failed",
520 ToolCallStatus::Rejected => "Rejected",
521 ToolCallStatus::Canceled => "Canceled",
522 }
523 )
524 }
525}
526
527#[derive(Debug, PartialEq, Clone)]
528pub enum ContentBlock {
529 Empty,
530 Markdown { markdown: Entity<Markdown> },
531 ResourceLink { resource_link: acp::ResourceLink },
532 Image { image: Arc<gpui::Image> },
533}
534
535impl ContentBlock {
536 pub fn new(
537 block: acp::ContentBlock,
538 language_registry: &Arc<LanguageRegistry>,
539 path_style: PathStyle,
540 cx: &mut App,
541 ) -> Self {
542 let mut this = Self::Empty;
543 this.append(block, language_registry, path_style, cx);
544 this
545 }
546
547 pub fn new_combined(
548 blocks: impl IntoIterator<Item = acp::ContentBlock>,
549 language_registry: Arc<LanguageRegistry>,
550 path_style: PathStyle,
551 cx: &mut App,
552 ) -> Self {
553 let mut this = Self::Empty;
554 for block in blocks {
555 this.append(block, &language_registry, path_style, cx);
556 }
557 this
558 }
559
560 pub fn append(
561 &mut self,
562 block: acp::ContentBlock,
563 language_registry: &Arc<LanguageRegistry>,
564 path_style: PathStyle,
565 cx: &mut App,
566 ) {
567 match (&mut *self, &block) {
568 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
569 *self = ContentBlock::ResourceLink {
570 resource_link: resource_link.clone(),
571 };
572 }
573 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
574 if let Some(image) = Self::decode_image(image_content) {
575 *self = ContentBlock::Image { image };
576 } else {
577 let new_content = Self::image_md(image_content);
578 *self = Self::create_markdown_block(new_content, language_registry, cx);
579 }
580 }
581 (ContentBlock::Empty, _) => {
582 let new_content = Self::block_string_contents(&block, path_style);
583 *self = Self::create_markdown_block(new_content, language_registry, cx);
584 }
585 (ContentBlock::Markdown { markdown }, _) => {
586 let new_content = Self::block_string_contents(&block, path_style);
587 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
588 }
589 (ContentBlock::ResourceLink { resource_link }, _) => {
590 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
591 let new_content = Self::block_string_contents(&block, path_style);
592 let combined = format!("{}\n{}", existing_content, new_content);
593 *self = Self::create_markdown_block(combined, language_registry, cx);
594 }
595 (ContentBlock::Image { .. }, _) => {
596 let new_content = Self::block_string_contents(&block, path_style);
597 let combined = format!("`Image`\n{}", new_content);
598 *self = Self::create_markdown_block(combined, language_registry, cx);
599 }
600 }
601 }
602
603 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
604 use base64::Engine as _;
605
606 let bytes = base64::engine::general_purpose::STANDARD
607 .decode(image_content.data.as_bytes())
608 .ok()?;
609 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
610 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
611 }
612
613 fn create_markdown_block(
614 content: String,
615 language_registry: &Arc<LanguageRegistry>,
616 cx: &mut App,
617 ) -> ContentBlock {
618 ContentBlock::Markdown {
619 markdown: cx
620 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
621 }
622 }
623
624 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
625 match block {
626 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
627 acp::ContentBlock::ResourceLink(resource_link) => {
628 Self::resource_link_md(&resource_link.uri, path_style)
629 }
630 acp::ContentBlock::Resource(acp::EmbeddedResource {
631 resource:
632 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
633 uri,
634 ..
635 }),
636 ..
637 }) => Self::resource_link_md(uri, path_style),
638 acp::ContentBlock::Image(image) => Self::image_md(image),
639 _ => String::new(),
640 }
641 }
642
643 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
644 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
645 uri.as_link().to_string()
646 } else {
647 uri.to_string()
648 }
649 }
650
651 fn image_md(_image: &acp::ImageContent) -> String {
652 "`Image`".into()
653 }
654
655 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
656 match self {
657 ContentBlock::Empty => "",
658 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
659 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
660 ContentBlock::Image { .. } => "`Image`",
661 }
662 }
663
664 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
665 match self {
666 ContentBlock::Empty => None,
667 ContentBlock::Markdown { markdown } => Some(markdown),
668 ContentBlock::ResourceLink { .. } => None,
669 ContentBlock::Image { .. } => None,
670 }
671 }
672
673 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
674 match self {
675 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
676 _ => None,
677 }
678 }
679
680 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
681 match self {
682 ContentBlock::Image { image } => Some(image),
683 _ => None,
684 }
685 }
686}
687
688#[derive(Debug)]
689pub enum ToolCallContent {
690 ContentBlock(ContentBlock),
691 Diff(Entity<Diff>),
692 Terminal(Entity<Terminal>),
693}
694
695impl ToolCallContent {
696 pub fn from_acp(
697 content: acp::ToolCallContent,
698 language_registry: Arc<LanguageRegistry>,
699 path_style: PathStyle,
700 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
701 cx: &mut App,
702 ) -> Result<Option<Self>> {
703 match content {
704 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
705 Ok(Some(Self::ContentBlock(ContentBlock::new(
706 content,
707 &language_registry,
708 path_style,
709 cx,
710 ))))
711 }
712 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
713 Diff::finalized(
714 diff.path.to_string_lossy().into_owned(),
715 diff.old_text,
716 diff.new_text,
717 language_registry,
718 cx,
719 )
720 })))),
721 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
722 .get(&terminal_id)
723 .cloned()
724 .map(|terminal| Some(Self::Terminal(terminal)))
725 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
726 _ => Ok(None),
727 }
728 }
729
730 pub fn update_from_acp(
731 &mut self,
732 new: acp::ToolCallContent,
733 language_registry: Arc<LanguageRegistry>,
734 path_style: PathStyle,
735 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
736 cx: &mut App,
737 ) -> Result<bool> {
738 let needs_update = match (&self, &new) {
739 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
740 old_diff.read(cx).needs_update(
741 new_diff.old_text.as_deref().unwrap_or(""),
742 &new_diff.new_text,
743 cx,
744 )
745 }
746 _ => true,
747 };
748
749 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
750 if needs_update {
751 *self = update;
752 }
753 Ok(true)
754 } else {
755 Ok(false)
756 }
757 }
758
759 pub fn to_markdown(&self, cx: &App) -> String {
760 match self {
761 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
762 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
763 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
764 }
765 }
766
767 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
768 match self {
769 Self::ContentBlock(content) => content.image(),
770 _ => None,
771 }
772 }
773}
774
775#[derive(Debug, PartialEq)]
776pub enum ToolCallUpdate {
777 UpdateFields(acp::ToolCallUpdate),
778 UpdateDiff(ToolCallUpdateDiff),
779 UpdateTerminal(ToolCallUpdateTerminal),
780}
781
782impl ToolCallUpdate {
783 fn id(&self) -> &acp::ToolCallId {
784 match self {
785 Self::UpdateFields(update) => &update.tool_call_id,
786 Self::UpdateDiff(diff) => &diff.id,
787 Self::UpdateTerminal(terminal) => &terminal.id,
788 }
789 }
790}
791
792impl From<acp::ToolCallUpdate> for ToolCallUpdate {
793 fn from(update: acp::ToolCallUpdate) -> Self {
794 Self::UpdateFields(update)
795 }
796}
797
798impl From<ToolCallUpdateDiff> for ToolCallUpdate {
799 fn from(diff: ToolCallUpdateDiff) -> Self {
800 Self::UpdateDiff(diff)
801 }
802}
803
804#[derive(Debug, PartialEq)]
805pub struct ToolCallUpdateDiff {
806 pub id: acp::ToolCallId,
807 pub diff: Entity<Diff>,
808}
809
810impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
811 fn from(terminal: ToolCallUpdateTerminal) -> Self {
812 Self::UpdateTerminal(terminal)
813 }
814}
815
816#[derive(Debug, PartialEq)]
817pub struct ToolCallUpdateTerminal {
818 pub id: acp::ToolCallId,
819 pub terminal: Entity<Terminal>,
820}
821
822#[derive(Debug, Default)]
823pub struct Plan {
824 pub entries: Vec<PlanEntry>,
825}
826
827#[derive(Debug)]
828pub struct PlanStats<'a> {
829 pub in_progress_entry: Option<&'a PlanEntry>,
830 pub pending: u32,
831 pub completed: u32,
832}
833
834impl Plan {
835 pub fn is_empty(&self) -> bool {
836 self.entries.is_empty()
837 }
838
839 pub fn stats(&self) -> PlanStats<'_> {
840 let mut stats = PlanStats {
841 in_progress_entry: None,
842 pending: 0,
843 completed: 0,
844 };
845
846 for entry in &self.entries {
847 match &entry.status {
848 acp::PlanEntryStatus::Pending => {
849 stats.pending += 1;
850 }
851 acp::PlanEntryStatus::InProgress => {
852 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
853 }
854 acp::PlanEntryStatus::Completed => {
855 stats.completed += 1;
856 }
857 _ => {}
858 }
859 }
860
861 stats
862 }
863}
864
865#[derive(Debug)]
866pub struct PlanEntry {
867 pub content: Entity<Markdown>,
868 pub priority: acp::PlanEntryPriority,
869 pub status: acp::PlanEntryStatus,
870}
871
872impl PlanEntry {
873 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
874 Self {
875 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
876 priority: entry.priority,
877 status: entry.status,
878 }
879 }
880}
881
882#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
883pub struct TokenUsage {
884 pub max_tokens: u64,
885 pub used_tokens: u64,
886 pub input_tokens: u64,
887 pub output_tokens: u64,
888}
889
890impl TokenUsage {
891 pub fn ratio(&self) -> TokenUsageRatio {
892 #[cfg(debug_assertions)]
893 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
894 .unwrap_or("0.8".to_string())
895 .parse()
896 .unwrap();
897 #[cfg(not(debug_assertions))]
898 let warning_threshold: f32 = 0.8;
899
900 // When the maximum is unknown because there is no selected model,
901 // avoid showing the token limit warning.
902 if self.max_tokens == 0 {
903 TokenUsageRatio::Normal
904 } else if self.used_tokens >= self.max_tokens {
905 TokenUsageRatio::Exceeded
906 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
907 TokenUsageRatio::Warning
908 } else {
909 TokenUsageRatio::Normal
910 }
911 }
912}
913
914#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
915pub enum TokenUsageRatio {
916 Normal,
917 Warning,
918 Exceeded,
919}
920
921#[derive(Debug, Clone)]
922pub struct RetryStatus {
923 pub last_error: SharedString,
924 pub attempt: usize,
925 pub max_attempts: usize,
926 pub started_at: Instant,
927 pub duration: Duration,
928}
929
930pub struct AcpThread {
931 parent_session_id: Option<acp::SessionId>,
932 title: SharedString,
933 entries: Vec<AgentThreadEntry>,
934 plan: Plan,
935 project: Entity<Project>,
936 action_log: Entity<ActionLog>,
937 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
938 send_task: Option<Task<()>>,
939 connection: Rc<dyn AgentConnection>,
940 session_id: acp::SessionId,
941 token_usage: Option<TokenUsage>,
942 prompt_capabilities: acp::PromptCapabilities,
943 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
944 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
945 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
946 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
947 // subagent cancellation fields
948 user_stopped: Arc<std::sync::atomic::AtomicBool>,
949 user_stop_tx: watch::Sender<bool>,
950}
951
952impl From<&AcpThread> for ActionLogTelemetry {
953 fn from(value: &AcpThread) -> Self {
954 Self {
955 agent_telemetry_id: value.connection().telemetry_id(),
956 session_id: value.session_id.0.clone(),
957 }
958 }
959}
960
961#[derive(Debug)]
962pub enum AcpThreadEvent {
963 NewEntry,
964 TitleUpdated,
965 TokenUsageUpdated,
966 EntryUpdated(usize),
967 EntriesRemoved(Range<usize>),
968 ToolAuthorizationRequired,
969 Retry(RetryStatus),
970 SubagentSpawned(acp::SessionId),
971 Stopped,
972 Error,
973 LoadError(LoadError),
974 PromptCapabilitiesUpdated,
975 Refusal,
976 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
977 ModeUpdated(acp::SessionModeId),
978 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
979}
980
981impl EventEmitter<AcpThreadEvent> for AcpThread {}
982
983#[derive(Debug, Clone)]
984pub enum TerminalProviderEvent {
985 Created {
986 terminal_id: acp::TerminalId,
987 label: String,
988 cwd: Option<PathBuf>,
989 output_byte_limit: Option<u64>,
990 terminal: Entity<::terminal::Terminal>,
991 },
992 Output {
993 terminal_id: acp::TerminalId,
994 data: Vec<u8>,
995 },
996 TitleChanged {
997 terminal_id: acp::TerminalId,
998 title: String,
999 },
1000 Exit {
1001 terminal_id: acp::TerminalId,
1002 status: acp::TerminalExitStatus,
1003 },
1004}
1005
1006#[derive(Debug, Clone)]
1007pub enum TerminalProviderCommand {
1008 WriteInput {
1009 terminal_id: acp::TerminalId,
1010 bytes: Vec<u8>,
1011 },
1012 Resize {
1013 terminal_id: acp::TerminalId,
1014 cols: u16,
1015 rows: u16,
1016 },
1017 Close {
1018 terminal_id: acp::TerminalId,
1019 },
1020}
1021
1022impl AcpThread {
1023 pub fn on_terminal_provider_event(
1024 &mut self,
1025 event: TerminalProviderEvent,
1026 cx: &mut Context<Self>,
1027 ) {
1028 match event {
1029 TerminalProviderEvent::Created {
1030 terminal_id,
1031 label,
1032 cwd,
1033 output_byte_limit,
1034 terminal,
1035 } => {
1036 let entity = self.register_terminal_created(
1037 terminal_id.clone(),
1038 label,
1039 cwd,
1040 output_byte_limit,
1041 terminal,
1042 cx,
1043 );
1044
1045 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1046 for data in chunks.drain(..) {
1047 entity.update(cx, |term, cx| {
1048 term.inner().update(cx, |inner, cx| {
1049 inner.write_output(&data, cx);
1050 })
1051 });
1052 }
1053 }
1054
1055 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1056 entity.update(cx, |_term, cx| {
1057 cx.notify();
1058 });
1059 }
1060
1061 cx.notify();
1062 }
1063 TerminalProviderEvent::Output { terminal_id, data } => {
1064 if let Some(entity) = self.terminals.get(&terminal_id) {
1065 entity.update(cx, |term, cx| {
1066 term.inner().update(cx, |inner, cx| {
1067 inner.write_output(&data, cx);
1068 })
1069 });
1070 } else {
1071 self.pending_terminal_output
1072 .entry(terminal_id)
1073 .or_default()
1074 .push(data);
1075 }
1076 }
1077 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1078 if let Some(entity) = self.terminals.get(&terminal_id) {
1079 entity.update(cx, |term, cx| {
1080 term.inner().update(cx, |inner, cx| {
1081 inner.breadcrumb_text = title;
1082 cx.emit(::terminal::Event::BreadcrumbsChanged);
1083 })
1084 });
1085 }
1086 }
1087 TerminalProviderEvent::Exit {
1088 terminal_id,
1089 status,
1090 } => {
1091 if let Some(entity) = self.terminals.get(&terminal_id) {
1092 entity.update(cx, |_term, cx| {
1093 cx.notify();
1094 });
1095 } else {
1096 self.pending_terminal_exit.insert(terminal_id, status);
1097 }
1098 }
1099 }
1100 }
1101}
1102
1103#[derive(PartialEq, Eq, Debug)]
1104pub enum ThreadStatus {
1105 Idle,
1106 Generating,
1107}
1108
1109#[derive(Debug, Clone)]
1110pub enum LoadError {
1111 Unsupported {
1112 command: SharedString,
1113 current_version: SharedString,
1114 minimum_version: SharedString,
1115 },
1116 FailedToInstall(SharedString),
1117 Exited {
1118 status: ExitStatus,
1119 },
1120 Other(SharedString),
1121}
1122
1123impl Display for LoadError {
1124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1125 match self {
1126 LoadError::Unsupported {
1127 command: path,
1128 current_version,
1129 minimum_version,
1130 } => {
1131 write!(
1132 f,
1133 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1134 )
1135 }
1136 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1137 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1138 LoadError::Other(msg) => write!(f, "{msg}"),
1139 }
1140 }
1141}
1142
1143impl Error for LoadError {}
1144
1145impl AcpThread {
1146 pub fn new(
1147 parent_session_id: Option<acp::SessionId>,
1148 title: impl Into<SharedString>,
1149 connection: Rc<dyn AgentConnection>,
1150 project: Entity<Project>,
1151 action_log: Entity<ActionLog>,
1152 session_id: acp::SessionId,
1153 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1154 cx: &mut Context<Self>,
1155 ) -> Self {
1156 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1157 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1158 loop {
1159 let caps = prompt_capabilities_rx.recv().await?;
1160 this.update(cx, |this, cx| {
1161 this.prompt_capabilities = caps;
1162 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1163 })?;
1164 }
1165 });
1166
1167 let (user_stop_tx, _user_stop_rx) = watch::channel(false);
1168
1169 Self {
1170 parent_session_id,
1171 action_log,
1172 shared_buffers: Default::default(),
1173 entries: Default::default(),
1174 plan: Default::default(),
1175 title: title.into(),
1176 project,
1177 send_task: None,
1178 connection,
1179 session_id,
1180 token_usage: None,
1181 prompt_capabilities,
1182 _observe_prompt_capabilities: task,
1183 terminals: HashMap::default(),
1184 pending_terminal_output: HashMap::default(),
1185 pending_terminal_exit: HashMap::default(),
1186 user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1187 user_stop_tx,
1188 }
1189 }
1190
1191 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1192 self.parent_session_id.as_ref()
1193 }
1194
1195 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1196 self.prompt_capabilities.clone()
1197 }
1198
1199 /// Marks this thread as stopped by user action and signals any listeners.
1200 pub fn stop_by_user(&mut self) {
1201 self.user_stopped
1202 .store(true, std::sync::atomic::Ordering::SeqCst);
1203 self.user_stop_tx.send(true).ok();
1204 self.send_task.take();
1205 }
1206
1207 pub fn was_stopped_by_user(&self) -> bool {
1208 self.user_stopped.load(std::sync::atomic::Ordering::SeqCst)
1209 }
1210
1211 pub fn user_stop_receiver(&self) -> watch::Receiver<bool> {
1212 self.user_stop_tx.receiver()
1213 }
1214
1215 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1216 &self.connection
1217 }
1218
1219 pub fn action_log(&self) -> &Entity<ActionLog> {
1220 &self.action_log
1221 }
1222
1223 pub fn project(&self) -> &Entity<Project> {
1224 &self.project
1225 }
1226
1227 pub fn title(&self) -> SharedString {
1228 self.title.clone()
1229 }
1230
1231 pub fn entries(&self) -> &[AgentThreadEntry] {
1232 &self.entries
1233 }
1234
1235 pub fn session_id(&self) -> &acp::SessionId {
1236 &self.session_id
1237 }
1238
1239 pub fn status(&self) -> ThreadStatus {
1240 if self.send_task.is_some() {
1241 ThreadStatus::Generating
1242 } else {
1243 ThreadStatus::Idle
1244 }
1245 }
1246
1247 pub fn token_usage(&self) -> Option<&TokenUsage> {
1248 self.token_usage.as_ref()
1249 }
1250
1251 pub fn has_pending_edit_tool_calls(&self) -> bool {
1252 for entry in self.entries.iter().rev() {
1253 match entry {
1254 AgentThreadEntry::UserMessage(_) => return false,
1255 AgentThreadEntry::ToolCall(
1256 call @ ToolCall {
1257 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1258 ..
1259 },
1260 ) if call.diffs().next().is_some() => {
1261 return true;
1262 }
1263 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1264 }
1265 }
1266
1267 false
1268 }
1269
1270 pub fn has_in_progress_tool_calls(&self) -> bool {
1271 for entry in self.entries.iter().rev() {
1272 match entry {
1273 AgentThreadEntry::UserMessage(_) => return false,
1274 AgentThreadEntry::ToolCall(ToolCall {
1275 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1276 ..
1277 }) => {
1278 return true;
1279 }
1280 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1281 }
1282 }
1283
1284 false
1285 }
1286
1287 pub fn used_tools_since_last_user_message(&self) -> bool {
1288 for entry in self.entries.iter().rev() {
1289 match entry {
1290 AgentThreadEntry::UserMessage(..) => return false,
1291 AgentThreadEntry::AssistantMessage(..) => continue,
1292 AgentThreadEntry::ToolCall(..) => return true,
1293 }
1294 }
1295
1296 false
1297 }
1298
1299 pub fn handle_session_update(
1300 &mut self,
1301 update: acp::SessionUpdate,
1302 cx: &mut Context<Self>,
1303 ) -> Result<(), acp::Error> {
1304 match update {
1305 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1306 self.push_user_content_block(None, content, cx);
1307 }
1308 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1309 self.push_assistant_content_block(content, false, cx);
1310 }
1311 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1312 self.push_assistant_content_block(content, true, cx);
1313 }
1314 acp::SessionUpdate::ToolCall(tool_call) => {
1315 self.upsert_tool_call(tool_call, cx)?;
1316 }
1317 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1318 self.update_tool_call(tool_call_update, cx)?;
1319 }
1320 acp::SessionUpdate::Plan(plan) => {
1321 self.update_plan(plan, cx);
1322 }
1323 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1324 available_commands,
1325 ..
1326 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1327 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1328 current_mode_id,
1329 ..
1330 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1331 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1332 config_options,
1333 ..
1334 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1335 _ => {}
1336 }
1337 Ok(())
1338 }
1339
1340 pub fn push_user_content_block(
1341 &mut self,
1342 message_id: Option<UserMessageId>,
1343 chunk: acp::ContentBlock,
1344 cx: &mut Context<Self>,
1345 ) {
1346 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1347 }
1348
1349 pub fn push_user_content_block_with_indent(
1350 &mut self,
1351 message_id: Option<UserMessageId>,
1352 chunk: acp::ContentBlock,
1353 indented: bool,
1354 cx: &mut Context<Self>,
1355 ) {
1356 let language_registry = self.project.read(cx).languages().clone();
1357 let path_style = self.project.read(cx).path_style(cx);
1358 let entries_len = self.entries.len();
1359
1360 if let Some(last_entry) = self.entries.last_mut()
1361 && let AgentThreadEntry::UserMessage(UserMessage {
1362 id,
1363 content,
1364 chunks,
1365 indented: existing_indented,
1366 ..
1367 }) = last_entry
1368 && *existing_indented == indented
1369 {
1370 *id = message_id.or(id.take());
1371 content.append(chunk.clone(), &language_registry, path_style, cx);
1372 chunks.push(chunk);
1373 let idx = entries_len - 1;
1374 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1375 } else {
1376 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1377 self.push_entry(
1378 AgentThreadEntry::UserMessage(UserMessage {
1379 id: message_id,
1380 content,
1381 chunks: vec![chunk],
1382 checkpoint: None,
1383 indented,
1384 }),
1385 cx,
1386 );
1387 }
1388 }
1389
1390 pub fn push_assistant_content_block(
1391 &mut self,
1392 chunk: acp::ContentBlock,
1393 is_thought: bool,
1394 cx: &mut Context<Self>,
1395 ) {
1396 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1397 }
1398
1399 pub fn push_assistant_content_block_with_indent(
1400 &mut self,
1401 chunk: acp::ContentBlock,
1402 is_thought: bool,
1403 indented: bool,
1404 cx: &mut Context<Self>,
1405 ) {
1406 let language_registry = self.project.read(cx).languages().clone();
1407 let path_style = self.project.read(cx).path_style(cx);
1408 let entries_len = self.entries.len();
1409 if let Some(last_entry) = self.entries.last_mut()
1410 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1411 chunks,
1412 indented: existing_indented,
1413 }) = last_entry
1414 && *existing_indented == indented
1415 {
1416 let idx = entries_len - 1;
1417 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1418 match (chunks.last_mut(), is_thought) {
1419 (Some(AssistantMessageChunk::Message { block }), false)
1420 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1421 block.append(chunk, &language_registry, path_style, cx)
1422 }
1423 _ => {
1424 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1425 if is_thought {
1426 chunks.push(AssistantMessageChunk::Thought { block })
1427 } else {
1428 chunks.push(AssistantMessageChunk::Message { block })
1429 }
1430 }
1431 }
1432 } else {
1433 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1434 let chunk = if is_thought {
1435 AssistantMessageChunk::Thought { block }
1436 } else {
1437 AssistantMessageChunk::Message { block }
1438 };
1439
1440 self.push_entry(
1441 AgentThreadEntry::AssistantMessage(AssistantMessage {
1442 chunks: vec![chunk],
1443 indented,
1444 }),
1445 cx,
1446 );
1447 }
1448 }
1449
1450 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1451 self.entries.push(entry);
1452 cx.emit(AcpThreadEvent::NewEntry);
1453 }
1454
1455 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1456 self.connection.set_title(&self.session_id, cx).is_some()
1457 }
1458
1459 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1460 if title != self.title {
1461 self.title = title.clone();
1462 cx.emit(AcpThreadEvent::TitleUpdated);
1463 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1464 return set_title.run(title, cx);
1465 }
1466 }
1467 Task::ready(Ok(()))
1468 }
1469
1470 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1471 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1472 }
1473
1474 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1475 self.token_usage = usage;
1476 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1477 }
1478
1479 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1480 cx.emit(AcpThreadEvent::Retry(status));
1481 }
1482
1483 pub fn update_tool_call(
1484 &mut self,
1485 update: impl Into<ToolCallUpdate>,
1486 cx: &mut Context<Self>,
1487 ) -> Result<()> {
1488 let update = update.into();
1489 let languages = self.project.read(cx).languages().clone();
1490 let path_style = self.project.read(cx).path_style(cx);
1491
1492 let ix = match self.index_for_tool_call(update.id()) {
1493 Some(ix) => ix,
1494 None => {
1495 // Tool call not found - create a failed tool call entry
1496 let failed_tool_call = ToolCall {
1497 id: update.id().clone(),
1498 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1499 kind: acp::ToolKind::Fetch,
1500 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1501 "Tool call not found".into(),
1502 &languages,
1503 path_style,
1504 cx,
1505 ))],
1506 status: ToolCallStatus::Failed,
1507 locations: Vec::new(),
1508 resolved_locations: Vec::new(),
1509 raw_input: None,
1510 raw_input_markdown: None,
1511 raw_output: None,
1512 tool_name: None,
1513 subagent_session_id: None,
1514 };
1515 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1516 return Ok(());
1517 }
1518 };
1519 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1520 unreachable!()
1521 };
1522
1523 match update {
1524 ToolCallUpdate::UpdateFields(update) => {
1525 let location_updated = update.fields.locations.is_some();
1526 call.update_fields(
1527 update.fields,
1528 update.meta,
1529 languages,
1530 path_style,
1531 &self.terminals,
1532 cx,
1533 )?;
1534 if location_updated {
1535 self.resolve_locations(update.tool_call_id, cx);
1536 }
1537 }
1538 ToolCallUpdate::UpdateDiff(update) => {
1539 call.content.clear();
1540 call.content.push(ToolCallContent::Diff(update.diff));
1541 }
1542 ToolCallUpdate::UpdateTerminal(update) => {
1543 call.content.clear();
1544 call.content
1545 .push(ToolCallContent::Terminal(update.terminal));
1546 }
1547 }
1548
1549 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1550
1551 Ok(())
1552 }
1553
1554 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1555 pub fn upsert_tool_call(
1556 &mut self,
1557 tool_call: acp::ToolCall,
1558 cx: &mut Context<Self>,
1559 ) -> Result<(), acp::Error> {
1560 let status = tool_call.status.into();
1561 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1562 }
1563
1564 /// Fails if id does not match an existing entry.
1565 pub fn upsert_tool_call_inner(
1566 &mut self,
1567 update: acp::ToolCallUpdate,
1568 status: ToolCallStatus,
1569 cx: &mut Context<Self>,
1570 ) -> Result<(), acp::Error> {
1571 let language_registry = self.project.read(cx).languages().clone();
1572 let path_style = self.project.read(cx).path_style(cx);
1573 let id = update.tool_call_id.clone();
1574
1575 let agent_telemetry_id = self.connection().telemetry_id();
1576 let session = self.session_id();
1577 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1578 let status = if matches!(status, ToolCallStatus::Completed) {
1579 "completed"
1580 } else {
1581 "failed"
1582 };
1583 telemetry::event!(
1584 "Agent Tool Call Completed",
1585 agent_telemetry_id,
1586 session,
1587 status
1588 );
1589 }
1590
1591 if let Some(ix) = self.index_for_tool_call(&id) {
1592 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1593 unreachable!()
1594 };
1595
1596 call.update_fields(
1597 update.fields,
1598 update.meta,
1599 language_registry,
1600 path_style,
1601 &self.terminals,
1602 cx,
1603 )?;
1604 call.status = status;
1605
1606 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1607 } else {
1608 let call = ToolCall::from_acp(
1609 update.try_into()?,
1610 status,
1611 language_registry,
1612 self.project.read(cx).path_style(cx),
1613 &self.terminals,
1614 cx,
1615 )?;
1616 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1617 };
1618
1619 self.resolve_locations(id, cx);
1620 Ok(())
1621 }
1622
1623 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1624 self.entries
1625 .iter()
1626 .enumerate()
1627 .rev()
1628 .find_map(|(index, entry)| {
1629 if let AgentThreadEntry::ToolCall(tool_call) = entry
1630 && &tool_call.id == id
1631 {
1632 Some(index)
1633 } else {
1634 None
1635 }
1636 })
1637 }
1638
1639 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1640 // The tool call we are looking for is typically the last one, or very close to the end.
1641 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1642 self.entries
1643 .iter_mut()
1644 .enumerate()
1645 .rev()
1646 .find_map(|(index, tool_call)| {
1647 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1648 && &tool_call.id == id
1649 {
1650 Some((index, tool_call))
1651 } else {
1652 None
1653 }
1654 })
1655 }
1656
1657 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1658 self.entries
1659 .iter()
1660 .enumerate()
1661 .rev()
1662 .find_map(|(index, tool_call)| {
1663 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1664 && &tool_call.id == id
1665 {
1666 Some((index, tool_call))
1667 } else {
1668 None
1669 }
1670 })
1671 }
1672
1673 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1674 let project = self.project.clone();
1675 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1676 return;
1677 };
1678 let task = tool_call.resolve_locations(project, cx);
1679 cx.spawn(async move |this, cx| {
1680 let resolved_locations = task.await;
1681
1682 this.update(cx, |this, cx| {
1683 let project = this.project.clone();
1684
1685 for location in resolved_locations.iter().flatten() {
1686 this.shared_buffers
1687 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1688 }
1689 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1690 return;
1691 };
1692
1693 if let Some(Some(location)) = resolved_locations.last() {
1694 project.update(cx, |project, cx| {
1695 let should_ignore = if let Some(agent_location) = project
1696 .agent_location()
1697 .filter(|agent_location| agent_location.buffer == location.buffer)
1698 {
1699 let snapshot = location.buffer.read(cx).snapshot();
1700 let old_position = agent_location.position.to_point(&snapshot);
1701 let new_position = location.position.to_point(&snapshot);
1702
1703 // ignore this so that when we get updates from the edit tool
1704 // the position doesn't reset to the startof line
1705 old_position.row == new_position.row
1706 && old_position.column > new_position.column
1707 } else {
1708 false
1709 };
1710 if !should_ignore {
1711 project.set_agent_location(Some(location.into()), cx);
1712 }
1713 });
1714 }
1715
1716 let resolved_locations = resolved_locations
1717 .iter()
1718 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1719 .collect::<Vec<_>>();
1720
1721 if tool_call.resolved_locations != resolved_locations {
1722 tool_call.resolved_locations = resolved_locations;
1723 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1724 }
1725 })
1726 })
1727 .detach();
1728 }
1729
1730 pub fn request_tool_call_authorization(
1731 &mut self,
1732 tool_call: acp::ToolCallUpdate,
1733 options: PermissionOptions,
1734 cx: &mut Context<Self>,
1735 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1736 let (tx, rx) = oneshot::channel();
1737
1738 let status = ToolCallStatus::WaitingForConfirmation {
1739 options,
1740 respond_tx: tx,
1741 };
1742
1743 self.upsert_tool_call_inner(tool_call, status, cx)?;
1744 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1745
1746 let fut = async {
1747 match rx.await {
1748 Ok(option) => acp::RequestPermissionOutcome::Selected(
1749 acp::SelectedPermissionOutcome::new(option),
1750 ),
1751 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1752 }
1753 }
1754 .boxed();
1755
1756 Ok(fut)
1757 }
1758
1759 pub fn authorize_tool_call(
1760 &mut self,
1761 id: acp::ToolCallId,
1762 option_id: acp::PermissionOptionId,
1763 option_kind: acp::PermissionOptionKind,
1764 cx: &mut Context<Self>,
1765 ) {
1766 let Some((ix, call)) = self.tool_call_mut(&id) else {
1767 return;
1768 };
1769
1770 let new_status = match option_kind {
1771 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1772 ToolCallStatus::Rejected
1773 }
1774 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1775 ToolCallStatus::InProgress
1776 }
1777 _ => ToolCallStatus::InProgress,
1778 };
1779
1780 let curr_status = mem::replace(&mut call.status, new_status);
1781
1782 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1783 respond_tx.send(option_id).log_err();
1784 } else if cfg!(debug_assertions) {
1785 panic!("tried to authorize an already authorized tool call");
1786 }
1787
1788 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1789 }
1790
1791 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1792 let mut first_tool_call = None;
1793
1794 for entry in self.entries.iter().rev() {
1795 match &entry {
1796 AgentThreadEntry::ToolCall(call) => {
1797 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1798 first_tool_call = Some(call);
1799 } else {
1800 continue;
1801 }
1802 }
1803 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1804 // Reached the beginning of the turn.
1805 // If we had pending permission requests in the previous turn, they have been cancelled.
1806 break;
1807 }
1808 }
1809 }
1810
1811 first_tool_call
1812 }
1813
1814 pub fn plan(&self) -> &Plan {
1815 &self.plan
1816 }
1817
1818 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1819 let new_entries_len = request.entries.len();
1820 let mut new_entries = request.entries.into_iter();
1821
1822 // Reuse existing markdown to prevent flickering
1823 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1824 let PlanEntry {
1825 content,
1826 priority,
1827 status,
1828 } = old;
1829 content.update(cx, |old, cx| {
1830 old.replace(new.content, cx);
1831 });
1832 *priority = new.priority;
1833 *status = new.status;
1834 }
1835 for new in new_entries {
1836 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1837 }
1838 self.plan.entries.truncate(new_entries_len);
1839
1840 cx.notify();
1841 }
1842
1843 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1844 self.plan
1845 .entries
1846 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1847 cx.notify();
1848 }
1849
1850 #[cfg(any(test, feature = "test-support"))]
1851 pub fn send_raw(
1852 &mut self,
1853 message: &str,
1854 cx: &mut Context<Self>,
1855 ) -> BoxFuture<'static, Result<()>> {
1856 self.send(vec![message.into()], cx)
1857 }
1858
1859 pub fn send(
1860 &mut self,
1861 message: Vec<acp::ContentBlock>,
1862 cx: &mut Context<Self>,
1863 ) -> BoxFuture<'static, Result<()>> {
1864 let block = ContentBlock::new_combined(
1865 message.clone(),
1866 self.project.read(cx).languages().clone(),
1867 self.project.read(cx).path_style(cx),
1868 cx,
1869 );
1870 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1871 let git_store = self.project.read(cx).git_store().clone();
1872
1873 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1874 Some(UserMessageId::new())
1875 } else {
1876 None
1877 };
1878
1879 self.run_turn(cx, async move |this, cx| {
1880 this.update(cx, |this, cx| {
1881 this.push_entry(
1882 AgentThreadEntry::UserMessage(UserMessage {
1883 id: message_id.clone(),
1884 content: block,
1885 chunks: message,
1886 checkpoint: None,
1887 indented: false,
1888 }),
1889 cx,
1890 );
1891 })
1892 .ok();
1893
1894 let old_checkpoint = git_store
1895 .update(cx, |git, cx| git.checkpoint(cx))
1896 .await
1897 .context("failed to get old checkpoint")
1898 .log_err();
1899 this.update(cx, |this, cx| {
1900 if let Some((_ix, message)) = this.last_user_message() {
1901 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1902 git_checkpoint,
1903 show: false,
1904 });
1905 }
1906 this.connection.prompt(message_id, request, cx)
1907 })?
1908 .await
1909 })
1910 }
1911
1912 pub fn can_retry(&self, cx: &App) -> bool {
1913 self.connection.retry(&self.session_id, cx).is_some()
1914 }
1915
1916 pub fn retry(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1917 self.run_turn(cx, async move |this, cx| {
1918 this.update(cx, |this, cx| {
1919 this.connection
1920 .retry(&this.session_id, cx)
1921 .map(|retry| retry.run(cx))
1922 })?
1923 .context("retrying a session is not supported")?
1924 .await
1925 })
1926 }
1927
1928 fn run_turn(
1929 &mut self,
1930 cx: &mut Context<Self>,
1931 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1932 ) -> BoxFuture<'static, Result<()>> {
1933 self.clear_completed_plan_entries(cx);
1934
1935 let (tx, rx) = oneshot::channel();
1936 let cancel_task = self.cancel(cx);
1937
1938 self.send_task = Some(cx.spawn(async move |this, cx| {
1939 cancel_task.await;
1940 tx.send(f(this, cx).await).ok();
1941 }));
1942
1943 cx.spawn(async move |this, cx| {
1944 let response = rx.await;
1945
1946 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1947 .await?;
1948
1949 this.update(cx, |this, cx| {
1950 this.project
1951 .update(cx, |project, cx| project.set_agent_location(None, cx));
1952 match response {
1953 Ok(Err(e)) => {
1954 this.send_task.take();
1955 cx.emit(AcpThreadEvent::Error);
1956 Err(e)
1957 }
1958 result => {
1959 let canceled = matches!(
1960 result,
1961 Ok(Ok(acp::PromptResponse {
1962 stop_reason: acp::StopReason::Cancelled,
1963 ..
1964 }))
1965 );
1966
1967 // We only take the task if the current prompt wasn't canceled.
1968 //
1969 // This prompt may have been canceled because another one was sent
1970 // while it was still generating. In these cases, dropping `send_task`
1971 // would cause the next generation to be canceled.
1972 if !canceled {
1973 this.send_task.take();
1974 }
1975
1976 // Handle refusal - distinguish between user prompt and tool call refusals
1977 if let Ok(Ok(acp::PromptResponse {
1978 stop_reason: acp::StopReason::Refusal,
1979 ..
1980 })) = result
1981 {
1982 if let Some((user_msg_ix, _)) = this.last_user_message() {
1983 // Check if there's a completed tool call with results after the last user message
1984 // This indicates the refusal is in response to tool output, not the user's prompt
1985 let has_completed_tool_call_after_user_msg =
1986 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1987 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1988 // Check if the tool call has completed and has output
1989 matches!(tool_call.status, ToolCallStatus::Completed)
1990 && tool_call.raw_output.is_some()
1991 } else {
1992 false
1993 }
1994 });
1995
1996 if has_completed_tool_call_after_user_msg {
1997 // Refusal is due to tool output - don't truncate, just notify
1998 // The model refused based on what the tool returned
1999 cx.emit(AcpThreadEvent::Refusal);
2000 } else {
2001 // User prompt was refused - truncate back to before the user message
2002 let range = user_msg_ix..this.entries.len();
2003 if range.start < range.end {
2004 this.entries.truncate(user_msg_ix);
2005 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2006 }
2007 cx.emit(AcpThreadEvent::Refusal);
2008 }
2009 } else {
2010 // No user message found, treat as general refusal
2011 cx.emit(AcpThreadEvent::Refusal);
2012 }
2013 }
2014
2015 cx.emit(AcpThreadEvent::Stopped);
2016 Ok(())
2017 }
2018 }
2019 })?
2020 })
2021 .boxed()
2022 }
2023
2024 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2025 let Some(send_task) = self.send_task.take() else {
2026 return Task::ready(());
2027 };
2028
2029 for entry in self.entries.iter_mut() {
2030 if let AgentThreadEntry::ToolCall(call) = entry {
2031 let cancel = matches!(
2032 call.status,
2033 ToolCallStatus::Pending
2034 | ToolCallStatus::WaitingForConfirmation { .. }
2035 | ToolCallStatus::InProgress
2036 );
2037
2038 if cancel {
2039 call.status = ToolCallStatus::Canceled;
2040 }
2041 }
2042 }
2043
2044 self.connection.cancel(&self.session_id, cx);
2045
2046 // Wait for the send task to complete
2047 cx.foreground_executor().spawn(send_task)
2048 }
2049
2050 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2051 pub fn restore_checkpoint(
2052 &mut self,
2053 id: UserMessageId,
2054 cx: &mut Context<Self>,
2055 ) -> Task<Result<()>> {
2056 let Some((_, message)) = self.user_message_mut(&id) else {
2057 return Task::ready(Err(anyhow!("message not found")));
2058 };
2059
2060 let checkpoint = message
2061 .checkpoint
2062 .as_ref()
2063 .map(|c| c.git_checkpoint.clone());
2064
2065 // Cancel any in-progress generation before restoring
2066 let cancel_task = self.cancel(cx);
2067 let rewind = self.rewind(id.clone(), cx);
2068 let git_store = self.project.read(cx).git_store().clone();
2069
2070 cx.spawn(async move |_, cx| {
2071 cancel_task.await;
2072 rewind.await?;
2073 if let Some(checkpoint) = checkpoint {
2074 git_store
2075 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2076 .await?;
2077 }
2078
2079 Ok(())
2080 })
2081 }
2082
2083 /// Rewinds this thread to before the entry at `index`, removing it and all
2084 /// subsequent entries while rejecting any action_log changes made from that point.
2085 /// Unlike `restore_checkpoint`, this method does not restore from git.
2086 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2087 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2088 return Task::ready(Err(anyhow!("not supported")));
2089 };
2090
2091 let telemetry = ActionLogTelemetry::from(&*self);
2092 cx.spawn(async move |this, cx| {
2093 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2094 this.update(cx, |this, cx| {
2095 if let Some((ix, _)) = this.user_message_mut(&id) {
2096 // Collect all terminals from entries that will be removed
2097 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2098 .iter()
2099 .flat_map(|entry| entry.terminals())
2100 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2101 .collect();
2102
2103 let range = ix..this.entries.len();
2104 this.entries.truncate(ix);
2105 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2106
2107 // Kill and remove the terminals
2108 for terminal_id in terminals_to_remove {
2109 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2110 terminal.update(cx, |terminal, cx| {
2111 terminal.kill(cx);
2112 });
2113 }
2114 }
2115 }
2116 this.action_log().update(cx, |action_log, cx| {
2117 action_log.reject_all_edits(Some(telemetry), cx)
2118 })
2119 })?
2120 .await;
2121 Ok(())
2122 })
2123 }
2124
2125 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2126 let git_store = self.project.read(cx).git_store().clone();
2127
2128 let Some((_, message)) = self.last_user_message() else {
2129 return Task::ready(Ok(()));
2130 };
2131 let Some(user_message_id) = message.id.clone() else {
2132 return Task::ready(Ok(()));
2133 };
2134 let Some(checkpoint) = message.checkpoint.as_ref() else {
2135 return Task::ready(Ok(()));
2136 };
2137 let old_checkpoint = checkpoint.git_checkpoint.clone();
2138
2139 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2140 cx.spawn(async move |this, cx| {
2141 let Some(new_checkpoint) = new_checkpoint
2142 .await
2143 .context("failed to get new checkpoint")
2144 .log_err()
2145 else {
2146 return Ok(());
2147 };
2148
2149 let equal = git_store
2150 .update(cx, |git, cx| {
2151 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2152 })
2153 .await
2154 .unwrap_or(true);
2155
2156 this.update(cx, |this, cx| {
2157 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2158 if let Some(checkpoint) = message.checkpoint.as_mut() {
2159 checkpoint.show = !equal;
2160 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2161 }
2162 }
2163 })?;
2164
2165 Ok(())
2166 })
2167 }
2168
2169 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2170 self.entries
2171 .iter_mut()
2172 .enumerate()
2173 .rev()
2174 .find_map(|(ix, entry)| {
2175 if let AgentThreadEntry::UserMessage(message) = entry {
2176 Some((ix, message))
2177 } else {
2178 None
2179 }
2180 })
2181 }
2182
2183 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2184 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2185 if let AgentThreadEntry::UserMessage(message) = entry {
2186 if message.id.as_ref() == Some(id) {
2187 Some((ix, message))
2188 } else {
2189 None
2190 }
2191 } else {
2192 None
2193 }
2194 })
2195 }
2196
2197 pub fn read_text_file(
2198 &self,
2199 path: PathBuf,
2200 line: Option<u32>,
2201 limit: Option<u32>,
2202 reuse_shared_snapshot: bool,
2203 cx: &mut Context<Self>,
2204 ) -> Task<Result<String, acp::Error>> {
2205 // Args are 1-based, move to 0-based
2206 let line = line.unwrap_or_default().saturating_sub(1);
2207 let limit = limit.unwrap_or(u32::MAX);
2208 let project = self.project.clone();
2209 let action_log = self.action_log.clone();
2210 cx.spawn(async move |this, cx| {
2211 let load = project.update(cx, |project, cx| {
2212 let path = project
2213 .project_path_for_absolute_path(&path, cx)
2214 .ok_or_else(|| {
2215 acp::Error::resource_not_found(Some(path.display().to_string()))
2216 })?;
2217 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2218 })?;
2219
2220 let buffer = load.await?;
2221
2222 let snapshot = if reuse_shared_snapshot {
2223 this.read_with(cx, |this, _| {
2224 this.shared_buffers.get(&buffer.clone()).cloned()
2225 })
2226 .log_err()
2227 .flatten()
2228 } else {
2229 None
2230 };
2231
2232 let snapshot = if let Some(snapshot) = snapshot {
2233 snapshot
2234 } else {
2235 action_log.update(cx, |action_log, cx| {
2236 action_log.buffer_read(buffer.clone(), cx);
2237 });
2238
2239 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2240 this.update(cx, |this, _| {
2241 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2242 })?;
2243 snapshot
2244 };
2245
2246 let max_point = snapshot.max_point();
2247 let start_position = Point::new(line, 0);
2248
2249 if start_position > max_point {
2250 return Err(acp::Error::invalid_params().data(format!(
2251 "Attempting to read beyond the end of the file, line {}:{}",
2252 max_point.row + 1,
2253 max_point.column
2254 )));
2255 }
2256
2257 let start = snapshot.anchor_before(start_position);
2258 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2259
2260 project.update(cx, |project, cx| {
2261 project.set_agent_location(
2262 Some(AgentLocation {
2263 buffer: buffer.downgrade(),
2264 position: start,
2265 }),
2266 cx,
2267 );
2268 });
2269
2270 Ok(snapshot.text_for_range(start..end).collect::<String>())
2271 })
2272 }
2273
2274 pub fn write_text_file(
2275 &self,
2276 path: PathBuf,
2277 content: String,
2278 cx: &mut Context<Self>,
2279 ) -> Task<Result<()>> {
2280 let project = self.project.clone();
2281 let action_log = self.action_log.clone();
2282 cx.spawn(async move |this, cx| {
2283 let load = project.update(cx, |project, cx| {
2284 let path = project
2285 .project_path_for_absolute_path(&path, cx)
2286 .context("invalid path")?;
2287 anyhow::Ok(project.open_buffer(path, cx))
2288 });
2289 let buffer = load?.await?;
2290 let snapshot = this.update(cx, |this, cx| {
2291 this.shared_buffers
2292 .get(&buffer)
2293 .cloned()
2294 .unwrap_or_else(|| buffer.read(cx).snapshot())
2295 })?;
2296 let edits = cx
2297 .background_executor()
2298 .spawn(async move {
2299 let old_text = snapshot.text();
2300 text_diff(old_text.as_str(), &content)
2301 .into_iter()
2302 .map(|(range, replacement)| {
2303 (
2304 snapshot.anchor_after(range.start)
2305 ..snapshot.anchor_before(range.end),
2306 replacement,
2307 )
2308 })
2309 .collect::<Vec<_>>()
2310 })
2311 .await;
2312
2313 project.update(cx, |project, cx| {
2314 project.set_agent_location(
2315 Some(AgentLocation {
2316 buffer: buffer.downgrade(),
2317 position: edits
2318 .last()
2319 .map(|(range, _)| range.end)
2320 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2321 }),
2322 cx,
2323 );
2324 });
2325
2326 let format_on_save = cx.update(|cx| {
2327 action_log.update(cx, |action_log, cx| {
2328 action_log.buffer_read(buffer.clone(), cx);
2329 });
2330
2331 let format_on_save = buffer.update(cx, |buffer, cx| {
2332 buffer.edit(edits, None, cx);
2333
2334 let settings = language::language_settings::language_settings(
2335 buffer.language().map(|l| l.name()),
2336 buffer.file(),
2337 cx,
2338 );
2339
2340 settings.format_on_save != FormatOnSave::Off
2341 });
2342 action_log.update(cx, |action_log, cx| {
2343 action_log.buffer_edited(buffer.clone(), cx);
2344 });
2345 format_on_save
2346 });
2347
2348 if format_on_save {
2349 let format_task = project.update(cx, |project, cx| {
2350 project.format(
2351 HashSet::from_iter([buffer.clone()]),
2352 LspFormatTarget::Buffers,
2353 false,
2354 FormatTrigger::Save,
2355 cx,
2356 )
2357 });
2358 format_task.await.log_err();
2359
2360 action_log.update(cx, |action_log, cx| {
2361 action_log.buffer_edited(buffer.clone(), cx);
2362 });
2363 }
2364
2365 project
2366 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2367 .await
2368 })
2369 }
2370
2371 pub fn create_terminal(
2372 &self,
2373 command: String,
2374 args: Vec<String>,
2375 extra_env: Vec<acp::EnvVariable>,
2376 cwd: Option<PathBuf>,
2377 output_byte_limit: Option<u64>,
2378 cx: &mut Context<Self>,
2379 ) -> Task<Result<Entity<Terminal>>> {
2380 let env = match &cwd {
2381 Some(dir) => self.project.update(cx, |project, cx| {
2382 project.environment().update(cx, |env, cx| {
2383 env.directory_environment(dir.as_path().into(), cx)
2384 })
2385 }),
2386 None => Task::ready(None).shared(),
2387 };
2388 let env = cx.spawn(async move |_, _| {
2389 let mut env = env.await.unwrap_or_default();
2390 // Disables paging for `git` and hopefully other commands
2391 env.insert("PAGER".into(), "".into());
2392 for var in extra_env {
2393 env.insert(var.name, var.value);
2394 }
2395 env
2396 });
2397
2398 let project = self.project.clone();
2399 let language_registry = project.read(cx).languages().clone();
2400 let is_windows = project.read(cx).path_style(cx).is_windows();
2401
2402 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2403 let terminal_task = cx.spawn({
2404 let terminal_id = terminal_id.clone();
2405 async move |_this, cx| {
2406 let env = env.await;
2407 let shell = project
2408 .update(cx, |project, cx| {
2409 project
2410 .remote_client()
2411 .and_then(|r| r.read(cx).default_system_shell())
2412 })
2413 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2414 let (task_command, task_args) =
2415 ShellBuilder::new(&Shell::Program(shell), is_windows)
2416 .redirect_stdin_to_dev_null()
2417 .build(Some(command.clone()), &args);
2418 let terminal = project
2419 .update(cx, |project, cx| {
2420 project.create_terminal_task(
2421 task::SpawnInTerminal {
2422 command: Some(task_command),
2423 args: task_args,
2424 cwd: cwd.clone(),
2425 env,
2426 ..Default::default()
2427 },
2428 cx,
2429 )
2430 })
2431 .await?;
2432
2433 anyhow::Ok(cx.new(|cx| {
2434 Terminal::new(
2435 terminal_id,
2436 &format!("{} {}", command, args.join(" ")),
2437 cwd,
2438 output_byte_limit.map(|l| l as usize),
2439 terminal,
2440 language_registry,
2441 cx,
2442 )
2443 }))
2444 }
2445 });
2446
2447 cx.spawn(async move |this, cx| {
2448 let terminal = terminal_task.await?;
2449 this.update(cx, |this, _cx| {
2450 this.terminals.insert(terminal_id, terminal.clone());
2451 terminal
2452 })
2453 })
2454 }
2455
2456 pub fn kill_terminal(
2457 &mut self,
2458 terminal_id: acp::TerminalId,
2459 cx: &mut Context<Self>,
2460 ) -> Result<()> {
2461 self.terminals
2462 .get(&terminal_id)
2463 .context("Terminal not found")?
2464 .update(cx, |terminal, cx| {
2465 terminal.kill(cx);
2466 });
2467
2468 Ok(())
2469 }
2470
2471 pub fn release_terminal(
2472 &mut self,
2473 terminal_id: acp::TerminalId,
2474 cx: &mut Context<Self>,
2475 ) -> Result<()> {
2476 self.terminals
2477 .remove(&terminal_id)
2478 .context("Terminal not found")?
2479 .update(cx, |terminal, cx| {
2480 terminal.kill(cx);
2481 });
2482
2483 Ok(())
2484 }
2485
2486 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2487 self.terminals
2488 .get(&terminal_id)
2489 .context("Terminal not found")
2490 .cloned()
2491 }
2492
2493 pub fn to_markdown(&self, cx: &App) -> String {
2494 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2495 }
2496
2497 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2498 cx.emit(AcpThreadEvent::LoadError(error));
2499 }
2500
2501 pub fn register_terminal_created(
2502 &mut self,
2503 terminal_id: acp::TerminalId,
2504 command_label: String,
2505 working_dir: Option<PathBuf>,
2506 output_byte_limit: Option<u64>,
2507 terminal: Entity<::terminal::Terminal>,
2508 cx: &mut Context<Self>,
2509 ) -> Entity<Terminal> {
2510 let language_registry = self.project.read(cx).languages().clone();
2511
2512 let entity = cx.new(|cx| {
2513 Terminal::new(
2514 terminal_id.clone(),
2515 &command_label,
2516 working_dir.clone(),
2517 output_byte_limit.map(|l| l as usize),
2518 terminal,
2519 language_registry,
2520 cx,
2521 )
2522 });
2523 self.terminals.insert(terminal_id.clone(), entity.clone());
2524 entity
2525 }
2526}
2527
2528fn markdown_for_raw_output(
2529 raw_output: &serde_json::Value,
2530 language_registry: &Arc<LanguageRegistry>,
2531 cx: &mut App,
2532) -> Option<Entity<Markdown>> {
2533 match raw_output {
2534 serde_json::Value::Null => None,
2535 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2536 Markdown::new(
2537 value.to_string().into(),
2538 Some(language_registry.clone()),
2539 None,
2540 cx,
2541 )
2542 })),
2543 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2544 Markdown::new(
2545 value.to_string().into(),
2546 Some(language_registry.clone()),
2547 None,
2548 cx,
2549 )
2550 })),
2551 serde_json::Value::String(value) => Some(cx.new(|cx| {
2552 Markdown::new(
2553 value.clone().into(),
2554 Some(language_registry.clone()),
2555 None,
2556 cx,
2557 )
2558 })),
2559 value => Some(cx.new(|cx| {
2560 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2561
2562 Markdown::new(
2563 format!("```json\n{}\n```", pretty_json).into(),
2564 Some(language_registry.clone()),
2565 None,
2566 cx,
2567 )
2568 })),
2569 }
2570}
2571
2572#[cfg(test)]
2573mod tests {
2574 use super::*;
2575 use anyhow::anyhow;
2576 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2577 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2578 use indoc::indoc;
2579 use project::{FakeFs, Fs};
2580 use rand::{distr, prelude::*};
2581 use serde_json::json;
2582 use settings::SettingsStore;
2583 use smol::stream::StreamExt as _;
2584 use std::{
2585 any::Any,
2586 cell::RefCell,
2587 path::Path,
2588 rc::Rc,
2589 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2590 time::Duration,
2591 };
2592 use util::path;
2593
2594 fn init_test(cx: &mut TestAppContext) {
2595 env_logger::try_init().ok();
2596 cx.update(|cx| {
2597 let settings_store = SettingsStore::test(cx);
2598 cx.set_global(settings_store);
2599 });
2600 }
2601
2602 #[gpui::test]
2603 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2604 init_test(cx);
2605
2606 let fs = FakeFs::new(cx.executor());
2607 let project = Project::test(fs, [], cx).await;
2608 let connection = Rc::new(FakeAgentConnection::new());
2609 let thread = cx
2610 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2611 .await
2612 .unwrap();
2613
2614 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2615
2616 // Send Output BEFORE Created - should be buffered by acp_thread
2617 thread.update(cx, |thread, cx| {
2618 thread.on_terminal_provider_event(
2619 TerminalProviderEvent::Output {
2620 terminal_id: terminal_id.clone(),
2621 data: b"hello buffered".to_vec(),
2622 },
2623 cx,
2624 );
2625 });
2626
2627 // Create a display-only terminal and then send Created
2628 let lower = cx.new(|cx| {
2629 let builder = ::terminal::TerminalBuilder::new_display_only(
2630 ::terminal::terminal_settings::CursorShape::default(),
2631 ::terminal::terminal_settings::AlternateScroll::On,
2632 None,
2633 0,
2634 cx.background_executor(),
2635 PathStyle::local(),
2636 )
2637 .unwrap();
2638 builder.subscribe(cx)
2639 });
2640
2641 thread.update(cx, |thread, cx| {
2642 thread.on_terminal_provider_event(
2643 TerminalProviderEvent::Created {
2644 terminal_id: terminal_id.clone(),
2645 label: "Buffered Test".to_string(),
2646 cwd: None,
2647 output_byte_limit: None,
2648 terminal: lower.clone(),
2649 },
2650 cx,
2651 );
2652 });
2653
2654 // After Created, buffered Output should have been flushed into the renderer
2655 let content = thread.read_with(cx, |thread, cx| {
2656 let term = thread.terminal(terminal_id.clone()).unwrap();
2657 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2658 });
2659
2660 assert!(
2661 content.contains("hello buffered"),
2662 "expected buffered output to render, got: {content}"
2663 );
2664 }
2665
2666 #[gpui::test]
2667 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2668 init_test(cx);
2669
2670 let fs = FakeFs::new(cx.executor());
2671 let project = Project::test(fs, [], cx).await;
2672 let connection = Rc::new(FakeAgentConnection::new());
2673 let thread = cx
2674 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2675 .await
2676 .unwrap();
2677
2678 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2679
2680 // Send Output BEFORE Created
2681 thread.update(cx, |thread, cx| {
2682 thread.on_terminal_provider_event(
2683 TerminalProviderEvent::Output {
2684 terminal_id: terminal_id.clone(),
2685 data: b"pre-exit data".to_vec(),
2686 },
2687 cx,
2688 );
2689 });
2690
2691 // Send Exit BEFORE Created
2692 thread.update(cx, |thread, cx| {
2693 thread.on_terminal_provider_event(
2694 TerminalProviderEvent::Exit {
2695 terminal_id: terminal_id.clone(),
2696 status: acp::TerminalExitStatus::new().exit_code(0),
2697 },
2698 cx,
2699 );
2700 });
2701
2702 // Now create a display-only lower-level terminal and send Created
2703 let lower = cx.new(|cx| {
2704 let builder = ::terminal::TerminalBuilder::new_display_only(
2705 ::terminal::terminal_settings::CursorShape::default(),
2706 ::terminal::terminal_settings::AlternateScroll::On,
2707 None,
2708 0,
2709 cx.background_executor(),
2710 PathStyle::local(),
2711 )
2712 .unwrap();
2713 builder.subscribe(cx)
2714 });
2715
2716 thread.update(cx, |thread, cx| {
2717 thread.on_terminal_provider_event(
2718 TerminalProviderEvent::Created {
2719 terminal_id: terminal_id.clone(),
2720 label: "Buffered Exit Test".to_string(),
2721 cwd: None,
2722 output_byte_limit: None,
2723 terminal: lower.clone(),
2724 },
2725 cx,
2726 );
2727 });
2728
2729 // Output should be present after Created (flushed from buffer)
2730 let content = thread.read_with(cx, |thread, cx| {
2731 let term = thread.terminal(terminal_id.clone()).unwrap();
2732 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2733 });
2734
2735 assert!(
2736 content.contains("pre-exit data"),
2737 "expected pre-exit data to render, got: {content}"
2738 );
2739 }
2740
2741 /// Test that killing a terminal via Terminal::kill properly:
2742 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2743 /// 2. The underlying terminal still has the output that was written before the kill
2744 ///
2745 /// This test verifies that the fix to kill_active_task (which now also kills
2746 /// the shell process in addition to the foreground process) properly allows
2747 /// wait_for_exit to complete instead of hanging indefinitely.
2748 #[cfg(unix)]
2749 #[gpui::test]
2750 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2751 use std::collections::HashMap;
2752 use task::Shell;
2753 use util::shell_builder::ShellBuilder;
2754
2755 init_test(cx);
2756 cx.executor().allow_parking();
2757
2758 let fs = FakeFs::new(cx.executor());
2759 let project = Project::test(fs, [], cx).await;
2760 let connection = Rc::new(FakeAgentConnection::new());
2761 let thread = cx
2762 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2763 .await
2764 .unwrap();
2765
2766 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2767
2768 // Create a real PTY terminal that runs a command which prints output then sleeps
2769 // We use printf instead of echo and chain with && sleep to ensure proper execution
2770 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2771 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2772 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2773 &[],
2774 );
2775
2776 let builder = cx
2777 .update(|cx| {
2778 ::terminal::TerminalBuilder::new(
2779 None,
2780 None,
2781 task::Shell::WithArguments {
2782 program,
2783 args,
2784 title_override: None,
2785 },
2786 HashMap::default(),
2787 ::terminal::terminal_settings::CursorShape::default(),
2788 ::terminal::terminal_settings::AlternateScroll::On,
2789 None,
2790 vec![],
2791 0,
2792 false,
2793 0,
2794 Some(completion_tx),
2795 cx,
2796 vec![],
2797 PathStyle::local(),
2798 )
2799 })
2800 .await
2801 .unwrap();
2802
2803 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2804
2805 // Create the acp_thread Terminal wrapper
2806 thread.update(cx, |thread, cx| {
2807 thread.on_terminal_provider_event(
2808 TerminalProviderEvent::Created {
2809 terminal_id: terminal_id.clone(),
2810 label: "printf output_before_kill && sleep 60".to_string(),
2811 cwd: None,
2812 output_byte_limit: None,
2813 terminal: lower_terminal.clone(),
2814 },
2815 cx,
2816 );
2817 });
2818
2819 // Wait for the printf command to execute and produce output
2820 // Use real time since parking is enabled
2821 cx.executor().timer(Duration::from_millis(500)).await;
2822
2823 // Get the acp_thread Terminal and kill it
2824 let wait_for_exit = thread.update(cx, |thread, cx| {
2825 let term = thread.terminals.get(&terminal_id).unwrap();
2826 let wait_for_exit = term.read(cx).wait_for_exit();
2827 term.update(cx, |term, cx| {
2828 term.kill(cx);
2829 });
2830 wait_for_exit
2831 });
2832
2833 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2834 // Before the fix to kill_active_task, this would hang forever because
2835 // only the foreground process was killed, not the shell, so the PTY
2836 // child never exited and wait_for_completed_task never completed.
2837 let exit_result = futures::select! {
2838 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2839 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2840 };
2841
2842 assert!(
2843 exit_result.is_some(),
2844 "wait_for_exit should complete after kill, but it timed out. \
2845 This indicates kill_active_task is not properly killing the shell process."
2846 );
2847
2848 // Give the system a chance to process any pending updates
2849 cx.run_until_parked();
2850
2851 // Verify that the underlying terminal still has the output that was
2852 // written before the kill. This verifies that killing doesn't lose output.
2853 let inner_content = thread.read_with(cx, |thread, cx| {
2854 let term = thread.terminals.get(&terminal_id).unwrap();
2855 term.read(cx).inner().read(cx).get_content()
2856 });
2857
2858 assert!(
2859 inner_content.contains("output_before_kill"),
2860 "Underlying terminal should contain output from before kill, got: {}",
2861 inner_content
2862 );
2863 }
2864
2865 #[gpui::test]
2866 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2867 init_test(cx);
2868
2869 let fs = FakeFs::new(cx.executor());
2870 let project = Project::test(fs, [], cx).await;
2871 let connection = Rc::new(FakeAgentConnection::new());
2872 let thread = cx
2873 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2874 .await
2875 .unwrap();
2876
2877 // Test creating a new user message
2878 thread.update(cx, |thread, cx| {
2879 thread.push_user_content_block(None, "Hello, ".into(), cx);
2880 });
2881
2882 thread.update(cx, |thread, cx| {
2883 assert_eq!(thread.entries.len(), 1);
2884 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2885 assert_eq!(user_msg.id, None);
2886 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2887 } else {
2888 panic!("Expected UserMessage");
2889 }
2890 });
2891
2892 // Test appending to existing user message
2893 let message_1_id = UserMessageId::new();
2894 thread.update(cx, |thread, cx| {
2895 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2896 });
2897
2898 thread.update(cx, |thread, cx| {
2899 assert_eq!(thread.entries.len(), 1);
2900 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2901 assert_eq!(user_msg.id, Some(message_1_id));
2902 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2903 } else {
2904 panic!("Expected UserMessage");
2905 }
2906 });
2907
2908 // Test creating new user message after assistant message
2909 thread.update(cx, |thread, cx| {
2910 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2911 });
2912
2913 let message_2_id = UserMessageId::new();
2914 thread.update(cx, |thread, cx| {
2915 thread.push_user_content_block(
2916 Some(message_2_id.clone()),
2917 "New user message".into(),
2918 cx,
2919 );
2920 });
2921
2922 thread.update(cx, |thread, cx| {
2923 assert_eq!(thread.entries.len(), 3);
2924 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2925 assert_eq!(user_msg.id, Some(message_2_id));
2926 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2927 } else {
2928 panic!("Expected UserMessage at index 2");
2929 }
2930 });
2931 }
2932
2933 #[gpui::test]
2934 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2935 init_test(cx);
2936
2937 let fs = FakeFs::new(cx.executor());
2938 let project = Project::test(fs, [], cx).await;
2939 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2940 |_, thread, mut cx| {
2941 async move {
2942 thread.update(&mut cx, |thread, cx| {
2943 thread
2944 .handle_session_update(
2945 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2946 "Thinking ".into(),
2947 )),
2948 cx,
2949 )
2950 .unwrap();
2951 thread
2952 .handle_session_update(
2953 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2954 "hard!".into(),
2955 )),
2956 cx,
2957 )
2958 .unwrap();
2959 })?;
2960 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2961 }
2962 .boxed_local()
2963 },
2964 ));
2965
2966 let thread = cx
2967 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2968 .await
2969 .unwrap();
2970
2971 thread
2972 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2973 .await
2974 .unwrap();
2975
2976 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2977 assert_eq!(
2978 output,
2979 indoc! {r#"
2980 ## User
2981
2982 Hello from Zed!
2983
2984 ## Assistant
2985
2986 <thinking>
2987 Thinking hard!
2988 </thinking>
2989
2990 "#}
2991 );
2992 }
2993
2994 #[gpui::test]
2995 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2996 init_test(cx);
2997
2998 let fs = FakeFs::new(cx.executor());
2999 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3000 .await;
3001 let project = Project::test(fs.clone(), [], cx).await;
3002 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3003 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3004 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3005 move |_, thread, mut cx| {
3006 let read_file_tx = read_file_tx.clone();
3007 async move {
3008 let content = thread
3009 .update(&mut cx, |thread, cx| {
3010 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3011 })
3012 .unwrap()
3013 .await
3014 .unwrap();
3015 assert_eq!(content, "one\ntwo\nthree\n");
3016 read_file_tx.take().unwrap().send(()).unwrap();
3017 thread
3018 .update(&mut cx, |thread, cx| {
3019 thread.write_text_file(
3020 path!("/tmp/foo").into(),
3021 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3022 cx,
3023 )
3024 })
3025 .unwrap()
3026 .await
3027 .unwrap();
3028 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3029 }
3030 .boxed_local()
3031 },
3032 ));
3033
3034 let (worktree, pathbuf) = project
3035 .update(cx, |project, cx| {
3036 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3037 })
3038 .await
3039 .unwrap();
3040 let buffer = project
3041 .update(cx, |project, cx| {
3042 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3043 })
3044 .await
3045 .unwrap();
3046
3047 let thread = cx
3048 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3049 .await
3050 .unwrap();
3051
3052 let request = thread.update(cx, |thread, cx| {
3053 thread.send_raw("Extend the count in /tmp/foo", cx)
3054 });
3055 read_file_rx.await.ok();
3056 buffer.update(cx, |buffer, cx| {
3057 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3058 });
3059 cx.run_until_parked();
3060 assert_eq!(
3061 buffer.read_with(cx, |buffer, _| buffer.text()),
3062 "zero\none\ntwo\nthree\nfour\nfive\n"
3063 );
3064 assert_eq!(
3065 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3066 "zero\none\ntwo\nthree\nfour\nfive\n"
3067 );
3068 request.await.unwrap();
3069 }
3070
3071 #[gpui::test]
3072 async fn test_reading_from_line(cx: &mut TestAppContext) {
3073 init_test(cx);
3074
3075 let fs = FakeFs::new(cx.executor());
3076 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3077 .await;
3078 let project = Project::test(fs.clone(), [], cx).await;
3079 project
3080 .update(cx, |project, cx| {
3081 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3082 })
3083 .await
3084 .unwrap();
3085
3086 let connection = Rc::new(FakeAgentConnection::new());
3087
3088 let thread = cx
3089 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3090 .await
3091 .unwrap();
3092
3093 // Whole file
3094 let content = thread
3095 .update(cx, |thread, cx| {
3096 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3097 })
3098 .await
3099 .unwrap();
3100
3101 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3102
3103 // Only start line
3104 let content = thread
3105 .update(cx, |thread, cx| {
3106 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3107 })
3108 .await
3109 .unwrap();
3110
3111 assert_eq!(content, "three\nfour\n");
3112
3113 // Only limit
3114 let content = thread
3115 .update(cx, |thread, cx| {
3116 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3117 })
3118 .await
3119 .unwrap();
3120
3121 assert_eq!(content, "one\ntwo\n");
3122
3123 // Range
3124 let content = thread
3125 .update(cx, |thread, cx| {
3126 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3127 })
3128 .await
3129 .unwrap();
3130
3131 assert_eq!(content, "two\nthree\n");
3132
3133 // Invalid
3134 let err = thread
3135 .update(cx, |thread, cx| {
3136 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3137 })
3138 .await
3139 .unwrap_err();
3140
3141 assert_eq!(
3142 err.to_string(),
3143 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3144 );
3145 }
3146
3147 #[gpui::test]
3148 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3149 init_test(cx);
3150
3151 let fs = FakeFs::new(cx.executor());
3152 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3153 let project = Project::test(fs.clone(), [], cx).await;
3154 project
3155 .update(cx, |project, cx| {
3156 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3157 })
3158 .await
3159 .unwrap();
3160
3161 let connection = Rc::new(FakeAgentConnection::new());
3162
3163 let thread = cx
3164 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3165 .await
3166 .unwrap();
3167
3168 // Whole file
3169 let content = thread
3170 .update(cx, |thread, cx| {
3171 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3172 })
3173 .await
3174 .unwrap();
3175
3176 assert_eq!(content, "");
3177
3178 // Only start line
3179 let content = thread
3180 .update(cx, |thread, cx| {
3181 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3182 })
3183 .await
3184 .unwrap();
3185
3186 assert_eq!(content, "");
3187
3188 // Only limit
3189 let content = thread
3190 .update(cx, |thread, cx| {
3191 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3192 })
3193 .await
3194 .unwrap();
3195
3196 assert_eq!(content, "");
3197
3198 // Range
3199 let content = thread
3200 .update(cx, |thread, cx| {
3201 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3202 })
3203 .await
3204 .unwrap();
3205
3206 assert_eq!(content, "");
3207
3208 // Invalid
3209 let err = thread
3210 .update(cx, |thread, cx| {
3211 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3212 })
3213 .await
3214 .unwrap_err();
3215
3216 assert_eq!(
3217 err.to_string(),
3218 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3219 );
3220 }
3221 #[gpui::test]
3222 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3223 init_test(cx);
3224
3225 let fs = FakeFs::new(cx.executor());
3226 fs.insert_tree(path!("/tmp"), json!({})).await;
3227 let project = Project::test(fs.clone(), [], cx).await;
3228 project
3229 .update(cx, |project, cx| {
3230 project.find_or_create_worktree(path!("/tmp"), true, cx)
3231 })
3232 .await
3233 .unwrap();
3234
3235 let connection = Rc::new(FakeAgentConnection::new());
3236
3237 let thread = cx
3238 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3239 .await
3240 .unwrap();
3241
3242 // Out of project file
3243 let err = thread
3244 .update(cx, |thread, cx| {
3245 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3246 })
3247 .await
3248 .unwrap_err();
3249
3250 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3251 }
3252
3253 #[gpui::test]
3254 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3255 init_test(cx);
3256
3257 let fs = FakeFs::new(cx.executor());
3258 let project = Project::test(fs, [], cx).await;
3259 let id = acp::ToolCallId::new("test");
3260
3261 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3262 let id = id.clone();
3263 move |_, thread, mut cx| {
3264 let id = id.clone();
3265 async move {
3266 thread
3267 .update(&mut cx, |thread, cx| {
3268 thread.handle_session_update(
3269 acp::SessionUpdate::ToolCall(
3270 acp::ToolCall::new(id.clone(), "Label")
3271 .kind(acp::ToolKind::Fetch)
3272 .status(acp::ToolCallStatus::InProgress),
3273 ),
3274 cx,
3275 )
3276 })
3277 .unwrap()
3278 .unwrap();
3279 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3280 }
3281 .boxed_local()
3282 }
3283 }));
3284
3285 let thread = cx
3286 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3287 .await
3288 .unwrap();
3289
3290 let request = thread.update(cx, |thread, cx| {
3291 thread.send_raw("Fetch https://example.com", cx)
3292 });
3293
3294 run_until_first_tool_call(&thread, cx).await;
3295
3296 thread.read_with(cx, |thread, _| {
3297 assert!(matches!(
3298 thread.entries[1],
3299 AgentThreadEntry::ToolCall(ToolCall {
3300 status: ToolCallStatus::InProgress,
3301 ..
3302 })
3303 ));
3304 });
3305
3306 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3307
3308 thread.read_with(cx, |thread, _| {
3309 assert!(matches!(
3310 &thread.entries[1],
3311 AgentThreadEntry::ToolCall(ToolCall {
3312 status: ToolCallStatus::Canceled,
3313 ..
3314 })
3315 ));
3316 });
3317
3318 thread
3319 .update(cx, |thread, cx| {
3320 thread.handle_session_update(
3321 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3322 id,
3323 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3324 )),
3325 cx,
3326 )
3327 })
3328 .unwrap();
3329
3330 request.await.unwrap();
3331
3332 thread.read_with(cx, |thread, _| {
3333 assert!(matches!(
3334 thread.entries[1],
3335 AgentThreadEntry::ToolCall(ToolCall {
3336 status: ToolCallStatus::Completed,
3337 ..
3338 })
3339 ));
3340 });
3341 }
3342
3343 #[gpui::test]
3344 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3345 init_test(cx);
3346 let fs = FakeFs::new(cx.background_executor.clone());
3347 fs.insert_tree(path!("/test"), json!({})).await;
3348 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3349
3350 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3351 move |_, thread, mut cx| {
3352 async move {
3353 thread
3354 .update(&mut cx, |thread, cx| {
3355 thread.handle_session_update(
3356 acp::SessionUpdate::ToolCall(
3357 acp::ToolCall::new("test", "Label")
3358 .kind(acp::ToolKind::Edit)
3359 .status(acp::ToolCallStatus::Completed)
3360 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3361 "/test/test.txt",
3362 "foo",
3363 ))]),
3364 ),
3365 cx,
3366 )
3367 })
3368 .unwrap()
3369 .unwrap();
3370 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3371 }
3372 .boxed_local()
3373 }
3374 }));
3375
3376 let thread = cx
3377 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3378 .await
3379 .unwrap();
3380
3381 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3382 .await
3383 .unwrap();
3384
3385 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3386 }
3387
3388 #[gpui::test(iterations = 10)]
3389 async fn test_checkpoints(cx: &mut TestAppContext) {
3390 init_test(cx);
3391 let fs = FakeFs::new(cx.background_executor.clone());
3392 fs.insert_tree(
3393 path!("/test"),
3394 json!({
3395 ".git": {}
3396 }),
3397 )
3398 .await;
3399 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3400
3401 let simulate_changes = Arc::new(AtomicBool::new(true));
3402 let next_filename = Arc::new(AtomicUsize::new(0));
3403 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3404 let simulate_changes = simulate_changes.clone();
3405 let next_filename = next_filename.clone();
3406 let fs = fs.clone();
3407 move |request, thread, mut cx| {
3408 let fs = fs.clone();
3409 let simulate_changes = simulate_changes.clone();
3410 let next_filename = next_filename.clone();
3411 async move {
3412 if simulate_changes.load(SeqCst) {
3413 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3414 fs.write(Path::new(&filename), b"").await?;
3415 }
3416
3417 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3418 panic!("expected text content block");
3419 };
3420 thread.update(&mut cx, |thread, cx| {
3421 thread
3422 .handle_session_update(
3423 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3424 content.text.to_uppercase().into(),
3425 )),
3426 cx,
3427 )
3428 .unwrap();
3429 })?;
3430 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3431 }
3432 .boxed_local()
3433 }
3434 }));
3435 let thread = cx
3436 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3437 .await
3438 .unwrap();
3439
3440 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3441 .await
3442 .unwrap();
3443 thread.read_with(cx, |thread, cx| {
3444 assert_eq!(
3445 thread.to_markdown(cx),
3446 indoc! {"
3447 ## User (checkpoint)
3448
3449 Lorem
3450
3451 ## Assistant
3452
3453 LOREM
3454
3455 "}
3456 );
3457 });
3458 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3459
3460 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3461 .await
3462 .unwrap();
3463 thread.read_with(cx, |thread, cx| {
3464 assert_eq!(
3465 thread.to_markdown(cx),
3466 indoc! {"
3467 ## User (checkpoint)
3468
3469 Lorem
3470
3471 ## Assistant
3472
3473 LOREM
3474
3475 ## User (checkpoint)
3476
3477 ipsum
3478
3479 ## Assistant
3480
3481 IPSUM
3482
3483 "}
3484 );
3485 });
3486 assert_eq!(
3487 fs.files(),
3488 vec![
3489 Path::new(path!("/test/file-0")),
3490 Path::new(path!("/test/file-1"))
3491 ]
3492 );
3493
3494 // Checkpoint isn't stored when there are no changes.
3495 simulate_changes.store(false, SeqCst);
3496 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3497 .await
3498 .unwrap();
3499 thread.read_with(cx, |thread, cx| {
3500 assert_eq!(
3501 thread.to_markdown(cx),
3502 indoc! {"
3503 ## User (checkpoint)
3504
3505 Lorem
3506
3507 ## Assistant
3508
3509 LOREM
3510
3511 ## User (checkpoint)
3512
3513 ipsum
3514
3515 ## Assistant
3516
3517 IPSUM
3518
3519 ## User
3520
3521 dolor
3522
3523 ## Assistant
3524
3525 DOLOR
3526
3527 "}
3528 );
3529 });
3530 assert_eq!(
3531 fs.files(),
3532 vec![
3533 Path::new(path!("/test/file-0")),
3534 Path::new(path!("/test/file-1"))
3535 ]
3536 );
3537
3538 // Rewinding the conversation truncates the history and restores the checkpoint.
3539 thread
3540 .update(cx, |thread, cx| {
3541 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3542 panic!("unexpected entries {:?}", thread.entries)
3543 };
3544 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3545 })
3546 .await
3547 .unwrap();
3548 thread.read_with(cx, |thread, cx| {
3549 assert_eq!(
3550 thread.to_markdown(cx),
3551 indoc! {"
3552 ## User (checkpoint)
3553
3554 Lorem
3555
3556 ## Assistant
3557
3558 LOREM
3559
3560 "}
3561 );
3562 });
3563 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3564 }
3565
3566 #[gpui::test]
3567 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3568 use std::sync::atomic::AtomicUsize;
3569 init_test(cx);
3570
3571 let fs = FakeFs::new(cx.executor());
3572 let project = Project::test(fs, None, cx).await;
3573
3574 // Create a connection that simulates refusal after tool result
3575 let prompt_count = Arc::new(AtomicUsize::new(0));
3576 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3577 let prompt_count = prompt_count.clone();
3578 move |_request, thread, mut cx| {
3579 let count = prompt_count.fetch_add(1, SeqCst);
3580 async move {
3581 if count == 0 {
3582 // First prompt: Generate a tool call with result
3583 thread.update(&mut cx, |thread, cx| {
3584 thread
3585 .handle_session_update(
3586 acp::SessionUpdate::ToolCall(
3587 acp::ToolCall::new("tool1", "Test Tool")
3588 .kind(acp::ToolKind::Fetch)
3589 .status(acp::ToolCallStatus::Completed)
3590 .raw_input(serde_json::json!({"query": "test"}))
3591 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3592 ),
3593 cx,
3594 )
3595 .unwrap();
3596 })?;
3597
3598 // Now return refusal because of the tool result
3599 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3600 } else {
3601 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3602 }
3603 }
3604 .boxed_local()
3605 }
3606 }));
3607
3608 let thread = cx
3609 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3610 .await
3611 .unwrap();
3612
3613 // Track if we see a Refusal event
3614 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3615 let saw_refusal_event_captured = saw_refusal_event.clone();
3616 thread.update(cx, |_thread, cx| {
3617 cx.subscribe(
3618 &thread,
3619 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3620 if matches!(event, AcpThreadEvent::Refusal) {
3621 *saw_refusal_event_captured.lock().unwrap() = true;
3622 }
3623 },
3624 )
3625 .detach();
3626 });
3627
3628 // Send a user message - this will trigger tool call and then refusal
3629 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3630 cx.background_executor.spawn(send_task).detach();
3631 cx.run_until_parked();
3632
3633 // Verify that:
3634 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3635 // 2. The user message was NOT truncated
3636 assert!(
3637 *saw_refusal_event.lock().unwrap(),
3638 "Refusal event should be emitted for tool result refusals"
3639 );
3640
3641 thread.read_with(cx, |thread, _| {
3642 let entries = thread.entries();
3643 assert!(entries.len() >= 2, "Should have user message and tool call");
3644
3645 // Verify user message is still there
3646 assert!(
3647 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3648 "User message should not be truncated"
3649 );
3650
3651 // Verify tool call is there with result
3652 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3653 assert!(
3654 tool_call.raw_output.is_some(),
3655 "Tool call should have output"
3656 );
3657 } else {
3658 panic!("Expected tool call at index 1");
3659 }
3660 });
3661 }
3662
3663 #[gpui::test]
3664 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3665 init_test(cx);
3666
3667 let fs = FakeFs::new(cx.executor());
3668 let project = Project::test(fs, None, cx).await;
3669
3670 let refuse_next = Arc::new(AtomicBool::new(false));
3671 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3672 let refuse_next = refuse_next.clone();
3673 move |_request, _thread, _cx| {
3674 if refuse_next.load(SeqCst) {
3675 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3676 .boxed_local()
3677 } else {
3678 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3679 .boxed_local()
3680 }
3681 }
3682 }));
3683
3684 let thread = cx
3685 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3686 .await
3687 .unwrap();
3688
3689 // Track if we see a Refusal event
3690 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3691 let saw_refusal_event_captured = saw_refusal_event.clone();
3692 thread.update(cx, |_thread, cx| {
3693 cx.subscribe(
3694 &thread,
3695 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3696 if matches!(event, AcpThreadEvent::Refusal) {
3697 *saw_refusal_event_captured.lock().unwrap() = true;
3698 }
3699 },
3700 )
3701 .detach();
3702 });
3703
3704 // Send a message that will be refused
3705 refuse_next.store(true, SeqCst);
3706 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3707 .await
3708 .unwrap();
3709
3710 // Verify that a Refusal event WAS emitted for user prompt refusal
3711 assert!(
3712 *saw_refusal_event.lock().unwrap(),
3713 "Refusal event should be emitted for user prompt refusals"
3714 );
3715
3716 // Verify the message was truncated (user prompt refusal)
3717 thread.read_with(cx, |thread, cx| {
3718 assert_eq!(thread.to_markdown(cx), "");
3719 });
3720 }
3721
3722 #[gpui::test]
3723 async fn test_refusal(cx: &mut TestAppContext) {
3724 init_test(cx);
3725 let fs = FakeFs::new(cx.background_executor.clone());
3726 fs.insert_tree(path!("/"), json!({})).await;
3727 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3728
3729 let refuse_next = Arc::new(AtomicBool::new(false));
3730 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3731 let refuse_next = refuse_next.clone();
3732 move |request, thread, mut cx| {
3733 let refuse_next = refuse_next.clone();
3734 async move {
3735 if refuse_next.load(SeqCst) {
3736 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3737 }
3738
3739 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3740 panic!("expected text content block");
3741 };
3742 thread.update(&mut cx, |thread, cx| {
3743 thread
3744 .handle_session_update(
3745 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3746 content.text.to_uppercase().into(),
3747 )),
3748 cx,
3749 )
3750 .unwrap();
3751 })?;
3752 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3753 }
3754 .boxed_local()
3755 }
3756 }));
3757 let thread = cx
3758 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3759 .await
3760 .unwrap();
3761
3762 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3763 .await
3764 .unwrap();
3765 thread.read_with(cx, |thread, cx| {
3766 assert_eq!(
3767 thread.to_markdown(cx),
3768 indoc! {"
3769 ## User
3770
3771 hello
3772
3773 ## Assistant
3774
3775 HELLO
3776
3777 "}
3778 );
3779 });
3780
3781 // Simulate refusing the second message. The message should be truncated
3782 // when a user prompt is refused.
3783 refuse_next.store(true, SeqCst);
3784 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3785 .await
3786 .unwrap();
3787 thread.read_with(cx, |thread, cx| {
3788 assert_eq!(
3789 thread.to_markdown(cx),
3790 indoc! {"
3791 ## User
3792
3793 hello
3794
3795 ## Assistant
3796
3797 HELLO
3798
3799 "}
3800 );
3801 });
3802 }
3803
3804 async fn run_until_first_tool_call(
3805 thread: &Entity<AcpThread>,
3806 cx: &mut TestAppContext,
3807 ) -> usize {
3808 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3809
3810 let subscription = cx.update(|cx| {
3811 cx.subscribe(thread, move |thread, _, cx| {
3812 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3813 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3814 return tx.try_send(ix).unwrap();
3815 }
3816 }
3817 })
3818 });
3819
3820 select! {
3821 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3822 panic!("Timeout waiting for tool call")
3823 }
3824 ix = rx.next().fuse() => {
3825 drop(subscription);
3826 ix.unwrap()
3827 }
3828 }
3829 }
3830
3831 #[derive(Clone, Default)]
3832 struct FakeAgentConnection {
3833 auth_methods: Vec<acp::AuthMethod>,
3834 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3835 on_user_message: Option<
3836 Rc<
3837 dyn Fn(
3838 acp::PromptRequest,
3839 WeakEntity<AcpThread>,
3840 AsyncApp,
3841 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3842 + 'static,
3843 >,
3844 >,
3845 }
3846
3847 impl FakeAgentConnection {
3848 fn new() -> Self {
3849 Self {
3850 auth_methods: Vec::new(),
3851 on_user_message: None,
3852 sessions: Arc::default(),
3853 }
3854 }
3855
3856 #[expect(unused)]
3857 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3858 self.auth_methods = auth_methods;
3859 self
3860 }
3861
3862 fn on_user_message(
3863 mut self,
3864 handler: impl Fn(
3865 acp::PromptRequest,
3866 WeakEntity<AcpThread>,
3867 AsyncApp,
3868 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3869 + 'static,
3870 ) -> Self {
3871 self.on_user_message.replace(Rc::new(handler));
3872 self
3873 }
3874 }
3875
3876 impl AgentConnection for FakeAgentConnection {
3877 fn telemetry_id(&self) -> SharedString {
3878 "fake".into()
3879 }
3880
3881 fn auth_methods(&self) -> &[acp::AuthMethod] {
3882 &self.auth_methods
3883 }
3884
3885 fn new_session(
3886 self: Rc<Self>,
3887 project: Entity<Project>,
3888 _cwd: &Path,
3889 cx: &mut App,
3890 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3891 let session_id = acp::SessionId::new(
3892 rand::rng()
3893 .sample_iter(&distr::Alphanumeric)
3894 .take(7)
3895 .map(char::from)
3896 .collect::<String>(),
3897 );
3898 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3899 let thread = cx.new(|cx| {
3900 AcpThread::new(
3901 None,
3902 "Test",
3903 self.clone(),
3904 project,
3905 action_log,
3906 session_id.clone(),
3907 watch::Receiver::constant(
3908 acp::PromptCapabilities::new()
3909 .image(true)
3910 .audio(true)
3911 .embedded_context(true),
3912 ),
3913 cx,
3914 )
3915 });
3916 self.sessions.lock().insert(session_id, thread.downgrade());
3917 Task::ready(Ok(thread))
3918 }
3919
3920 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3921 if self.auth_methods().iter().any(|m| m.id == method) {
3922 Task::ready(Ok(()))
3923 } else {
3924 Task::ready(Err(anyhow!("Invalid Auth Method")))
3925 }
3926 }
3927
3928 fn prompt(
3929 &self,
3930 _id: Option<UserMessageId>,
3931 params: acp::PromptRequest,
3932 cx: &mut App,
3933 ) -> Task<gpui::Result<acp::PromptResponse>> {
3934 let sessions = self.sessions.lock();
3935 let thread = sessions.get(¶ms.session_id).unwrap();
3936 if let Some(handler) = &self.on_user_message {
3937 let handler = handler.clone();
3938 let thread = thread.clone();
3939 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3940 } else {
3941 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3942 }
3943 }
3944
3945 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3946 let sessions = self.sessions.lock();
3947 let thread = sessions.get(session_id).unwrap().clone();
3948
3949 cx.spawn(async move |cx| {
3950 thread
3951 .update(cx, |thread, cx| thread.cancel(cx))
3952 .unwrap()
3953 .await
3954 })
3955 .detach();
3956 }
3957
3958 fn truncate(
3959 &self,
3960 session_id: &acp::SessionId,
3961 _cx: &App,
3962 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3963 Some(Rc::new(FakeAgentSessionEditor {
3964 _session_id: session_id.clone(),
3965 }))
3966 }
3967
3968 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3969 self
3970 }
3971 }
3972
3973 struct FakeAgentSessionEditor {
3974 _session_id: acp::SessionId,
3975 }
3976
3977 impl AgentSessionTruncate for FakeAgentSessionEditor {
3978 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3979 Task::ready(Ok(()))
3980 }
3981 }
3982
3983 #[gpui::test]
3984 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3985 init_test(cx);
3986
3987 let fs = FakeFs::new(cx.executor());
3988 let project = Project::test(fs, [], cx).await;
3989 let connection = Rc::new(FakeAgentConnection::new());
3990 let thread = cx
3991 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3992 .await
3993 .unwrap();
3994
3995 // Try to update a tool call that doesn't exist
3996 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3997 thread.update(cx, |thread, cx| {
3998 let result = thread.handle_session_update(
3999 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4000 nonexistent_id.clone(),
4001 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4002 )),
4003 cx,
4004 );
4005
4006 // The update should succeed (not return an error)
4007 assert!(result.is_ok());
4008
4009 // There should now be exactly one entry in the thread
4010 assert_eq!(thread.entries.len(), 1);
4011
4012 // The entry should be a failed tool call
4013 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4014 assert_eq!(tool_call.id, nonexistent_id);
4015 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4016 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4017
4018 // Check that the content contains the error message
4019 assert_eq!(tool_call.content.len(), 1);
4020 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4021 match content_block {
4022 ContentBlock::Markdown { markdown } => {
4023 let markdown_text = markdown.read(cx).source();
4024 assert!(markdown_text.contains("Tool call not found"));
4025 }
4026 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4027 ContentBlock::ResourceLink { .. } => {
4028 panic!("Expected markdown content, got resource link")
4029 }
4030 ContentBlock::Image { .. } => {
4031 panic!("Expected markdown content, got image")
4032 }
4033 }
4034 } else {
4035 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4036 }
4037 } else {
4038 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4039 }
4040 });
4041 }
4042
4043 /// Tests that restoring a checkpoint properly cleans up terminals that were
4044 /// created after that checkpoint, and cancels any in-progress generation.
4045 ///
4046 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4047 /// that were started after that checkpoint should be terminated, and any in-progress
4048 /// AI generation should be canceled.
4049 #[gpui::test]
4050 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4051 init_test(cx);
4052
4053 let fs = FakeFs::new(cx.executor());
4054 let project = Project::test(fs, [], cx).await;
4055 let connection = Rc::new(FakeAgentConnection::new());
4056 let thread = cx
4057 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4058 .await
4059 .unwrap();
4060
4061 // Send first user message to create a checkpoint
4062 cx.update(|cx| {
4063 thread.update(cx, |thread, cx| {
4064 thread.send(vec!["first message".into()], cx)
4065 })
4066 })
4067 .await
4068 .unwrap();
4069
4070 // Send second message (creates another checkpoint) - we'll restore to this one
4071 cx.update(|cx| {
4072 thread.update(cx, |thread, cx| {
4073 thread.send(vec!["second message".into()], cx)
4074 })
4075 })
4076 .await
4077 .unwrap();
4078
4079 // Create 2 terminals BEFORE the checkpoint that have completed running
4080 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4081 let mock_terminal_1 = cx.new(|cx| {
4082 let builder = ::terminal::TerminalBuilder::new_display_only(
4083 ::terminal::terminal_settings::CursorShape::default(),
4084 ::terminal::terminal_settings::AlternateScroll::On,
4085 None,
4086 0,
4087 cx.background_executor(),
4088 PathStyle::local(),
4089 )
4090 .unwrap();
4091 builder.subscribe(cx)
4092 });
4093
4094 thread.update(cx, |thread, cx| {
4095 thread.on_terminal_provider_event(
4096 TerminalProviderEvent::Created {
4097 terminal_id: terminal_id_1.clone(),
4098 label: "echo 'first'".to_string(),
4099 cwd: Some(PathBuf::from("/test")),
4100 output_byte_limit: None,
4101 terminal: mock_terminal_1.clone(),
4102 },
4103 cx,
4104 );
4105 });
4106
4107 thread.update(cx, |thread, cx| {
4108 thread.on_terminal_provider_event(
4109 TerminalProviderEvent::Output {
4110 terminal_id: terminal_id_1.clone(),
4111 data: b"first\n".to_vec(),
4112 },
4113 cx,
4114 );
4115 });
4116
4117 thread.update(cx, |thread, cx| {
4118 thread.on_terminal_provider_event(
4119 TerminalProviderEvent::Exit {
4120 terminal_id: terminal_id_1.clone(),
4121 status: acp::TerminalExitStatus::new().exit_code(0),
4122 },
4123 cx,
4124 );
4125 });
4126
4127 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4128 let mock_terminal_2 = cx.new(|cx| {
4129 let builder = ::terminal::TerminalBuilder::new_display_only(
4130 ::terminal::terminal_settings::CursorShape::default(),
4131 ::terminal::terminal_settings::AlternateScroll::On,
4132 None,
4133 0,
4134 cx.background_executor(),
4135 PathStyle::local(),
4136 )
4137 .unwrap();
4138 builder.subscribe(cx)
4139 });
4140
4141 thread.update(cx, |thread, cx| {
4142 thread.on_terminal_provider_event(
4143 TerminalProviderEvent::Created {
4144 terminal_id: terminal_id_2.clone(),
4145 label: "echo 'second'".to_string(),
4146 cwd: Some(PathBuf::from("/test")),
4147 output_byte_limit: None,
4148 terminal: mock_terminal_2.clone(),
4149 },
4150 cx,
4151 );
4152 });
4153
4154 thread.update(cx, |thread, cx| {
4155 thread.on_terminal_provider_event(
4156 TerminalProviderEvent::Output {
4157 terminal_id: terminal_id_2.clone(),
4158 data: b"second\n".to_vec(),
4159 },
4160 cx,
4161 );
4162 });
4163
4164 thread.update(cx, |thread, cx| {
4165 thread.on_terminal_provider_event(
4166 TerminalProviderEvent::Exit {
4167 terminal_id: terminal_id_2.clone(),
4168 status: acp::TerminalExitStatus::new().exit_code(0),
4169 },
4170 cx,
4171 );
4172 });
4173
4174 // Get the second message ID to restore to
4175 let second_message_id = thread.read_with(cx, |thread, _| {
4176 // At this point we have:
4177 // - Index 0: First user message (with checkpoint)
4178 // - Index 1: Second user message (with checkpoint)
4179 // No assistant responses because FakeAgentConnection just returns EndTurn
4180 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4181 panic!("expected user message at index 1");
4182 };
4183 message.id.clone().unwrap()
4184 });
4185
4186 // Create a terminal AFTER the checkpoint we'll restore to.
4187 // This simulates the AI agent starting a long-running terminal command.
4188 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4189 let mock_terminal = cx.new(|cx| {
4190 let builder = ::terminal::TerminalBuilder::new_display_only(
4191 ::terminal::terminal_settings::CursorShape::default(),
4192 ::terminal::terminal_settings::AlternateScroll::On,
4193 None,
4194 0,
4195 cx.background_executor(),
4196 PathStyle::local(),
4197 )
4198 .unwrap();
4199 builder.subscribe(cx)
4200 });
4201
4202 // Register the terminal as created
4203 thread.update(cx, |thread, cx| {
4204 thread.on_terminal_provider_event(
4205 TerminalProviderEvent::Created {
4206 terminal_id: terminal_id.clone(),
4207 label: "sleep 1000".to_string(),
4208 cwd: Some(PathBuf::from("/test")),
4209 output_byte_limit: None,
4210 terminal: mock_terminal.clone(),
4211 },
4212 cx,
4213 );
4214 });
4215
4216 // Simulate the terminal producing output (still running)
4217 thread.update(cx, |thread, cx| {
4218 thread.on_terminal_provider_event(
4219 TerminalProviderEvent::Output {
4220 terminal_id: terminal_id.clone(),
4221 data: b"terminal is running...\n".to_vec(),
4222 },
4223 cx,
4224 );
4225 });
4226
4227 // Create a tool call entry that references this terminal
4228 // This represents the agent requesting a terminal command
4229 thread.update(cx, |thread, cx| {
4230 thread
4231 .handle_session_update(
4232 acp::SessionUpdate::ToolCall(
4233 acp::ToolCall::new("terminal-tool-1", "Running command")
4234 .kind(acp::ToolKind::Execute)
4235 .status(acp::ToolCallStatus::InProgress)
4236 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4237 terminal_id.clone(),
4238 ))])
4239 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4240 ),
4241 cx,
4242 )
4243 .unwrap();
4244 });
4245
4246 // Verify terminal exists and is in the thread
4247 let terminal_exists_before =
4248 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4249 assert!(
4250 terminal_exists_before,
4251 "Terminal should exist before checkpoint restore"
4252 );
4253
4254 // Verify the terminal's underlying task is still running (not completed)
4255 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4256 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4257 terminal_entity.read_with(cx, |term, _cx| {
4258 term.output().is_none() // output is None means it's still running
4259 })
4260 });
4261 assert!(
4262 terminal_running_before,
4263 "Terminal should be running before checkpoint restore"
4264 );
4265
4266 // Verify we have the expected entries before restore
4267 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4268 assert!(
4269 entry_count_before > 1,
4270 "Should have multiple entries before restore"
4271 );
4272
4273 // Restore the checkpoint to the second message.
4274 // This should:
4275 // 1. Cancel any in-progress generation (via the cancel() call)
4276 // 2. Remove the terminal that was created after that point
4277 thread
4278 .update(cx, |thread, cx| {
4279 thread.restore_checkpoint(second_message_id, cx)
4280 })
4281 .await
4282 .unwrap();
4283
4284 // Verify that no send_task is in progress after restore
4285 // (cancel() clears the send_task)
4286 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4287 assert!(
4288 !has_send_task_after,
4289 "Should not have a send_task after restore (cancel should have cleared it)"
4290 );
4291
4292 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4293 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4294 assert_eq!(
4295 entry_count, 1,
4296 "Should have 1 entry after restore (only the first user message)"
4297 );
4298
4299 // Verify the 2 completed terminals from before the checkpoint still exist
4300 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4301 thread.terminals.contains_key(&terminal_id_1)
4302 });
4303 assert!(
4304 terminal_1_exists,
4305 "Terminal 1 (from before checkpoint) should still exist"
4306 );
4307
4308 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4309 thread.terminals.contains_key(&terminal_id_2)
4310 });
4311 assert!(
4312 terminal_2_exists,
4313 "Terminal 2 (from before checkpoint) should still exist"
4314 );
4315
4316 // Verify they're still in completed state
4317 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4318 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4319 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4320 });
4321 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4322
4323 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4324 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4325 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4326 });
4327 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4328
4329 // Verify the running terminal (created after checkpoint) was removed
4330 let terminal_3_exists =
4331 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4332 assert!(
4333 !terminal_3_exists,
4334 "Terminal 3 (created after checkpoint) should have been removed"
4335 );
4336
4337 // Verify total count is 2 (the two from before the checkpoint)
4338 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4339 assert_eq!(
4340 terminal_count, 2,
4341 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4342 );
4343 }
4344
4345 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4346 /// even when a new user message is added while the async checkpoint comparison is in progress.
4347 ///
4348 /// This is a regression test for a bug where update_last_checkpoint would fail with
4349 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4350 /// update_last_checkpoint started and when its async closure ran.
4351 #[gpui::test]
4352 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4353 init_test(cx);
4354
4355 let fs = FakeFs::new(cx.executor());
4356 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4357 .await;
4358 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4359
4360 let handler_done = Arc::new(AtomicBool::new(false));
4361 let handler_done_clone = handler_done.clone();
4362 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4363 move |_, _thread, _cx| {
4364 handler_done_clone.store(true, SeqCst);
4365 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4366 },
4367 ));
4368
4369 let thread = cx
4370 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4371 .await
4372 .unwrap();
4373
4374 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4375 let send_task = cx.background_executor.spawn(send_future);
4376
4377 // Tick until handler completes, then a few more to let update_last_checkpoint start
4378 while !handler_done.load(SeqCst) {
4379 cx.executor().tick();
4380 }
4381 for _ in 0..5 {
4382 cx.executor().tick();
4383 }
4384
4385 thread.update(cx, |thread, cx| {
4386 thread.push_entry(
4387 AgentThreadEntry::UserMessage(UserMessage {
4388 id: Some(UserMessageId::new()),
4389 content: ContentBlock::Empty,
4390 chunks: vec!["Injected message (no checkpoint)".into()],
4391 checkpoint: None,
4392 indented: false,
4393 }),
4394 cx,
4395 );
4396 });
4397
4398 cx.run_until_parked();
4399 let result = send_task.await;
4400
4401 assert!(
4402 result.is_ok(),
4403 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4404 result.err()
4405 );
4406 }
4407}