1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 self.label.update(cx, |label, cx| {
318 if self.kind == acp::ToolKind::Execute {
319 label.replace(title, cx);
320 } else if let Some((first_line, _)) = title.split_once("\n") {
321 label.replace(first_line.to_owned() + "…", cx);
322 } else {
323 label.replace(title, cx);
324 }
325 });
326 }
327
328 if let Some(content) = content {
329 let mut new_content_len = content.len();
330 let mut content = content.into_iter();
331
332 // Reuse existing content if we can
333 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
334 let valid_content =
335 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
336 if !valid_content {
337 new_content_len -= 1;
338 }
339 }
340 for new in content {
341 if let Some(new) = ToolCallContent::from_acp(
342 new,
343 language_registry.clone(),
344 path_style,
345 terminals,
346 cx,
347 )? {
348 self.content.push(new);
349 } else {
350 new_content_len -= 1;
351 }
352 }
353 self.content.truncate(new_content_len);
354 }
355
356 if let Some(locations) = locations {
357 self.locations = locations;
358 }
359
360 if let Some(raw_input) = raw_input {
361 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
362 self.raw_input = Some(raw_input);
363 }
364
365 if let Some(raw_output) = raw_output {
366 if self.content.is_empty()
367 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
368 {
369 self.content
370 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
371 markdown,
372 }));
373 }
374 self.raw_output = Some(raw_output);
375 }
376 Ok(())
377 }
378
379 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
380 self.content.iter().filter_map(|content| match content {
381 ToolCallContent::Diff(diff) => Some(diff),
382 ToolCallContent::ContentBlock(_) => None,
383 ToolCallContent::Terminal(_) => None,
384 })
385 }
386
387 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
388 self.content.iter().filter_map(|content| match content {
389 ToolCallContent::Terminal(terminal) => Some(terminal),
390 ToolCallContent::ContentBlock(_) => None,
391 ToolCallContent::Diff(_) => None,
392 })
393 }
394
395 pub fn is_subagent(&self) -> bool {
396 self.tool_name.as_ref().is_some_and(|s| s == "subagent")
397 || self.subagent_session_id.is_some()
398 }
399
400 pub fn to_markdown(&self, cx: &App) -> String {
401 let mut markdown = format!(
402 "**Tool Call: {}**\nStatus: {}\n\n",
403 self.label.read(cx).source(),
404 self.status
405 );
406 for content in &self.content {
407 markdown.push_str(content.to_markdown(cx).as_str());
408 markdown.push_str("\n\n");
409 }
410 markdown
411 }
412
413 async fn resolve_location(
414 location: acp::ToolCallLocation,
415 project: WeakEntity<Project>,
416 cx: &mut AsyncApp,
417 ) -> Option<ResolvedLocation> {
418 let buffer = project
419 .update(cx, |project, cx| {
420 project
421 .project_path_for_absolute_path(&location.path, cx)
422 .map(|path| project.open_buffer(path, cx))
423 })
424 .ok()??;
425 let buffer = buffer.await.log_err()?;
426 let position = buffer.update(cx, |buffer, _| {
427 let snapshot = buffer.snapshot();
428 if let Some(row) = location.line {
429 let column = snapshot.indent_size_for_line(row).len;
430 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
431 snapshot.anchor_before(point)
432 } else {
433 Anchor::min_for_buffer(snapshot.remote_id())
434 }
435 });
436
437 Some(ResolvedLocation { buffer, position })
438 }
439
440 fn resolve_locations(
441 &self,
442 project: Entity<Project>,
443 cx: &mut App,
444 ) -> Task<Vec<Option<ResolvedLocation>>> {
445 let locations = self.locations.clone();
446 project.update(cx, |_, cx| {
447 cx.spawn(async move |project, cx| {
448 let mut new_locations = Vec::new();
449 for location in locations {
450 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
451 }
452 new_locations
453 })
454 })
455 }
456}
457
458// Separate so we can hold a strong reference to the buffer
459// for saving on the thread
460#[derive(Clone, Debug, PartialEq, Eq)]
461struct ResolvedLocation {
462 buffer: Entity<Buffer>,
463 position: Anchor,
464}
465
466impl From<&ResolvedLocation> for AgentLocation {
467 fn from(value: &ResolvedLocation) -> Self {
468 Self {
469 buffer: value.buffer.downgrade(),
470 position: value.position,
471 }
472 }
473}
474
475#[derive(Debug)]
476pub enum ToolCallStatus {
477 /// The tool call hasn't started running yet, but we start showing it to
478 /// the user.
479 Pending,
480 /// The tool call is waiting for confirmation from the user.
481 WaitingForConfirmation {
482 options: PermissionOptions,
483 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
484 },
485 /// The tool call is currently running.
486 InProgress,
487 /// The tool call completed successfully.
488 Completed,
489 /// The tool call failed.
490 Failed,
491 /// The user rejected the tool call.
492 Rejected,
493 /// The user canceled generation so the tool call was canceled.
494 Canceled,
495}
496
497impl From<acp::ToolCallStatus> for ToolCallStatus {
498 fn from(status: acp::ToolCallStatus) -> Self {
499 match status {
500 acp::ToolCallStatus::Pending => Self::Pending,
501 acp::ToolCallStatus::InProgress => Self::InProgress,
502 acp::ToolCallStatus::Completed => Self::Completed,
503 acp::ToolCallStatus::Failed => Self::Failed,
504 _ => Self::Pending,
505 }
506 }
507}
508
509impl Display for ToolCallStatus {
510 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
511 write!(
512 f,
513 "{}",
514 match self {
515 ToolCallStatus::Pending => "Pending",
516 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
517 ToolCallStatus::InProgress => "In Progress",
518 ToolCallStatus::Completed => "Completed",
519 ToolCallStatus::Failed => "Failed",
520 ToolCallStatus::Rejected => "Rejected",
521 ToolCallStatus::Canceled => "Canceled",
522 }
523 )
524 }
525}
526
527#[derive(Debug, PartialEq, Clone)]
528pub enum ContentBlock {
529 Empty,
530 Markdown { markdown: Entity<Markdown> },
531 ResourceLink { resource_link: acp::ResourceLink },
532 Image { image: Arc<gpui::Image> },
533}
534
535impl ContentBlock {
536 pub fn new(
537 block: acp::ContentBlock,
538 language_registry: &Arc<LanguageRegistry>,
539 path_style: PathStyle,
540 cx: &mut App,
541 ) -> Self {
542 let mut this = Self::Empty;
543 this.append(block, language_registry, path_style, cx);
544 this
545 }
546
547 pub fn new_combined(
548 blocks: impl IntoIterator<Item = acp::ContentBlock>,
549 language_registry: Arc<LanguageRegistry>,
550 path_style: PathStyle,
551 cx: &mut App,
552 ) -> Self {
553 let mut this = Self::Empty;
554 for block in blocks {
555 this.append(block, &language_registry, path_style, cx);
556 }
557 this
558 }
559
560 pub fn append(
561 &mut self,
562 block: acp::ContentBlock,
563 language_registry: &Arc<LanguageRegistry>,
564 path_style: PathStyle,
565 cx: &mut App,
566 ) {
567 match (&mut *self, &block) {
568 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
569 *self = ContentBlock::ResourceLink {
570 resource_link: resource_link.clone(),
571 };
572 }
573 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
574 if let Some(image) = Self::decode_image(image_content) {
575 *self = ContentBlock::Image { image };
576 } else {
577 let new_content = Self::image_md(image_content);
578 *self = Self::create_markdown_block(new_content, language_registry, cx);
579 }
580 }
581 (ContentBlock::Empty, _) => {
582 let new_content = Self::block_string_contents(&block, path_style);
583 *self = Self::create_markdown_block(new_content, language_registry, cx);
584 }
585 (ContentBlock::Markdown { markdown }, _) => {
586 let new_content = Self::block_string_contents(&block, path_style);
587 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
588 }
589 (ContentBlock::ResourceLink { resource_link }, _) => {
590 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
591 let new_content = Self::block_string_contents(&block, path_style);
592 let combined = format!("{}\n{}", existing_content, new_content);
593 *self = Self::create_markdown_block(combined, language_registry, cx);
594 }
595 (ContentBlock::Image { .. }, _) => {
596 let new_content = Self::block_string_contents(&block, path_style);
597 let combined = format!("`Image`\n{}", new_content);
598 *self = Self::create_markdown_block(combined, language_registry, cx);
599 }
600 }
601 }
602
603 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
604 use base64::Engine as _;
605
606 let bytes = base64::engine::general_purpose::STANDARD
607 .decode(image_content.data.as_bytes())
608 .ok()?;
609 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
610 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
611 }
612
613 fn create_markdown_block(
614 content: String,
615 language_registry: &Arc<LanguageRegistry>,
616 cx: &mut App,
617 ) -> ContentBlock {
618 ContentBlock::Markdown {
619 markdown: cx
620 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
621 }
622 }
623
624 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
625 match block {
626 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
627 acp::ContentBlock::ResourceLink(resource_link) => {
628 Self::resource_link_md(&resource_link.uri, path_style)
629 }
630 acp::ContentBlock::Resource(acp::EmbeddedResource {
631 resource:
632 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
633 uri,
634 ..
635 }),
636 ..
637 }) => Self::resource_link_md(uri, path_style),
638 acp::ContentBlock::Image(image) => Self::image_md(image),
639 _ => String::new(),
640 }
641 }
642
643 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
644 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
645 uri.as_link().to_string()
646 } else {
647 uri.to_string()
648 }
649 }
650
651 fn image_md(_image: &acp::ImageContent) -> String {
652 "`Image`".into()
653 }
654
655 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
656 match self {
657 ContentBlock::Empty => "",
658 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
659 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
660 ContentBlock::Image { .. } => "`Image`",
661 }
662 }
663
664 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
665 match self {
666 ContentBlock::Empty => None,
667 ContentBlock::Markdown { markdown } => Some(markdown),
668 ContentBlock::ResourceLink { .. } => None,
669 ContentBlock::Image { .. } => None,
670 }
671 }
672
673 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
674 match self {
675 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
676 _ => None,
677 }
678 }
679
680 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
681 match self {
682 ContentBlock::Image { image } => Some(image),
683 _ => None,
684 }
685 }
686}
687
688#[derive(Debug)]
689pub enum ToolCallContent {
690 ContentBlock(ContentBlock),
691 Diff(Entity<Diff>),
692 Terminal(Entity<Terminal>),
693}
694
695impl ToolCallContent {
696 pub fn from_acp(
697 content: acp::ToolCallContent,
698 language_registry: Arc<LanguageRegistry>,
699 path_style: PathStyle,
700 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
701 cx: &mut App,
702 ) -> Result<Option<Self>> {
703 match content {
704 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
705 Ok(Some(Self::ContentBlock(ContentBlock::new(
706 content,
707 &language_registry,
708 path_style,
709 cx,
710 ))))
711 }
712 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
713 Diff::finalized(
714 diff.path.to_string_lossy().into_owned(),
715 diff.old_text,
716 diff.new_text,
717 language_registry,
718 cx,
719 )
720 })))),
721 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
722 .get(&terminal_id)
723 .cloned()
724 .map(|terminal| Some(Self::Terminal(terminal)))
725 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
726 _ => Ok(None),
727 }
728 }
729
730 pub fn update_from_acp(
731 &mut self,
732 new: acp::ToolCallContent,
733 language_registry: Arc<LanguageRegistry>,
734 path_style: PathStyle,
735 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
736 cx: &mut App,
737 ) -> Result<bool> {
738 let needs_update = match (&self, &new) {
739 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
740 old_diff.read(cx).needs_update(
741 new_diff.old_text.as_deref().unwrap_or(""),
742 &new_diff.new_text,
743 cx,
744 )
745 }
746 _ => true,
747 };
748
749 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
750 if needs_update {
751 *self = update;
752 }
753 Ok(true)
754 } else {
755 Ok(false)
756 }
757 }
758
759 pub fn to_markdown(&self, cx: &App) -> String {
760 match self {
761 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
762 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
763 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
764 }
765 }
766
767 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
768 match self {
769 Self::ContentBlock(content) => content.image(),
770 _ => None,
771 }
772 }
773}
774
775#[derive(Debug, PartialEq)]
776pub enum ToolCallUpdate {
777 UpdateFields(acp::ToolCallUpdate),
778 UpdateDiff(ToolCallUpdateDiff),
779 UpdateTerminal(ToolCallUpdateTerminal),
780}
781
782impl ToolCallUpdate {
783 fn id(&self) -> &acp::ToolCallId {
784 match self {
785 Self::UpdateFields(update) => &update.tool_call_id,
786 Self::UpdateDiff(diff) => &diff.id,
787 Self::UpdateTerminal(terminal) => &terminal.id,
788 }
789 }
790}
791
792impl From<acp::ToolCallUpdate> for ToolCallUpdate {
793 fn from(update: acp::ToolCallUpdate) -> Self {
794 Self::UpdateFields(update)
795 }
796}
797
798impl From<ToolCallUpdateDiff> for ToolCallUpdate {
799 fn from(diff: ToolCallUpdateDiff) -> Self {
800 Self::UpdateDiff(diff)
801 }
802}
803
804#[derive(Debug, PartialEq)]
805pub struct ToolCallUpdateDiff {
806 pub id: acp::ToolCallId,
807 pub diff: Entity<Diff>,
808}
809
810impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
811 fn from(terminal: ToolCallUpdateTerminal) -> Self {
812 Self::UpdateTerminal(terminal)
813 }
814}
815
816#[derive(Debug, PartialEq)]
817pub struct ToolCallUpdateTerminal {
818 pub id: acp::ToolCallId,
819 pub terminal: Entity<Terminal>,
820}
821
822#[derive(Debug, Default)]
823pub struct Plan {
824 pub entries: Vec<PlanEntry>,
825}
826
827#[derive(Debug)]
828pub struct PlanStats<'a> {
829 pub in_progress_entry: Option<&'a PlanEntry>,
830 pub pending: u32,
831 pub completed: u32,
832}
833
834impl Plan {
835 pub fn is_empty(&self) -> bool {
836 self.entries.is_empty()
837 }
838
839 pub fn stats(&self) -> PlanStats<'_> {
840 let mut stats = PlanStats {
841 in_progress_entry: None,
842 pending: 0,
843 completed: 0,
844 };
845
846 for entry in &self.entries {
847 match &entry.status {
848 acp::PlanEntryStatus::Pending => {
849 stats.pending += 1;
850 }
851 acp::PlanEntryStatus::InProgress => {
852 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
853 }
854 acp::PlanEntryStatus::Completed => {
855 stats.completed += 1;
856 }
857 _ => {}
858 }
859 }
860
861 stats
862 }
863}
864
865#[derive(Debug)]
866pub struct PlanEntry {
867 pub content: Entity<Markdown>,
868 pub priority: acp::PlanEntryPriority,
869 pub status: acp::PlanEntryStatus,
870}
871
872impl PlanEntry {
873 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
874 Self {
875 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
876 priority: entry.priority,
877 status: entry.status,
878 }
879 }
880}
881
882#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
883pub struct TokenUsage {
884 pub max_tokens: u64,
885 pub used_tokens: u64,
886 pub input_tokens: u64,
887 pub output_tokens: u64,
888 pub max_output_tokens: Option<u64>,
889}
890
891impl TokenUsage {
892 pub fn ratio(&self) -> TokenUsageRatio {
893 #[cfg(debug_assertions)]
894 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
895 .unwrap_or("0.8".to_string())
896 .parse()
897 .unwrap();
898 #[cfg(not(debug_assertions))]
899 let warning_threshold: f32 = 0.8;
900
901 // When the maximum is unknown because there is no selected model,
902 // avoid showing the token limit warning.
903 if self.max_tokens == 0 {
904 TokenUsageRatio::Normal
905 } else if self.used_tokens >= self.max_tokens {
906 TokenUsageRatio::Exceeded
907 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
908 TokenUsageRatio::Warning
909 } else {
910 TokenUsageRatio::Normal
911 }
912 }
913}
914
915#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
916pub enum TokenUsageRatio {
917 Normal,
918 Warning,
919 Exceeded,
920}
921
922#[derive(Debug, Clone)]
923pub struct RetryStatus {
924 pub last_error: SharedString,
925 pub attempt: usize,
926 pub max_attempts: usize,
927 pub started_at: Instant,
928 pub duration: Duration,
929}
930
931pub struct AcpThread {
932 parent_session_id: Option<acp::SessionId>,
933 title: SharedString,
934 entries: Vec<AgentThreadEntry>,
935 plan: Plan,
936 project: Entity<Project>,
937 action_log: Entity<ActionLog>,
938 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
939 send_task: Option<Task<()>>,
940 connection: Rc<dyn AgentConnection>,
941 session_id: acp::SessionId,
942 token_usage: Option<TokenUsage>,
943 prompt_capabilities: acp::PromptCapabilities,
944 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
945 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
946 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
947 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
948 // subagent cancellation fields
949 user_stopped: Arc<std::sync::atomic::AtomicBool>,
950 user_stop_tx: watch::Sender<bool>,
951}
952
953impl From<&AcpThread> for ActionLogTelemetry {
954 fn from(value: &AcpThread) -> Self {
955 Self {
956 agent_telemetry_id: value.connection().telemetry_id(),
957 session_id: value.session_id.0.clone(),
958 }
959 }
960}
961
962#[derive(Debug)]
963pub enum AcpThreadEvent {
964 NewEntry,
965 TitleUpdated,
966 TokenUsageUpdated,
967 EntryUpdated(usize),
968 EntriesRemoved(Range<usize>),
969 ToolAuthorizationRequired,
970 Retry(RetryStatus),
971 SubagentSpawned(acp::SessionId),
972 Stopped,
973 Error,
974 LoadError(LoadError),
975 PromptCapabilitiesUpdated,
976 Refusal,
977 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
978 ModeUpdated(acp::SessionModeId),
979 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
980}
981
982impl EventEmitter<AcpThreadEvent> for AcpThread {}
983
984#[derive(Debug, Clone)]
985pub enum TerminalProviderEvent {
986 Created {
987 terminal_id: acp::TerminalId,
988 label: String,
989 cwd: Option<PathBuf>,
990 output_byte_limit: Option<u64>,
991 terminal: Entity<::terminal::Terminal>,
992 },
993 Output {
994 terminal_id: acp::TerminalId,
995 data: Vec<u8>,
996 },
997 TitleChanged {
998 terminal_id: acp::TerminalId,
999 title: String,
1000 },
1001 Exit {
1002 terminal_id: acp::TerminalId,
1003 status: acp::TerminalExitStatus,
1004 },
1005}
1006
1007#[derive(Debug, Clone)]
1008pub enum TerminalProviderCommand {
1009 WriteInput {
1010 terminal_id: acp::TerminalId,
1011 bytes: Vec<u8>,
1012 },
1013 Resize {
1014 terminal_id: acp::TerminalId,
1015 cols: u16,
1016 rows: u16,
1017 },
1018 Close {
1019 terminal_id: acp::TerminalId,
1020 },
1021}
1022
1023impl AcpThread {
1024 pub fn on_terminal_provider_event(
1025 &mut self,
1026 event: TerminalProviderEvent,
1027 cx: &mut Context<Self>,
1028 ) {
1029 match event {
1030 TerminalProviderEvent::Created {
1031 terminal_id,
1032 label,
1033 cwd,
1034 output_byte_limit,
1035 terminal,
1036 } => {
1037 let entity = self.register_terminal_created(
1038 terminal_id.clone(),
1039 label,
1040 cwd,
1041 output_byte_limit,
1042 terminal,
1043 cx,
1044 );
1045
1046 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1047 for data in chunks.drain(..) {
1048 entity.update(cx, |term, cx| {
1049 term.inner().update(cx, |inner, cx| {
1050 inner.write_output(&data, cx);
1051 })
1052 });
1053 }
1054 }
1055
1056 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1057 entity.update(cx, |_term, cx| {
1058 cx.notify();
1059 });
1060 }
1061
1062 cx.notify();
1063 }
1064 TerminalProviderEvent::Output { terminal_id, data } => {
1065 if let Some(entity) = self.terminals.get(&terminal_id) {
1066 entity.update(cx, |term, cx| {
1067 term.inner().update(cx, |inner, cx| {
1068 inner.write_output(&data, cx);
1069 })
1070 });
1071 } else {
1072 self.pending_terminal_output
1073 .entry(terminal_id)
1074 .or_default()
1075 .push(data);
1076 }
1077 }
1078 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1079 if let Some(entity) = self.terminals.get(&terminal_id) {
1080 entity.update(cx, |term, cx| {
1081 term.inner().update(cx, |inner, cx| {
1082 inner.breadcrumb_text = title;
1083 cx.emit(::terminal::Event::BreadcrumbsChanged);
1084 })
1085 });
1086 }
1087 }
1088 TerminalProviderEvent::Exit {
1089 terminal_id,
1090 status,
1091 } => {
1092 if let Some(entity) = self.terminals.get(&terminal_id) {
1093 entity.update(cx, |_term, cx| {
1094 cx.notify();
1095 });
1096 } else {
1097 self.pending_terminal_exit.insert(terminal_id, status);
1098 }
1099 }
1100 }
1101 }
1102}
1103
1104#[derive(PartialEq, Eq, Debug)]
1105pub enum ThreadStatus {
1106 Idle,
1107 Generating,
1108}
1109
1110#[derive(Debug, Clone)]
1111pub enum LoadError {
1112 Unsupported {
1113 command: SharedString,
1114 current_version: SharedString,
1115 minimum_version: SharedString,
1116 },
1117 FailedToInstall(SharedString),
1118 Exited {
1119 status: ExitStatus,
1120 },
1121 Other(SharedString),
1122}
1123
1124impl Display for LoadError {
1125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1126 match self {
1127 LoadError::Unsupported {
1128 command: path,
1129 current_version,
1130 minimum_version,
1131 } => {
1132 write!(
1133 f,
1134 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1135 )
1136 }
1137 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1138 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1139 LoadError::Other(msg) => write!(f, "{msg}"),
1140 }
1141 }
1142}
1143
1144impl Error for LoadError {}
1145
1146impl AcpThread {
1147 pub fn new(
1148 parent_session_id: Option<acp::SessionId>,
1149 title: impl Into<SharedString>,
1150 connection: Rc<dyn AgentConnection>,
1151 project: Entity<Project>,
1152 action_log: Entity<ActionLog>,
1153 session_id: acp::SessionId,
1154 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1155 cx: &mut Context<Self>,
1156 ) -> Self {
1157 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1158 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1159 loop {
1160 let caps = prompt_capabilities_rx.recv().await?;
1161 this.update(cx, |this, cx| {
1162 this.prompt_capabilities = caps;
1163 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1164 })?;
1165 }
1166 });
1167
1168 let (user_stop_tx, _user_stop_rx) = watch::channel(false);
1169
1170 Self {
1171 parent_session_id,
1172 action_log,
1173 shared_buffers: Default::default(),
1174 entries: Default::default(),
1175 plan: Default::default(),
1176 title: title.into(),
1177 project,
1178 send_task: None,
1179 connection,
1180 session_id,
1181 token_usage: None,
1182 prompt_capabilities,
1183 _observe_prompt_capabilities: task,
1184 terminals: HashMap::default(),
1185 pending_terminal_output: HashMap::default(),
1186 pending_terminal_exit: HashMap::default(),
1187 user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1188 user_stop_tx,
1189 }
1190 }
1191
1192 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1193 self.parent_session_id.as_ref()
1194 }
1195
1196 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1197 self.prompt_capabilities.clone()
1198 }
1199
1200 /// Marks this thread as stopped by user action and signals any listeners.
1201 pub fn stop_by_user(&mut self) {
1202 self.user_stopped
1203 .store(true, std::sync::atomic::Ordering::SeqCst);
1204 self.user_stop_tx.send(true).ok();
1205 self.send_task.take();
1206 }
1207
1208 pub fn was_stopped_by_user(&self) -> bool {
1209 self.user_stopped.load(std::sync::atomic::Ordering::SeqCst)
1210 }
1211
1212 pub fn user_stop_receiver(&self) -> watch::Receiver<bool> {
1213 self.user_stop_tx.receiver()
1214 }
1215
1216 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1217 &self.connection
1218 }
1219
1220 pub fn action_log(&self) -> &Entity<ActionLog> {
1221 &self.action_log
1222 }
1223
1224 pub fn project(&self) -> &Entity<Project> {
1225 &self.project
1226 }
1227
1228 pub fn title(&self) -> SharedString {
1229 self.title.clone()
1230 }
1231
1232 pub fn entries(&self) -> &[AgentThreadEntry] {
1233 &self.entries
1234 }
1235
1236 pub fn session_id(&self) -> &acp::SessionId {
1237 &self.session_id
1238 }
1239
1240 pub fn status(&self) -> ThreadStatus {
1241 if self.send_task.is_some() {
1242 ThreadStatus::Generating
1243 } else {
1244 ThreadStatus::Idle
1245 }
1246 }
1247
1248 pub fn token_usage(&self) -> Option<&TokenUsage> {
1249 self.token_usage.as_ref()
1250 }
1251
1252 pub fn has_pending_edit_tool_calls(&self) -> bool {
1253 for entry in self.entries.iter().rev() {
1254 match entry {
1255 AgentThreadEntry::UserMessage(_) => return false,
1256 AgentThreadEntry::ToolCall(
1257 call @ ToolCall {
1258 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1259 ..
1260 },
1261 ) if call.diffs().next().is_some() => {
1262 return true;
1263 }
1264 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1265 }
1266 }
1267
1268 false
1269 }
1270
1271 pub fn has_in_progress_tool_calls(&self) -> bool {
1272 for entry in self.entries.iter().rev() {
1273 match entry {
1274 AgentThreadEntry::UserMessage(_) => return false,
1275 AgentThreadEntry::ToolCall(ToolCall {
1276 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1277 ..
1278 }) => {
1279 return true;
1280 }
1281 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1282 }
1283 }
1284
1285 false
1286 }
1287
1288 pub fn used_tools_since_last_user_message(&self) -> bool {
1289 for entry in self.entries.iter().rev() {
1290 match entry {
1291 AgentThreadEntry::UserMessage(..) => return false,
1292 AgentThreadEntry::AssistantMessage(..) => continue,
1293 AgentThreadEntry::ToolCall(..) => return true,
1294 }
1295 }
1296
1297 false
1298 }
1299
1300 pub fn handle_session_update(
1301 &mut self,
1302 update: acp::SessionUpdate,
1303 cx: &mut Context<Self>,
1304 ) -> Result<(), acp::Error> {
1305 match update {
1306 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1307 self.push_user_content_block(None, content, cx);
1308 }
1309 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1310 self.push_assistant_content_block(content, false, cx);
1311 }
1312 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1313 self.push_assistant_content_block(content, true, cx);
1314 }
1315 acp::SessionUpdate::ToolCall(tool_call) => {
1316 self.upsert_tool_call(tool_call, cx)?;
1317 }
1318 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1319 self.update_tool_call(tool_call_update, cx)?;
1320 }
1321 acp::SessionUpdate::Plan(plan) => {
1322 self.update_plan(plan, cx);
1323 }
1324 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1325 available_commands,
1326 ..
1327 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1328 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1329 current_mode_id,
1330 ..
1331 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1332 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1333 config_options,
1334 ..
1335 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1336 _ => {}
1337 }
1338 Ok(())
1339 }
1340
1341 pub fn push_user_content_block(
1342 &mut self,
1343 message_id: Option<UserMessageId>,
1344 chunk: acp::ContentBlock,
1345 cx: &mut Context<Self>,
1346 ) {
1347 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1348 }
1349
1350 pub fn push_user_content_block_with_indent(
1351 &mut self,
1352 message_id: Option<UserMessageId>,
1353 chunk: acp::ContentBlock,
1354 indented: bool,
1355 cx: &mut Context<Self>,
1356 ) {
1357 let language_registry = self.project.read(cx).languages().clone();
1358 let path_style = self.project.read(cx).path_style(cx);
1359 let entries_len = self.entries.len();
1360
1361 if let Some(last_entry) = self.entries.last_mut()
1362 && let AgentThreadEntry::UserMessage(UserMessage {
1363 id,
1364 content,
1365 chunks,
1366 indented: existing_indented,
1367 ..
1368 }) = last_entry
1369 && *existing_indented == indented
1370 {
1371 *id = message_id.or(id.take());
1372 content.append(chunk.clone(), &language_registry, path_style, cx);
1373 chunks.push(chunk);
1374 let idx = entries_len - 1;
1375 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1376 } else {
1377 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1378 self.push_entry(
1379 AgentThreadEntry::UserMessage(UserMessage {
1380 id: message_id,
1381 content,
1382 chunks: vec![chunk],
1383 checkpoint: None,
1384 indented,
1385 }),
1386 cx,
1387 );
1388 }
1389 }
1390
1391 pub fn push_assistant_content_block(
1392 &mut self,
1393 chunk: acp::ContentBlock,
1394 is_thought: bool,
1395 cx: &mut Context<Self>,
1396 ) {
1397 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1398 }
1399
1400 pub fn push_assistant_content_block_with_indent(
1401 &mut self,
1402 chunk: acp::ContentBlock,
1403 is_thought: bool,
1404 indented: bool,
1405 cx: &mut Context<Self>,
1406 ) {
1407 let language_registry = self.project.read(cx).languages().clone();
1408 let path_style = self.project.read(cx).path_style(cx);
1409 let entries_len = self.entries.len();
1410 if let Some(last_entry) = self.entries.last_mut()
1411 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1412 chunks,
1413 indented: existing_indented,
1414 }) = last_entry
1415 && *existing_indented == indented
1416 {
1417 let idx = entries_len - 1;
1418 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1419 match (chunks.last_mut(), is_thought) {
1420 (Some(AssistantMessageChunk::Message { block }), false)
1421 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1422 block.append(chunk, &language_registry, path_style, cx)
1423 }
1424 _ => {
1425 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1426 if is_thought {
1427 chunks.push(AssistantMessageChunk::Thought { block })
1428 } else {
1429 chunks.push(AssistantMessageChunk::Message { block })
1430 }
1431 }
1432 }
1433 } else {
1434 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1435 let chunk = if is_thought {
1436 AssistantMessageChunk::Thought { block }
1437 } else {
1438 AssistantMessageChunk::Message { block }
1439 };
1440
1441 self.push_entry(
1442 AgentThreadEntry::AssistantMessage(AssistantMessage {
1443 chunks: vec![chunk],
1444 indented,
1445 }),
1446 cx,
1447 );
1448 }
1449 }
1450
1451 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1452 self.entries.push(entry);
1453 cx.emit(AcpThreadEvent::NewEntry);
1454 }
1455
1456 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1457 self.connection.set_title(&self.session_id, cx).is_some()
1458 }
1459
1460 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1461 if title != self.title {
1462 self.title = title.clone();
1463 cx.emit(AcpThreadEvent::TitleUpdated);
1464 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1465 return set_title.run(title, cx);
1466 }
1467 }
1468 Task::ready(Ok(()))
1469 }
1470
1471 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1472 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1473 }
1474
1475 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1476 self.token_usage = usage;
1477 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1478 }
1479
1480 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1481 cx.emit(AcpThreadEvent::Retry(status));
1482 }
1483
1484 pub fn update_tool_call(
1485 &mut self,
1486 update: impl Into<ToolCallUpdate>,
1487 cx: &mut Context<Self>,
1488 ) -> Result<()> {
1489 let update = update.into();
1490 let languages = self.project.read(cx).languages().clone();
1491 let path_style = self.project.read(cx).path_style(cx);
1492
1493 let ix = match self.index_for_tool_call(update.id()) {
1494 Some(ix) => ix,
1495 None => {
1496 // Tool call not found - create a failed tool call entry
1497 let failed_tool_call = ToolCall {
1498 id: update.id().clone(),
1499 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1500 kind: acp::ToolKind::Fetch,
1501 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1502 "Tool call not found".into(),
1503 &languages,
1504 path_style,
1505 cx,
1506 ))],
1507 status: ToolCallStatus::Failed,
1508 locations: Vec::new(),
1509 resolved_locations: Vec::new(),
1510 raw_input: None,
1511 raw_input_markdown: None,
1512 raw_output: None,
1513 tool_name: None,
1514 subagent_session_id: None,
1515 };
1516 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1517 return Ok(());
1518 }
1519 };
1520 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1521 unreachable!()
1522 };
1523
1524 match update {
1525 ToolCallUpdate::UpdateFields(update) => {
1526 let location_updated = update.fields.locations.is_some();
1527 call.update_fields(
1528 update.fields,
1529 update.meta,
1530 languages,
1531 path_style,
1532 &self.terminals,
1533 cx,
1534 )?;
1535 if location_updated {
1536 self.resolve_locations(update.tool_call_id, cx);
1537 }
1538 }
1539 ToolCallUpdate::UpdateDiff(update) => {
1540 call.content.clear();
1541 call.content.push(ToolCallContent::Diff(update.diff));
1542 }
1543 ToolCallUpdate::UpdateTerminal(update) => {
1544 call.content.clear();
1545 call.content
1546 .push(ToolCallContent::Terminal(update.terminal));
1547 }
1548 }
1549
1550 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1551
1552 Ok(())
1553 }
1554
1555 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1556 pub fn upsert_tool_call(
1557 &mut self,
1558 tool_call: acp::ToolCall,
1559 cx: &mut Context<Self>,
1560 ) -> Result<(), acp::Error> {
1561 let status = tool_call.status.into();
1562 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1563 }
1564
1565 /// Fails if id does not match an existing entry.
1566 pub fn upsert_tool_call_inner(
1567 &mut self,
1568 update: acp::ToolCallUpdate,
1569 status: ToolCallStatus,
1570 cx: &mut Context<Self>,
1571 ) -> Result<(), acp::Error> {
1572 let language_registry = self.project.read(cx).languages().clone();
1573 let path_style = self.project.read(cx).path_style(cx);
1574 let id = update.tool_call_id.clone();
1575
1576 let agent_telemetry_id = self.connection().telemetry_id();
1577 let session = self.session_id();
1578 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1579 let status = if matches!(status, ToolCallStatus::Completed) {
1580 "completed"
1581 } else {
1582 "failed"
1583 };
1584 telemetry::event!(
1585 "Agent Tool Call Completed",
1586 agent_telemetry_id,
1587 session,
1588 status
1589 );
1590 }
1591
1592 if let Some(ix) = self.index_for_tool_call(&id) {
1593 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1594 unreachable!()
1595 };
1596
1597 call.update_fields(
1598 update.fields,
1599 update.meta,
1600 language_registry,
1601 path_style,
1602 &self.terminals,
1603 cx,
1604 )?;
1605 call.status = status;
1606
1607 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1608 } else {
1609 let call = ToolCall::from_acp(
1610 update.try_into()?,
1611 status,
1612 language_registry,
1613 self.project.read(cx).path_style(cx),
1614 &self.terminals,
1615 cx,
1616 )?;
1617 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1618 };
1619
1620 self.resolve_locations(id, cx);
1621 Ok(())
1622 }
1623
1624 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1625 self.entries
1626 .iter()
1627 .enumerate()
1628 .rev()
1629 .find_map(|(index, entry)| {
1630 if let AgentThreadEntry::ToolCall(tool_call) = entry
1631 && &tool_call.id == id
1632 {
1633 Some(index)
1634 } else {
1635 None
1636 }
1637 })
1638 }
1639
1640 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1641 // The tool call we are looking for is typically the last one, or very close to the end.
1642 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1643 self.entries
1644 .iter_mut()
1645 .enumerate()
1646 .rev()
1647 .find_map(|(index, tool_call)| {
1648 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1649 && &tool_call.id == id
1650 {
1651 Some((index, tool_call))
1652 } else {
1653 None
1654 }
1655 })
1656 }
1657
1658 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1659 self.entries
1660 .iter()
1661 .enumerate()
1662 .rev()
1663 .find_map(|(index, tool_call)| {
1664 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1665 && &tool_call.id == id
1666 {
1667 Some((index, tool_call))
1668 } else {
1669 None
1670 }
1671 })
1672 }
1673
1674 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1675 let project = self.project.clone();
1676 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1677 return;
1678 };
1679 let task = tool_call.resolve_locations(project, cx);
1680 cx.spawn(async move |this, cx| {
1681 let resolved_locations = task.await;
1682
1683 this.update(cx, |this, cx| {
1684 let project = this.project.clone();
1685
1686 for location in resolved_locations.iter().flatten() {
1687 this.shared_buffers
1688 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1689 }
1690 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1691 return;
1692 };
1693
1694 if let Some(Some(location)) = resolved_locations.last() {
1695 project.update(cx, |project, cx| {
1696 let should_ignore = if let Some(agent_location) = project
1697 .agent_location()
1698 .filter(|agent_location| agent_location.buffer == location.buffer)
1699 {
1700 let snapshot = location.buffer.read(cx).snapshot();
1701 let old_position = agent_location.position.to_point(&snapshot);
1702 let new_position = location.position.to_point(&snapshot);
1703
1704 // ignore this so that when we get updates from the edit tool
1705 // the position doesn't reset to the startof line
1706 old_position.row == new_position.row
1707 && old_position.column > new_position.column
1708 } else {
1709 false
1710 };
1711 if !should_ignore {
1712 project.set_agent_location(Some(location.into()), cx);
1713 }
1714 });
1715 }
1716
1717 let resolved_locations = resolved_locations
1718 .iter()
1719 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1720 .collect::<Vec<_>>();
1721
1722 if tool_call.resolved_locations != resolved_locations {
1723 tool_call.resolved_locations = resolved_locations;
1724 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1725 }
1726 })
1727 })
1728 .detach();
1729 }
1730
1731 pub fn request_tool_call_authorization(
1732 &mut self,
1733 tool_call: acp::ToolCallUpdate,
1734 options: PermissionOptions,
1735 cx: &mut Context<Self>,
1736 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1737 let (tx, rx) = oneshot::channel();
1738
1739 let status = ToolCallStatus::WaitingForConfirmation {
1740 options,
1741 respond_tx: tx,
1742 };
1743
1744 self.upsert_tool_call_inner(tool_call, status, cx)?;
1745 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1746
1747 let fut = async {
1748 match rx.await {
1749 Ok(option) => acp::RequestPermissionOutcome::Selected(
1750 acp::SelectedPermissionOutcome::new(option),
1751 ),
1752 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1753 }
1754 }
1755 .boxed();
1756
1757 Ok(fut)
1758 }
1759
1760 pub fn authorize_tool_call(
1761 &mut self,
1762 id: acp::ToolCallId,
1763 option_id: acp::PermissionOptionId,
1764 option_kind: acp::PermissionOptionKind,
1765 cx: &mut Context<Self>,
1766 ) {
1767 let Some((ix, call)) = self.tool_call_mut(&id) else {
1768 return;
1769 };
1770
1771 let new_status = match option_kind {
1772 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1773 ToolCallStatus::Rejected
1774 }
1775 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1776 ToolCallStatus::InProgress
1777 }
1778 _ => ToolCallStatus::InProgress,
1779 };
1780
1781 let curr_status = mem::replace(&mut call.status, new_status);
1782
1783 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1784 respond_tx.send(option_id).log_err();
1785 } else if cfg!(debug_assertions) {
1786 panic!("tried to authorize an already authorized tool call");
1787 }
1788
1789 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1790 }
1791
1792 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1793 let mut first_tool_call = None;
1794
1795 for entry in self.entries.iter().rev() {
1796 match &entry {
1797 AgentThreadEntry::ToolCall(call) => {
1798 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1799 first_tool_call = Some(call);
1800 } else {
1801 continue;
1802 }
1803 }
1804 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1805 // Reached the beginning of the turn.
1806 // If we had pending permission requests in the previous turn, they have been cancelled.
1807 break;
1808 }
1809 }
1810 }
1811
1812 first_tool_call
1813 }
1814
1815 pub fn plan(&self) -> &Plan {
1816 &self.plan
1817 }
1818
1819 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1820 let new_entries_len = request.entries.len();
1821 let mut new_entries = request.entries.into_iter();
1822
1823 // Reuse existing markdown to prevent flickering
1824 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1825 let PlanEntry {
1826 content,
1827 priority,
1828 status,
1829 } = old;
1830 content.update(cx, |old, cx| {
1831 old.replace(new.content, cx);
1832 });
1833 *priority = new.priority;
1834 *status = new.status;
1835 }
1836 for new in new_entries {
1837 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1838 }
1839 self.plan.entries.truncate(new_entries_len);
1840
1841 cx.notify();
1842 }
1843
1844 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1845 self.plan
1846 .entries
1847 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1848 cx.notify();
1849 }
1850
1851 #[cfg(any(test, feature = "test-support"))]
1852 pub fn send_raw(
1853 &mut self,
1854 message: &str,
1855 cx: &mut Context<Self>,
1856 ) -> BoxFuture<'static, Result<()>> {
1857 self.send(vec![message.into()], cx)
1858 }
1859
1860 pub fn send(
1861 &mut self,
1862 message: Vec<acp::ContentBlock>,
1863 cx: &mut Context<Self>,
1864 ) -> BoxFuture<'static, Result<()>> {
1865 let block = ContentBlock::new_combined(
1866 message.clone(),
1867 self.project.read(cx).languages().clone(),
1868 self.project.read(cx).path_style(cx),
1869 cx,
1870 );
1871 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1872 let git_store = self.project.read(cx).git_store().clone();
1873
1874 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1875 Some(UserMessageId::new())
1876 } else {
1877 None
1878 };
1879
1880 self.run_turn(cx, async move |this, cx| {
1881 this.update(cx, |this, cx| {
1882 this.push_entry(
1883 AgentThreadEntry::UserMessage(UserMessage {
1884 id: message_id.clone(),
1885 content: block,
1886 chunks: message,
1887 checkpoint: None,
1888 indented: false,
1889 }),
1890 cx,
1891 );
1892 })
1893 .ok();
1894
1895 let old_checkpoint = git_store
1896 .update(cx, |git, cx| git.checkpoint(cx))
1897 .await
1898 .context("failed to get old checkpoint")
1899 .log_err();
1900 this.update(cx, |this, cx| {
1901 if let Some((_ix, message)) = this.last_user_message() {
1902 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1903 git_checkpoint,
1904 show: false,
1905 });
1906 }
1907 this.connection.prompt(message_id, request, cx)
1908 })?
1909 .await
1910 })
1911 }
1912
1913 pub fn can_retry(&self, cx: &App) -> bool {
1914 self.connection.retry(&self.session_id, cx).is_some()
1915 }
1916
1917 pub fn retry(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1918 self.run_turn(cx, async move |this, cx| {
1919 this.update(cx, |this, cx| {
1920 this.connection
1921 .retry(&this.session_id, cx)
1922 .map(|retry| retry.run(cx))
1923 })?
1924 .context("retrying a session is not supported")?
1925 .await
1926 })
1927 }
1928
1929 fn run_turn(
1930 &mut self,
1931 cx: &mut Context<Self>,
1932 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1933 ) -> BoxFuture<'static, Result<()>> {
1934 self.clear_completed_plan_entries(cx);
1935
1936 let (tx, rx) = oneshot::channel();
1937 let cancel_task = self.cancel(cx);
1938
1939 self.send_task = Some(cx.spawn(async move |this, cx| {
1940 cancel_task.await;
1941 tx.send(f(this, cx).await).ok();
1942 }));
1943
1944 cx.spawn(async move |this, cx| {
1945 let response = rx.await;
1946
1947 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1948 .await?;
1949
1950 this.update(cx, |this, cx| {
1951 this.project
1952 .update(cx, |project, cx| project.set_agent_location(None, cx));
1953 match response {
1954 Ok(Err(e)) => {
1955 this.send_task.take();
1956 cx.emit(AcpThreadEvent::Error);
1957 log::error!("Error in run turn: {:?}", e);
1958 Err(e)
1959 }
1960 Ok(Ok(r)) if r.stop_reason == acp::StopReason::MaxTokens => {
1961 this.send_task.take();
1962 cx.emit(AcpThreadEvent::Error);
1963 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1964 Err(anyhow!("Max tokens reached"))
1965 }
1966 result => {
1967 let canceled = matches!(
1968 result,
1969 Ok(Ok(acp::PromptResponse {
1970 stop_reason: acp::StopReason::Cancelled,
1971 ..
1972 }))
1973 );
1974
1975 // We only take the task if the current prompt wasn't canceled.
1976 //
1977 // This prompt may have been canceled because another one was sent
1978 // while it was still generating. In these cases, dropping `send_task`
1979 // would cause the next generation to be canceled.
1980 if !canceled {
1981 this.send_task.take();
1982 }
1983
1984 // Handle refusal - distinguish between user prompt and tool call refusals
1985 if let Ok(Ok(acp::PromptResponse {
1986 stop_reason: acp::StopReason::Refusal,
1987 ..
1988 })) = result
1989 {
1990 if let Some((user_msg_ix, _)) = this.last_user_message() {
1991 // Check if there's a completed tool call with results after the last user message
1992 // This indicates the refusal is in response to tool output, not the user's prompt
1993 let has_completed_tool_call_after_user_msg =
1994 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1995 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1996 // Check if the tool call has completed and has output
1997 matches!(tool_call.status, ToolCallStatus::Completed)
1998 && tool_call.raw_output.is_some()
1999 } else {
2000 false
2001 }
2002 });
2003
2004 if has_completed_tool_call_after_user_msg {
2005 // Refusal is due to tool output - don't truncate, just notify
2006 // The model refused based on what the tool returned
2007 cx.emit(AcpThreadEvent::Refusal);
2008 } else {
2009 // User prompt was refused - truncate back to before the user message
2010 let range = user_msg_ix..this.entries.len();
2011 if range.start < range.end {
2012 this.entries.truncate(user_msg_ix);
2013 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2014 }
2015 cx.emit(AcpThreadEvent::Refusal);
2016 }
2017 } else {
2018 // No user message found, treat as general refusal
2019 cx.emit(AcpThreadEvent::Refusal);
2020 }
2021 }
2022
2023 cx.emit(AcpThreadEvent::Stopped);
2024 Ok(())
2025 }
2026 }
2027 })?
2028 })
2029 .boxed()
2030 }
2031
2032 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2033 let Some(send_task) = self.send_task.take() else {
2034 return Task::ready(());
2035 };
2036
2037 for entry in self.entries.iter_mut() {
2038 if let AgentThreadEntry::ToolCall(call) = entry {
2039 let cancel = matches!(
2040 call.status,
2041 ToolCallStatus::Pending
2042 | ToolCallStatus::WaitingForConfirmation { .. }
2043 | ToolCallStatus::InProgress
2044 );
2045
2046 if cancel {
2047 call.status = ToolCallStatus::Canceled;
2048 }
2049 }
2050 }
2051
2052 self.connection.cancel(&self.session_id, cx);
2053
2054 // Wait for the send task to complete
2055 cx.foreground_executor().spawn(send_task)
2056 }
2057
2058 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2059 pub fn restore_checkpoint(
2060 &mut self,
2061 id: UserMessageId,
2062 cx: &mut Context<Self>,
2063 ) -> Task<Result<()>> {
2064 let Some((_, message)) = self.user_message_mut(&id) else {
2065 return Task::ready(Err(anyhow!("message not found")));
2066 };
2067
2068 let checkpoint = message
2069 .checkpoint
2070 .as_ref()
2071 .map(|c| c.git_checkpoint.clone());
2072
2073 // Cancel any in-progress generation before restoring
2074 let cancel_task = self.cancel(cx);
2075 let rewind = self.rewind(id.clone(), cx);
2076 let git_store = self.project.read(cx).git_store().clone();
2077
2078 cx.spawn(async move |_, cx| {
2079 cancel_task.await;
2080 rewind.await?;
2081 if let Some(checkpoint) = checkpoint {
2082 git_store
2083 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2084 .await?;
2085 }
2086
2087 Ok(())
2088 })
2089 }
2090
2091 /// Rewinds this thread to before the entry at `index`, removing it and all
2092 /// subsequent entries while rejecting any action_log changes made from that point.
2093 /// Unlike `restore_checkpoint`, this method does not restore from git.
2094 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2095 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2096 return Task::ready(Err(anyhow!("not supported")));
2097 };
2098
2099 let telemetry = ActionLogTelemetry::from(&*self);
2100 cx.spawn(async move |this, cx| {
2101 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2102 this.update(cx, |this, cx| {
2103 if let Some((ix, _)) = this.user_message_mut(&id) {
2104 // Collect all terminals from entries that will be removed
2105 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2106 .iter()
2107 .flat_map(|entry| entry.terminals())
2108 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2109 .collect();
2110
2111 let range = ix..this.entries.len();
2112 this.entries.truncate(ix);
2113 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2114
2115 // Kill and remove the terminals
2116 for terminal_id in terminals_to_remove {
2117 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2118 terminal.update(cx, |terminal, cx| {
2119 terminal.kill(cx);
2120 });
2121 }
2122 }
2123 }
2124 this.action_log().update(cx, |action_log, cx| {
2125 action_log.reject_all_edits(Some(telemetry), cx)
2126 })
2127 })?
2128 .await;
2129 Ok(())
2130 })
2131 }
2132
2133 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2134 let git_store = self.project.read(cx).git_store().clone();
2135
2136 let Some((_, message)) = self.last_user_message() else {
2137 return Task::ready(Ok(()));
2138 };
2139 let Some(user_message_id) = message.id.clone() else {
2140 return Task::ready(Ok(()));
2141 };
2142 let Some(checkpoint) = message.checkpoint.as_ref() else {
2143 return Task::ready(Ok(()));
2144 };
2145 let old_checkpoint = checkpoint.git_checkpoint.clone();
2146
2147 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2148 cx.spawn(async move |this, cx| {
2149 let Some(new_checkpoint) = new_checkpoint
2150 .await
2151 .context("failed to get new checkpoint")
2152 .log_err()
2153 else {
2154 return Ok(());
2155 };
2156
2157 let equal = git_store
2158 .update(cx, |git, cx| {
2159 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2160 })
2161 .await
2162 .unwrap_or(true);
2163
2164 this.update(cx, |this, cx| {
2165 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2166 if let Some(checkpoint) = message.checkpoint.as_mut() {
2167 checkpoint.show = !equal;
2168 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2169 }
2170 }
2171 })?;
2172
2173 Ok(())
2174 })
2175 }
2176
2177 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2178 self.entries
2179 .iter_mut()
2180 .enumerate()
2181 .rev()
2182 .find_map(|(ix, entry)| {
2183 if let AgentThreadEntry::UserMessage(message) = entry {
2184 Some((ix, message))
2185 } else {
2186 None
2187 }
2188 })
2189 }
2190
2191 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2192 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2193 if let AgentThreadEntry::UserMessage(message) = entry {
2194 if message.id.as_ref() == Some(id) {
2195 Some((ix, message))
2196 } else {
2197 None
2198 }
2199 } else {
2200 None
2201 }
2202 })
2203 }
2204
2205 pub fn read_text_file(
2206 &self,
2207 path: PathBuf,
2208 line: Option<u32>,
2209 limit: Option<u32>,
2210 reuse_shared_snapshot: bool,
2211 cx: &mut Context<Self>,
2212 ) -> Task<Result<String, acp::Error>> {
2213 // Args are 1-based, move to 0-based
2214 let line = line.unwrap_or_default().saturating_sub(1);
2215 let limit = limit.unwrap_or(u32::MAX);
2216 let project = self.project.clone();
2217 let action_log = self.action_log.clone();
2218 cx.spawn(async move |this, cx| {
2219 let load = project.update(cx, |project, cx| {
2220 let path = project
2221 .project_path_for_absolute_path(&path, cx)
2222 .ok_or_else(|| {
2223 acp::Error::resource_not_found(Some(path.display().to_string()))
2224 })?;
2225 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2226 })?;
2227
2228 let buffer = load.await?;
2229
2230 let snapshot = if reuse_shared_snapshot {
2231 this.read_with(cx, |this, _| {
2232 this.shared_buffers.get(&buffer.clone()).cloned()
2233 })
2234 .log_err()
2235 .flatten()
2236 } else {
2237 None
2238 };
2239
2240 let snapshot = if let Some(snapshot) = snapshot {
2241 snapshot
2242 } else {
2243 action_log.update(cx, |action_log, cx| {
2244 action_log.buffer_read(buffer.clone(), cx);
2245 });
2246
2247 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2248 this.update(cx, |this, _| {
2249 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2250 })?;
2251 snapshot
2252 };
2253
2254 let max_point = snapshot.max_point();
2255 let start_position = Point::new(line, 0);
2256
2257 if start_position > max_point {
2258 return Err(acp::Error::invalid_params().data(format!(
2259 "Attempting to read beyond the end of the file, line {}:{}",
2260 max_point.row + 1,
2261 max_point.column
2262 )));
2263 }
2264
2265 let start = snapshot.anchor_before(start_position);
2266 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2267
2268 project.update(cx, |project, cx| {
2269 project.set_agent_location(
2270 Some(AgentLocation {
2271 buffer: buffer.downgrade(),
2272 position: start,
2273 }),
2274 cx,
2275 );
2276 });
2277
2278 Ok(snapshot.text_for_range(start..end).collect::<String>())
2279 })
2280 }
2281
2282 pub fn write_text_file(
2283 &self,
2284 path: PathBuf,
2285 content: String,
2286 cx: &mut Context<Self>,
2287 ) -> Task<Result<()>> {
2288 let project = self.project.clone();
2289 let action_log = self.action_log.clone();
2290 cx.spawn(async move |this, cx| {
2291 let load = project.update(cx, |project, cx| {
2292 let path = project
2293 .project_path_for_absolute_path(&path, cx)
2294 .context("invalid path")?;
2295 anyhow::Ok(project.open_buffer(path, cx))
2296 });
2297 let buffer = load?.await?;
2298 let snapshot = this.update(cx, |this, cx| {
2299 this.shared_buffers
2300 .get(&buffer)
2301 .cloned()
2302 .unwrap_or_else(|| buffer.read(cx).snapshot())
2303 })?;
2304 let edits = cx
2305 .background_executor()
2306 .spawn(async move {
2307 let old_text = snapshot.text();
2308 text_diff(old_text.as_str(), &content)
2309 .into_iter()
2310 .map(|(range, replacement)| {
2311 (
2312 snapshot.anchor_after(range.start)
2313 ..snapshot.anchor_before(range.end),
2314 replacement,
2315 )
2316 })
2317 .collect::<Vec<_>>()
2318 })
2319 .await;
2320
2321 project.update(cx, |project, cx| {
2322 project.set_agent_location(
2323 Some(AgentLocation {
2324 buffer: buffer.downgrade(),
2325 position: edits
2326 .last()
2327 .map(|(range, _)| range.end)
2328 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2329 }),
2330 cx,
2331 );
2332 });
2333
2334 let format_on_save = cx.update(|cx| {
2335 action_log.update(cx, |action_log, cx| {
2336 action_log.buffer_read(buffer.clone(), cx);
2337 });
2338
2339 let format_on_save = buffer.update(cx, |buffer, cx| {
2340 buffer.edit(edits, None, cx);
2341
2342 let settings = language::language_settings::language_settings(
2343 buffer.language().map(|l| l.name()),
2344 buffer.file(),
2345 cx,
2346 );
2347
2348 settings.format_on_save != FormatOnSave::Off
2349 });
2350 action_log.update(cx, |action_log, cx| {
2351 action_log.buffer_edited(buffer.clone(), cx);
2352 });
2353 format_on_save
2354 });
2355
2356 if format_on_save {
2357 let format_task = project.update(cx, |project, cx| {
2358 project.format(
2359 HashSet::from_iter([buffer.clone()]),
2360 LspFormatTarget::Buffers,
2361 false,
2362 FormatTrigger::Save,
2363 cx,
2364 )
2365 });
2366 format_task.await.log_err();
2367
2368 action_log.update(cx, |action_log, cx| {
2369 action_log.buffer_edited(buffer.clone(), cx);
2370 });
2371 }
2372
2373 project
2374 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2375 .await
2376 })
2377 }
2378
2379 pub fn create_terminal(
2380 &self,
2381 command: String,
2382 args: Vec<String>,
2383 extra_env: Vec<acp::EnvVariable>,
2384 cwd: Option<PathBuf>,
2385 output_byte_limit: Option<u64>,
2386 cx: &mut Context<Self>,
2387 ) -> Task<Result<Entity<Terminal>>> {
2388 let env = match &cwd {
2389 Some(dir) => self.project.update(cx, |project, cx| {
2390 project.environment().update(cx, |env, cx| {
2391 env.directory_environment(dir.as_path().into(), cx)
2392 })
2393 }),
2394 None => Task::ready(None).shared(),
2395 };
2396 let env = cx.spawn(async move |_, _| {
2397 let mut env = env.await.unwrap_or_default();
2398 // Disables paging for `git` and hopefully other commands
2399 env.insert("PAGER".into(), "".into());
2400 for var in extra_env {
2401 env.insert(var.name, var.value);
2402 }
2403 env
2404 });
2405
2406 let project = self.project.clone();
2407 let language_registry = project.read(cx).languages().clone();
2408 let is_windows = project.read(cx).path_style(cx).is_windows();
2409
2410 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2411 let terminal_task = cx.spawn({
2412 let terminal_id = terminal_id.clone();
2413 async move |_this, cx| {
2414 let env = env.await;
2415 let shell = project
2416 .update(cx, |project, cx| {
2417 project
2418 .remote_client()
2419 .and_then(|r| r.read(cx).default_system_shell())
2420 })
2421 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2422 let (task_command, task_args) =
2423 ShellBuilder::new(&Shell::Program(shell), is_windows)
2424 .redirect_stdin_to_dev_null()
2425 .build(Some(command.clone()), &args);
2426 let terminal = project
2427 .update(cx, |project, cx| {
2428 project.create_terminal_task(
2429 task::SpawnInTerminal {
2430 command: Some(task_command),
2431 args: task_args,
2432 cwd: cwd.clone(),
2433 env,
2434 ..Default::default()
2435 },
2436 cx,
2437 )
2438 })
2439 .await?;
2440
2441 anyhow::Ok(cx.new(|cx| {
2442 Terminal::new(
2443 terminal_id,
2444 &format!("{} {}", command, args.join(" ")),
2445 cwd,
2446 output_byte_limit.map(|l| l as usize),
2447 terminal,
2448 language_registry,
2449 cx,
2450 )
2451 }))
2452 }
2453 });
2454
2455 cx.spawn(async move |this, cx| {
2456 let terminal = terminal_task.await?;
2457 this.update(cx, |this, _cx| {
2458 this.terminals.insert(terminal_id, terminal.clone());
2459 terminal
2460 })
2461 })
2462 }
2463
2464 pub fn kill_terminal(
2465 &mut self,
2466 terminal_id: acp::TerminalId,
2467 cx: &mut Context<Self>,
2468 ) -> Result<()> {
2469 self.terminals
2470 .get(&terminal_id)
2471 .context("Terminal not found")?
2472 .update(cx, |terminal, cx| {
2473 terminal.kill(cx);
2474 });
2475
2476 Ok(())
2477 }
2478
2479 pub fn release_terminal(
2480 &mut self,
2481 terminal_id: acp::TerminalId,
2482 cx: &mut Context<Self>,
2483 ) -> Result<()> {
2484 self.terminals
2485 .remove(&terminal_id)
2486 .context("Terminal not found")?
2487 .update(cx, |terminal, cx| {
2488 terminal.kill(cx);
2489 });
2490
2491 Ok(())
2492 }
2493
2494 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2495 self.terminals
2496 .get(&terminal_id)
2497 .context("Terminal not found")
2498 .cloned()
2499 }
2500
2501 pub fn to_markdown(&self, cx: &App) -> String {
2502 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2503 }
2504
2505 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2506 cx.emit(AcpThreadEvent::LoadError(error));
2507 }
2508
2509 pub fn register_terminal_created(
2510 &mut self,
2511 terminal_id: acp::TerminalId,
2512 command_label: String,
2513 working_dir: Option<PathBuf>,
2514 output_byte_limit: Option<u64>,
2515 terminal: Entity<::terminal::Terminal>,
2516 cx: &mut Context<Self>,
2517 ) -> Entity<Terminal> {
2518 let language_registry = self.project.read(cx).languages().clone();
2519
2520 let entity = cx.new(|cx| {
2521 Terminal::new(
2522 terminal_id.clone(),
2523 &command_label,
2524 working_dir.clone(),
2525 output_byte_limit.map(|l| l as usize),
2526 terminal,
2527 language_registry,
2528 cx,
2529 )
2530 });
2531 self.terminals.insert(terminal_id.clone(), entity.clone());
2532 entity
2533 }
2534}
2535
2536fn markdown_for_raw_output(
2537 raw_output: &serde_json::Value,
2538 language_registry: &Arc<LanguageRegistry>,
2539 cx: &mut App,
2540) -> Option<Entity<Markdown>> {
2541 match raw_output {
2542 serde_json::Value::Null => None,
2543 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2544 Markdown::new(
2545 value.to_string().into(),
2546 Some(language_registry.clone()),
2547 None,
2548 cx,
2549 )
2550 })),
2551 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2552 Markdown::new(
2553 value.to_string().into(),
2554 Some(language_registry.clone()),
2555 None,
2556 cx,
2557 )
2558 })),
2559 serde_json::Value::String(value) => Some(cx.new(|cx| {
2560 Markdown::new(
2561 value.clone().into(),
2562 Some(language_registry.clone()),
2563 None,
2564 cx,
2565 )
2566 })),
2567 value => Some(cx.new(|cx| {
2568 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2569
2570 Markdown::new(
2571 format!("```json\n{}\n```", pretty_json).into(),
2572 Some(language_registry.clone()),
2573 None,
2574 cx,
2575 )
2576 })),
2577 }
2578}
2579
2580#[cfg(test)]
2581mod tests {
2582 use super::*;
2583 use anyhow::anyhow;
2584 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2585 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2586 use indoc::indoc;
2587 use project::{FakeFs, Fs};
2588 use rand::{distr, prelude::*};
2589 use serde_json::json;
2590 use settings::SettingsStore;
2591 use smol::stream::StreamExt as _;
2592 use std::{
2593 any::Any,
2594 cell::RefCell,
2595 path::Path,
2596 rc::Rc,
2597 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2598 time::Duration,
2599 };
2600 use util::path;
2601
2602 fn init_test(cx: &mut TestAppContext) {
2603 env_logger::try_init().ok();
2604 cx.update(|cx| {
2605 let settings_store = SettingsStore::test(cx);
2606 cx.set_global(settings_store);
2607 });
2608 }
2609
2610 #[gpui::test]
2611 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2612 init_test(cx);
2613
2614 let fs = FakeFs::new(cx.executor());
2615 let project = Project::test(fs, [], cx).await;
2616 let connection = Rc::new(FakeAgentConnection::new());
2617 let thread = cx
2618 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2619 .await
2620 .unwrap();
2621
2622 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2623
2624 // Send Output BEFORE Created - should be buffered by acp_thread
2625 thread.update(cx, |thread, cx| {
2626 thread.on_terminal_provider_event(
2627 TerminalProviderEvent::Output {
2628 terminal_id: terminal_id.clone(),
2629 data: b"hello buffered".to_vec(),
2630 },
2631 cx,
2632 );
2633 });
2634
2635 // Create a display-only terminal and then send Created
2636 let lower = cx.new(|cx| {
2637 let builder = ::terminal::TerminalBuilder::new_display_only(
2638 ::terminal::terminal_settings::CursorShape::default(),
2639 ::terminal::terminal_settings::AlternateScroll::On,
2640 None,
2641 0,
2642 cx.background_executor(),
2643 PathStyle::local(),
2644 )
2645 .unwrap();
2646 builder.subscribe(cx)
2647 });
2648
2649 thread.update(cx, |thread, cx| {
2650 thread.on_terminal_provider_event(
2651 TerminalProviderEvent::Created {
2652 terminal_id: terminal_id.clone(),
2653 label: "Buffered Test".to_string(),
2654 cwd: None,
2655 output_byte_limit: None,
2656 terminal: lower.clone(),
2657 },
2658 cx,
2659 );
2660 });
2661
2662 // After Created, buffered Output should have been flushed into the renderer
2663 let content = thread.read_with(cx, |thread, cx| {
2664 let term = thread.terminal(terminal_id.clone()).unwrap();
2665 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2666 });
2667
2668 assert!(
2669 content.contains("hello buffered"),
2670 "expected buffered output to render, got: {content}"
2671 );
2672 }
2673
2674 #[gpui::test]
2675 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2676 init_test(cx);
2677
2678 let fs = FakeFs::new(cx.executor());
2679 let project = Project::test(fs, [], cx).await;
2680 let connection = Rc::new(FakeAgentConnection::new());
2681 let thread = cx
2682 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2683 .await
2684 .unwrap();
2685
2686 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2687
2688 // Send Output BEFORE Created
2689 thread.update(cx, |thread, cx| {
2690 thread.on_terminal_provider_event(
2691 TerminalProviderEvent::Output {
2692 terminal_id: terminal_id.clone(),
2693 data: b"pre-exit data".to_vec(),
2694 },
2695 cx,
2696 );
2697 });
2698
2699 // Send Exit BEFORE Created
2700 thread.update(cx, |thread, cx| {
2701 thread.on_terminal_provider_event(
2702 TerminalProviderEvent::Exit {
2703 terminal_id: terminal_id.clone(),
2704 status: acp::TerminalExitStatus::new().exit_code(0),
2705 },
2706 cx,
2707 );
2708 });
2709
2710 // Now create a display-only lower-level terminal and send Created
2711 let lower = cx.new(|cx| {
2712 let builder = ::terminal::TerminalBuilder::new_display_only(
2713 ::terminal::terminal_settings::CursorShape::default(),
2714 ::terminal::terminal_settings::AlternateScroll::On,
2715 None,
2716 0,
2717 cx.background_executor(),
2718 PathStyle::local(),
2719 )
2720 .unwrap();
2721 builder.subscribe(cx)
2722 });
2723
2724 thread.update(cx, |thread, cx| {
2725 thread.on_terminal_provider_event(
2726 TerminalProviderEvent::Created {
2727 terminal_id: terminal_id.clone(),
2728 label: "Buffered Exit Test".to_string(),
2729 cwd: None,
2730 output_byte_limit: None,
2731 terminal: lower.clone(),
2732 },
2733 cx,
2734 );
2735 });
2736
2737 // Output should be present after Created (flushed from buffer)
2738 let content = thread.read_with(cx, |thread, cx| {
2739 let term = thread.terminal(terminal_id.clone()).unwrap();
2740 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2741 });
2742
2743 assert!(
2744 content.contains("pre-exit data"),
2745 "expected pre-exit data to render, got: {content}"
2746 );
2747 }
2748
2749 /// Test that killing a terminal via Terminal::kill properly:
2750 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2751 /// 2. The underlying terminal still has the output that was written before the kill
2752 ///
2753 /// This test verifies that the fix to kill_active_task (which now also kills
2754 /// the shell process in addition to the foreground process) properly allows
2755 /// wait_for_exit to complete instead of hanging indefinitely.
2756 #[cfg(unix)]
2757 #[gpui::test]
2758 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2759 use std::collections::HashMap;
2760 use task::Shell;
2761 use util::shell_builder::ShellBuilder;
2762
2763 init_test(cx);
2764 cx.executor().allow_parking();
2765
2766 let fs = FakeFs::new(cx.executor());
2767 let project = Project::test(fs, [], cx).await;
2768 let connection = Rc::new(FakeAgentConnection::new());
2769 let thread = cx
2770 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2771 .await
2772 .unwrap();
2773
2774 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2775
2776 // Create a real PTY terminal that runs a command which prints output then sleeps
2777 // We use printf instead of echo and chain with && sleep to ensure proper execution
2778 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2779 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2780 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2781 &[],
2782 );
2783
2784 let builder = cx
2785 .update(|cx| {
2786 ::terminal::TerminalBuilder::new(
2787 None,
2788 None,
2789 task::Shell::WithArguments {
2790 program,
2791 args,
2792 title_override: None,
2793 },
2794 HashMap::default(),
2795 ::terminal::terminal_settings::CursorShape::default(),
2796 ::terminal::terminal_settings::AlternateScroll::On,
2797 None,
2798 vec![],
2799 0,
2800 false,
2801 0,
2802 Some(completion_tx),
2803 cx,
2804 vec![],
2805 PathStyle::local(),
2806 )
2807 })
2808 .await
2809 .unwrap();
2810
2811 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2812
2813 // Create the acp_thread Terminal wrapper
2814 thread.update(cx, |thread, cx| {
2815 thread.on_terminal_provider_event(
2816 TerminalProviderEvent::Created {
2817 terminal_id: terminal_id.clone(),
2818 label: "printf output_before_kill && sleep 60".to_string(),
2819 cwd: None,
2820 output_byte_limit: None,
2821 terminal: lower_terminal.clone(),
2822 },
2823 cx,
2824 );
2825 });
2826
2827 // Wait for the printf command to execute and produce output
2828 // Use real time since parking is enabled
2829 cx.executor().timer(Duration::from_millis(500)).await;
2830
2831 // Get the acp_thread Terminal and kill it
2832 let wait_for_exit = thread.update(cx, |thread, cx| {
2833 let term = thread.terminals.get(&terminal_id).unwrap();
2834 let wait_for_exit = term.read(cx).wait_for_exit();
2835 term.update(cx, |term, cx| {
2836 term.kill(cx);
2837 });
2838 wait_for_exit
2839 });
2840
2841 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2842 // Before the fix to kill_active_task, this would hang forever because
2843 // only the foreground process was killed, not the shell, so the PTY
2844 // child never exited and wait_for_completed_task never completed.
2845 let exit_result = futures::select! {
2846 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2847 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2848 };
2849
2850 assert!(
2851 exit_result.is_some(),
2852 "wait_for_exit should complete after kill, but it timed out. \
2853 This indicates kill_active_task is not properly killing the shell process."
2854 );
2855
2856 // Give the system a chance to process any pending updates
2857 cx.run_until_parked();
2858
2859 // Verify that the underlying terminal still has the output that was
2860 // written before the kill. This verifies that killing doesn't lose output.
2861 let inner_content = thread.read_with(cx, |thread, cx| {
2862 let term = thread.terminals.get(&terminal_id).unwrap();
2863 term.read(cx).inner().read(cx).get_content()
2864 });
2865
2866 assert!(
2867 inner_content.contains("output_before_kill"),
2868 "Underlying terminal should contain output from before kill, got: {}",
2869 inner_content
2870 );
2871 }
2872
2873 #[gpui::test]
2874 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2875 init_test(cx);
2876
2877 let fs = FakeFs::new(cx.executor());
2878 let project = Project::test(fs, [], cx).await;
2879 let connection = Rc::new(FakeAgentConnection::new());
2880 let thread = cx
2881 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2882 .await
2883 .unwrap();
2884
2885 // Test creating a new user message
2886 thread.update(cx, |thread, cx| {
2887 thread.push_user_content_block(None, "Hello, ".into(), cx);
2888 });
2889
2890 thread.update(cx, |thread, cx| {
2891 assert_eq!(thread.entries.len(), 1);
2892 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2893 assert_eq!(user_msg.id, None);
2894 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2895 } else {
2896 panic!("Expected UserMessage");
2897 }
2898 });
2899
2900 // Test appending to existing user message
2901 let message_1_id = UserMessageId::new();
2902 thread.update(cx, |thread, cx| {
2903 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2904 });
2905
2906 thread.update(cx, |thread, cx| {
2907 assert_eq!(thread.entries.len(), 1);
2908 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2909 assert_eq!(user_msg.id, Some(message_1_id));
2910 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2911 } else {
2912 panic!("Expected UserMessage");
2913 }
2914 });
2915
2916 // Test creating new user message after assistant message
2917 thread.update(cx, |thread, cx| {
2918 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2919 });
2920
2921 let message_2_id = UserMessageId::new();
2922 thread.update(cx, |thread, cx| {
2923 thread.push_user_content_block(
2924 Some(message_2_id.clone()),
2925 "New user message".into(),
2926 cx,
2927 );
2928 });
2929
2930 thread.update(cx, |thread, cx| {
2931 assert_eq!(thread.entries.len(), 3);
2932 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2933 assert_eq!(user_msg.id, Some(message_2_id));
2934 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2935 } else {
2936 panic!("Expected UserMessage at index 2");
2937 }
2938 });
2939 }
2940
2941 #[gpui::test]
2942 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2943 init_test(cx);
2944
2945 let fs = FakeFs::new(cx.executor());
2946 let project = Project::test(fs, [], cx).await;
2947 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2948 |_, thread, mut cx| {
2949 async move {
2950 thread.update(&mut cx, |thread, cx| {
2951 thread
2952 .handle_session_update(
2953 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2954 "Thinking ".into(),
2955 )),
2956 cx,
2957 )
2958 .unwrap();
2959 thread
2960 .handle_session_update(
2961 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2962 "hard!".into(),
2963 )),
2964 cx,
2965 )
2966 .unwrap();
2967 })?;
2968 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2969 }
2970 .boxed_local()
2971 },
2972 ));
2973
2974 let thread = cx
2975 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2976 .await
2977 .unwrap();
2978
2979 thread
2980 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2981 .await
2982 .unwrap();
2983
2984 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2985 assert_eq!(
2986 output,
2987 indoc! {r#"
2988 ## User
2989
2990 Hello from Zed!
2991
2992 ## Assistant
2993
2994 <thinking>
2995 Thinking hard!
2996 </thinking>
2997
2998 "#}
2999 );
3000 }
3001
3002 #[gpui::test]
3003 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3004 init_test(cx);
3005
3006 let fs = FakeFs::new(cx.executor());
3007 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3008 .await;
3009 let project = Project::test(fs.clone(), [], cx).await;
3010 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3011 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3012 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3013 move |_, thread, mut cx| {
3014 let read_file_tx = read_file_tx.clone();
3015 async move {
3016 let content = thread
3017 .update(&mut cx, |thread, cx| {
3018 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3019 })
3020 .unwrap()
3021 .await
3022 .unwrap();
3023 assert_eq!(content, "one\ntwo\nthree\n");
3024 read_file_tx.take().unwrap().send(()).unwrap();
3025 thread
3026 .update(&mut cx, |thread, cx| {
3027 thread.write_text_file(
3028 path!("/tmp/foo").into(),
3029 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3030 cx,
3031 )
3032 })
3033 .unwrap()
3034 .await
3035 .unwrap();
3036 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3037 }
3038 .boxed_local()
3039 },
3040 ));
3041
3042 let (worktree, pathbuf) = project
3043 .update(cx, |project, cx| {
3044 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3045 })
3046 .await
3047 .unwrap();
3048 let buffer = project
3049 .update(cx, |project, cx| {
3050 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3051 })
3052 .await
3053 .unwrap();
3054
3055 let thread = cx
3056 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3057 .await
3058 .unwrap();
3059
3060 let request = thread.update(cx, |thread, cx| {
3061 thread.send_raw("Extend the count in /tmp/foo", cx)
3062 });
3063 read_file_rx.await.ok();
3064 buffer.update(cx, |buffer, cx| {
3065 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3066 });
3067 cx.run_until_parked();
3068 assert_eq!(
3069 buffer.read_with(cx, |buffer, _| buffer.text()),
3070 "zero\none\ntwo\nthree\nfour\nfive\n"
3071 );
3072 assert_eq!(
3073 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3074 "zero\none\ntwo\nthree\nfour\nfive\n"
3075 );
3076 request.await.unwrap();
3077 }
3078
3079 #[gpui::test]
3080 async fn test_reading_from_line(cx: &mut TestAppContext) {
3081 init_test(cx);
3082
3083 let fs = FakeFs::new(cx.executor());
3084 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3085 .await;
3086 let project = Project::test(fs.clone(), [], cx).await;
3087 project
3088 .update(cx, |project, cx| {
3089 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3090 })
3091 .await
3092 .unwrap();
3093
3094 let connection = Rc::new(FakeAgentConnection::new());
3095
3096 let thread = cx
3097 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3098 .await
3099 .unwrap();
3100
3101 // Whole file
3102 let content = thread
3103 .update(cx, |thread, cx| {
3104 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3105 })
3106 .await
3107 .unwrap();
3108
3109 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3110
3111 // Only start line
3112 let content = thread
3113 .update(cx, |thread, cx| {
3114 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3115 })
3116 .await
3117 .unwrap();
3118
3119 assert_eq!(content, "three\nfour\n");
3120
3121 // Only limit
3122 let content = thread
3123 .update(cx, |thread, cx| {
3124 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3125 })
3126 .await
3127 .unwrap();
3128
3129 assert_eq!(content, "one\ntwo\n");
3130
3131 // Range
3132 let content = thread
3133 .update(cx, |thread, cx| {
3134 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3135 })
3136 .await
3137 .unwrap();
3138
3139 assert_eq!(content, "two\nthree\n");
3140
3141 // Invalid
3142 let err = thread
3143 .update(cx, |thread, cx| {
3144 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3145 })
3146 .await
3147 .unwrap_err();
3148
3149 assert_eq!(
3150 err.to_string(),
3151 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3152 );
3153 }
3154
3155 #[gpui::test]
3156 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3157 init_test(cx);
3158
3159 let fs = FakeFs::new(cx.executor());
3160 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3161 let project = Project::test(fs.clone(), [], cx).await;
3162 project
3163 .update(cx, |project, cx| {
3164 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3165 })
3166 .await
3167 .unwrap();
3168
3169 let connection = Rc::new(FakeAgentConnection::new());
3170
3171 let thread = cx
3172 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3173 .await
3174 .unwrap();
3175
3176 // Whole file
3177 let content = thread
3178 .update(cx, |thread, cx| {
3179 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3180 })
3181 .await
3182 .unwrap();
3183
3184 assert_eq!(content, "");
3185
3186 // Only start line
3187 let content = thread
3188 .update(cx, |thread, cx| {
3189 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3190 })
3191 .await
3192 .unwrap();
3193
3194 assert_eq!(content, "");
3195
3196 // Only limit
3197 let content = thread
3198 .update(cx, |thread, cx| {
3199 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3200 })
3201 .await
3202 .unwrap();
3203
3204 assert_eq!(content, "");
3205
3206 // Range
3207 let content = thread
3208 .update(cx, |thread, cx| {
3209 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3210 })
3211 .await
3212 .unwrap();
3213
3214 assert_eq!(content, "");
3215
3216 // Invalid
3217 let err = thread
3218 .update(cx, |thread, cx| {
3219 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3220 })
3221 .await
3222 .unwrap_err();
3223
3224 assert_eq!(
3225 err.to_string(),
3226 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3227 );
3228 }
3229 #[gpui::test]
3230 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3231 init_test(cx);
3232
3233 let fs = FakeFs::new(cx.executor());
3234 fs.insert_tree(path!("/tmp"), json!({})).await;
3235 let project = Project::test(fs.clone(), [], cx).await;
3236 project
3237 .update(cx, |project, cx| {
3238 project.find_or_create_worktree(path!("/tmp"), true, cx)
3239 })
3240 .await
3241 .unwrap();
3242
3243 let connection = Rc::new(FakeAgentConnection::new());
3244
3245 let thread = cx
3246 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3247 .await
3248 .unwrap();
3249
3250 // Out of project file
3251 let err = thread
3252 .update(cx, |thread, cx| {
3253 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3254 })
3255 .await
3256 .unwrap_err();
3257
3258 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3259 }
3260
3261 #[gpui::test]
3262 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3263 init_test(cx);
3264
3265 let fs = FakeFs::new(cx.executor());
3266 let project = Project::test(fs, [], cx).await;
3267 let id = acp::ToolCallId::new("test");
3268
3269 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3270 let id = id.clone();
3271 move |_, thread, mut cx| {
3272 let id = id.clone();
3273 async move {
3274 thread
3275 .update(&mut cx, |thread, cx| {
3276 thread.handle_session_update(
3277 acp::SessionUpdate::ToolCall(
3278 acp::ToolCall::new(id.clone(), "Label")
3279 .kind(acp::ToolKind::Fetch)
3280 .status(acp::ToolCallStatus::InProgress),
3281 ),
3282 cx,
3283 )
3284 })
3285 .unwrap()
3286 .unwrap();
3287 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3288 }
3289 .boxed_local()
3290 }
3291 }));
3292
3293 let thread = cx
3294 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3295 .await
3296 .unwrap();
3297
3298 let request = thread.update(cx, |thread, cx| {
3299 thread.send_raw("Fetch https://example.com", cx)
3300 });
3301
3302 run_until_first_tool_call(&thread, cx).await;
3303
3304 thread.read_with(cx, |thread, _| {
3305 assert!(matches!(
3306 thread.entries[1],
3307 AgentThreadEntry::ToolCall(ToolCall {
3308 status: ToolCallStatus::InProgress,
3309 ..
3310 })
3311 ));
3312 });
3313
3314 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3315
3316 thread.read_with(cx, |thread, _| {
3317 assert!(matches!(
3318 &thread.entries[1],
3319 AgentThreadEntry::ToolCall(ToolCall {
3320 status: ToolCallStatus::Canceled,
3321 ..
3322 })
3323 ));
3324 });
3325
3326 thread
3327 .update(cx, |thread, cx| {
3328 thread.handle_session_update(
3329 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3330 id,
3331 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3332 )),
3333 cx,
3334 )
3335 })
3336 .unwrap();
3337
3338 request.await.unwrap();
3339
3340 thread.read_with(cx, |thread, _| {
3341 assert!(matches!(
3342 thread.entries[1],
3343 AgentThreadEntry::ToolCall(ToolCall {
3344 status: ToolCallStatus::Completed,
3345 ..
3346 })
3347 ));
3348 });
3349 }
3350
3351 #[gpui::test]
3352 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3353 init_test(cx);
3354 let fs = FakeFs::new(cx.background_executor.clone());
3355 fs.insert_tree(path!("/test"), json!({})).await;
3356 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3357
3358 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3359 move |_, thread, mut cx| {
3360 async move {
3361 thread
3362 .update(&mut cx, |thread, cx| {
3363 thread.handle_session_update(
3364 acp::SessionUpdate::ToolCall(
3365 acp::ToolCall::new("test", "Label")
3366 .kind(acp::ToolKind::Edit)
3367 .status(acp::ToolCallStatus::Completed)
3368 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3369 "/test/test.txt",
3370 "foo",
3371 ))]),
3372 ),
3373 cx,
3374 )
3375 })
3376 .unwrap()
3377 .unwrap();
3378 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3379 }
3380 .boxed_local()
3381 }
3382 }));
3383
3384 let thread = cx
3385 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3386 .await
3387 .unwrap();
3388
3389 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3390 .await
3391 .unwrap();
3392
3393 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3394 }
3395
3396 #[gpui::test(iterations = 10)]
3397 async fn test_checkpoints(cx: &mut TestAppContext) {
3398 init_test(cx);
3399 let fs = FakeFs::new(cx.background_executor.clone());
3400 fs.insert_tree(
3401 path!("/test"),
3402 json!({
3403 ".git": {}
3404 }),
3405 )
3406 .await;
3407 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3408
3409 let simulate_changes = Arc::new(AtomicBool::new(true));
3410 let next_filename = Arc::new(AtomicUsize::new(0));
3411 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3412 let simulate_changes = simulate_changes.clone();
3413 let next_filename = next_filename.clone();
3414 let fs = fs.clone();
3415 move |request, thread, mut cx| {
3416 let fs = fs.clone();
3417 let simulate_changes = simulate_changes.clone();
3418 let next_filename = next_filename.clone();
3419 async move {
3420 if simulate_changes.load(SeqCst) {
3421 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3422 fs.write(Path::new(&filename), b"").await?;
3423 }
3424
3425 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3426 panic!("expected text content block");
3427 };
3428 thread.update(&mut cx, |thread, cx| {
3429 thread
3430 .handle_session_update(
3431 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3432 content.text.to_uppercase().into(),
3433 )),
3434 cx,
3435 )
3436 .unwrap();
3437 })?;
3438 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3439 }
3440 .boxed_local()
3441 }
3442 }));
3443 let thread = cx
3444 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3445 .await
3446 .unwrap();
3447
3448 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3449 .await
3450 .unwrap();
3451 thread.read_with(cx, |thread, cx| {
3452 assert_eq!(
3453 thread.to_markdown(cx),
3454 indoc! {"
3455 ## User (checkpoint)
3456
3457 Lorem
3458
3459 ## Assistant
3460
3461 LOREM
3462
3463 "}
3464 );
3465 });
3466 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3467
3468 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3469 .await
3470 .unwrap();
3471 thread.read_with(cx, |thread, cx| {
3472 assert_eq!(
3473 thread.to_markdown(cx),
3474 indoc! {"
3475 ## User (checkpoint)
3476
3477 Lorem
3478
3479 ## Assistant
3480
3481 LOREM
3482
3483 ## User (checkpoint)
3484
3485 ipsum
3486
3487 ## Assistant
3488
3489 IPSUM
3490
3491 "}
3492 );
3493 });
3494 assert_eq!(
3495 fs.files(),
3496 vec![
3497 Path::new(path!("/test/file-0")),
3498 Path::new(path!("/test/file-1"))
3499 ]
3500 );
3501
3502 // Checkpoint isn't stored when there are no changes.
3503 simulate_changes.store(false, SeqCst);
3504 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3505 .await
3506 .unwrap();
3507 thread.read_with(cx, |thread, cx| {
3508 assert_eq!(
3509 thread.to_markdown(cx),
3510 indoc! {"
3511 ## User (checkpoint)
3512
3513 Lorem
3514
3515 ## Assistant
3516
3517 LOREM
3518
3519 ## User (checkpoint)
3520
3521 ipsum
3522
3523 ## Assistant
3524
3525 IPSUM
3526
3527 ## User
3528
3529 dolor
3530
3531 ## Assistant
3532
3533 DOLOR
3534
3535 "}
3536 );
3537 });
3538 assert_eq!(
3539 fs.files(),
3540 vec![
3541 Path::new(path!("/test/file-0")),
3542 Path::new(path!("/test/file-1"))
3543 ]
3544 );
3545
3546 // Rewinding the conversation truncates the history and restores the checkpoint.
3547 thread
3548 .update(cx, |thread, cx| {
3549 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3550 panic!("unexpected entries {:?}", thread.entries)
3551 };
3552 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3553 })
3554 .await
3555 .unwrap();
3556 thread.read_with(cx, |thread, cx| {
3557 assert_eq!(
3558 thread.to_markdown(cx),
3559 indoc! {"
3560 ## User (checkpoint)
3561
3562 Lorem
3563
3564 ## Assistant
3565
3566 LOREM
3567
3568 "}
3569 );
3570 });
3571 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3572 }
3573
3574 #[gpui::test]
3575 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3576 use std::sync::atomic::AtomicUsize;
3577 init_test(cx);
3578
3579 let fs = FakeFs::new(cx.executor());
3580 let project = Project::test(fs, None, cx).await;
3581
3582 // Create a connection that simulates refusal after tool result
3583 let prompt_count = Arc::new(AtomicUsize::new(0));
3584 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3585 let prompt_count = prompt_count.clone();
3586 move |_request, thread, mut cx| {
3587 let count = prompt_count.fetch_add(1, SeqCst);
3588 async move {
3589 if count == 0 {
3590 // First prompt: Generate a tool call with result
3591 thread.update(&mut cx, |thread, cx| {
3592 thread
3593 .handle_session_update(
3594 acp::SessionUpdate::ToolCall(
3595 acp::ToolCall::new("tool1", "Test Tool")
3596 .kind(acp::ToolKind::Fetch)
3597 .status(acp::ToolCallStatus::Completed)
3598 .raw_input(serde_json::json!({"query": "test"}))
3599 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3600 ),
3601 cx,
3602 )
3603 .unwrap();
3604 })?;
3605
3606 // Now return refusal because of the tool result
3607 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3608 } else {
3609 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3610 }
3611 }
3612 .boxed_local()
3613 }
3614 }));
3615
3616 let thread = cx
3617 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3618 .await
3619 .unwrap();
3620
3621 // Track if we see a Refusal event
3622 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3623 let saw_refusal_event_captured = saw_refusal_event.clone();
3624 thread.update(cx, |_thread, cx| {
3625 cx.subscribe(
3626 &thread,
3627 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3628 if matches!(event, AcpThreadEvent::Refusal) {
3629 *saw_refusal_event_captured.lock().unwrap() = true;
3630 }
3631 },
3632 )
3633 .detach();
3634 });
3635
3636 // Send a user message - this will trigger tool call and then refusal
3637 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3638 cx.background_executor.spawn(send_task).detach();
3639 cx.run_until_parked();
3640
3641 // Verify that:
3642 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3643 // 2. The user message was NOT truncated
3644 assert!(
3645 *saw_refusal_event.lock().unwrap(),
3646 "Refusal event should be emitted for tool result refusals"
3647 );
3648
3649 thread.read_with(cx, |thread, _| {
3650 let entries = thread.entries();
3651 assert!(entries.len() >= 2, "Should have user message and tool call");
3652
3653 // Verify user message is still there
3654 assert!(
3655 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3656 "User message should not be truncated"
3657 );
3658
3659 // Verify tool call is there with result
3660 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3661 assert!(
3662 tool_call.raw_output.is_some(),
3663 "Tool call should have output"
3664 );
3665 } else {
3666 panic!("Expected tool call at index 1");
3667 }
3668 });
3669 }
3670
3671 #[gpui::test]
3672 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3673 init_test(cx);
3674
3675 let fs = FakeFs::new(cx.executor());
3676 let project = Project::test(fs, None, cx).await;
3677
3678 let refuse_next = Arc::new(AtomicBool::new(false));
3679 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3680 let refuse_next = refuse_next.clone();
3681 move |_request, _thread, _cx| {
3682 if refuse_next.load(SeqCst) {
3683 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3684 .boxed_local()
3685 } else {
3686 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3687 .boxed_local()
3688 }
3689 }
3690 }));
3691
3692 let thread = cx
3693 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3694 .await
3695 .unwrap();
3696
3697 // Track if we see a Refusal event
3698 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3699 let saw_refusal_event_captured = saw_refusal_event.clone();
3700 thread.update(cx, |_thread, cx| {
3701 cx.subscribe(
3702 &thread,
3703 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3704 if matches!(event, AcpThreadEvent::Refusal) {
3705 *saw_refusal_event_captured.lock().unwrap() = true;
3706 }
3707 },
3708 )
3709 .detach();
3710 });
3711
3712 // Send a message that will be refused
3713 refuse_next.store(true, SeqCst);
3714 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3715 .await
3716 .unwrap();
3717
3718 // Verify that a Refusal event WAS emitted for user prompt refusal
3719 assert!(
3720 *saw_refusal_event.lock().unwrap(),
3721 "Refusal event should be emitted for user prompt refusals"
3722 );
3723
3724 // Verify the message was truncated (user prompt refusal)
3725 thread.read_with(cx, |thread, cx| {
3726 assert_eq!(thread.to_markdown(cx), "");
3727 });
3728 }
3729
3730 #[gpui::test]
3731 async fn test_refusal(cx: &mut TestAppContext) {
3732 init_test(cx);
3733 let fs = FakeFs::new(cx.background_executor.clone());
3734 fs.insert_tree(path!("/"), json!({})).await;
3735 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3736
3737 let refuse_next = Arc::new(AtomicBool::new(false));
3738 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3739 let refuse_next = refuse_next.clone();
3740 move |request, thread, mut cx| {
3741 let refuse_next = refuse_next.clone();
3742 async move {
3743 if refuse_next.load(SeqCst) {
3744 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3745 }
3746
3747 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3748 panic!("expected text content block");
3749 };
3750 thread.update(&mut cx, |thread, cx| {
3751 thread
3752 .handle_session_update(
3753 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3754 content.text.to_uppercase().into(),
3755 )),
3756 cx,
3757 )
3758 .unwrap();
3759 })?;
3760 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3761 }
3762 .boxed_local()
3763 }
3764 }));
3765 let thread = cx
3766 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3767 .await
3768 .unwrap();
3769
3770 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3771 .await
3772 .unwrap();
3773 thread.read_with(cx, |thread, cx| {
3774 assert_eq!(
3775 thread.to_markdown(cx),
3776 indoc! {"
3777 ## User
3778
3779 hello
3780
3781 ## Assistant
3782
3783 HELLO
3784
3785 "}
3786 );
3787 });
3788
3789 // Simulate refusing the second message. The message should be truncated
3790 // when a user prompt is refused.
3791 refuse_next.store(true, SeqCst);
3792 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3793 .await
3794 .unwrap();
3795 thread.read_with(cx, |thread, cx| {
3796 assert_eq!(
3797 thread.to_markdown(cx),
3798 indoc! {"
3799 ## User
3800
3801 hello
3802
3803 ## Assistant
3804
3805 HELLO
3806
3807 "}
3808 );
3809 });
3810 }
3811
3812 async fn run_until_first_tool_call(
3813 thread: &Entity<AcpThread>,
3814 cx: &mut TestAppContext,
3815 ) -> usize {
3816 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3817
3818 let subscription = cx.update(|cx| {
3819 cx.subscribe(thread, move |thread, _, cx| {
3820 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3821 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3822 return tx.try_send(ix).unwrap();
3823 }
3824 }
3825 })
3826 });
3827
3828 select! {
3829 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3830 panic!("Timeout waiting for tool call")
3831 }
3832 ix = rx.next().fuse() => {
3833 drop(subscription);
3834 ix.unwrap()
3835 }
3836 }
3837 }
3838
3839 #[derive(Clone, Default)]
3840 struct FakeAgentConnection {
3841 auth_methods: Vec<acp::AuthMethod>,
3842 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3843 on_user_message: Option<
3844 Rc<
3845 dyn Fn(
3846 acp::PromptRequest,
3847 WeakEntity<AcpThread>,
3848 AsyncApp,
3849 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3850 + 'static,
3851 >,
3852 >,
3853 }
3854
3855 impl FakeAgentConnection {
3856 fn new() -> Self {
3857 Self {
3858 auth_methods: Vec::new(),
3859 on_user_message: None,
3860 sessions: Arc::default(),
3861 }
3862 }
3863
3864 #[expect(unused)]
3865 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3866 self.auth_methods = auth_methods;
3867 self
3868 }
3869
3870 fn on_user_message(
3871 mut self,
3872 handler: impl Fn(
3873 acp::PromptRequest,
3874 WeakEntity<AcpThread>,
3875 AsyncApp,
3876 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3877 + 'static,
3878 ) -> Self {
3879 self.on_user_message.replace(Rc::new(handler));
3880 self
3881 }
3882 }
3883
3884 impl AgentConnection for FakeAgentConnection {
3885 fn telemetry_id(&self) -> SharedString {
3886 "fake".into()
3887 }
3888
3889 fn auth_methods(&self) -> &[acp::AuthMethod] {
3890 &self.auth_methods
3891 }
3892
3893 fn new_session(
3894 self: Rc<Self>,
3895 project: Entity<Project>,
3896 _cwd: &Path,
3897 cx: &mut App,
3898 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3899 let session_id = acp::SessionId::new(
3900 rand::rng()
3901 .sample_iter(&distr::Alphanumeric)
3902 .take(7)
3903 .map(char::from)
3904 .collect::<String>(),
3905 );
3906 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3907 let thread = cx.new(|cx| {
3908 AcpThread::new(
3909 None,
3910 "Test",
3911 self.clone(),
3912 project,
3913 action_log,
3914 session_id.clone(),
3915 watch::Receiver::constant(
3916 acp::PromptCapabilities::new()
3917 .image(true)
3918 .audio(true)
3919 .embedded_context(true),
3920 ),
3921 cx,
3922 )
3923 });
3924 self.sessions.lock().insert(session_id, thread.downgrade());
3925 Task::ready(Ok(thread))
3926 }
3927
3928 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3929 if self.auth_methods().iter().any(|m| m.id == method) {
3930 Task::ready(Ok(()))
3931 } else {
3932 Task::ready(Err(anyhow!("Invalid Auth Method")))
3933 }
3934 }
3935
3936 fn prompt(
3937 &self,
3938 _id: Option<UserMessageId>,
3939 params: acp::PromptRequest,
3940 cx: &mut App,
3941 ) -> Task<gpui::Result<acp::PromptResponse>> {
3942 let sessions = self.sessions.lock();
3943 let thread = sessions.get(¶ms.session_id).unwrap();
3944 if let Some(handler) = &self.on_user_message {
3945 let handler = handler.clone();
3946 let thread = thread.clone();
3947 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3948 } else {
3949 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3950 }
3951 }
3952
3953 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3954 let sessions = self.sessions.lock();
3955 let thread = sessions.get(session_id).unwrap().clone();
3956
3957 cx.spawn(async move |cx| {
3958 thread
3959 .update(cx, |thread, cx| thread.cancel(cx))
3960 .unwrap()
3961 .await
3962 })
3963 .detach();
3964 }
3965
3966 fn truncate(
3967 &self,
3968 session_id: &acp::SessionId,
3969 _cx: &App,
3970 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3971 Some(Rc::new(FakeAgentSessionEditor {
3972 _session_id: session_id.clone(),
3973 }))
3974 }
3975
3976 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3977 self
3978 }
3979 }
3980
3981 struct FakeAgentSessionEditor {
3982 _session_id: acp::SessionId,
3983 }
3984
3985 impl AgentSessionTruncate for FakeAgentSessionEditor {
3986 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3987 Task::ready(Ok(()))
3988 }
3989 }
3990
3991 #[gpui::test]
3992 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3993 init_test(cx);
3994
3995 let fs = FakeFs::new(cx.executor());
3996 let project = Project::test(fs, [], cx).await;
3997 let connection = Rc::new(FakeAgentConnection::new());
3998 let thread = cx
3999 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4000 .await
4001 .unwrap();
4002
4003 // Try to update a tool call that doesn't exist
4004 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
4005 thread.update(cx, |thread, cx| {
4006 let result = thread.handle_session_update(
4007 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4008 nonexistent_id.clone(),
4009 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4010 )),
4011 cx,
4012 );
4013
4014 // The update should succeed (not return an error)
4015 assert!(result.is_ok());
4016
4017 // There should now be exactly one entry in the thread
4018 assert_eq!(thread.entries.len(), 1);
4019
4020 // The entry should be a failed tool call
4021 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4022 assert_eq!(tool_call.id, nonexistent_id);
4023 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4024 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4025
4026 // Check that the content contains the error message
4027 assert_eq!(tool_call.content.len(), 1);
4028 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4029 match content_block {
4030 ContentBlock::Markdown { markdown } => {
4031 let markdown_text = markdown.read(cx).source();
4032 assert!(markdown_text.contains("Tool call not found"));
4033 }
4034 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4035 ContentBlock::ResourceLink { .. } => {
4036 panic!("Expected markdown content, got resource link")
4037 }
4038 ContentBlock::Image { .. } => {
4039 panic!("Expected markdown content, got image")
4040 }
4041 }
4042 } else {
4043 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4044 }
4045 } else {
4046 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4047 }
4048 });
4049 }
4050
4051 /// Tests that restoring a checkpoint properly cleans up terminals that were
4052 /// created after that checkpoint, and cancels any in-progress generation.
4053 ///
4054 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4055 /// that were started after that checkpoint should be terminated, and any in-progress
4056 /// AI generation should be canceled.
4057 #[gpui::test]
4058 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4059 init_test(cx);
4060
4061 let fs = FakeFs::new(cx.executor());
4062 let project = Project::test(fs, [], cx).await;
4063 let connection = Rc::new(FakeAgentConnection::new());
4064 let thread = cx
4065 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4066 .await
4067 .unwrap();
4068
4069 // Send first user message to create a checkpoint
4070 cx.update(|cx| {
4071 thread.update(cx, |thread, cx| {
4072 thread.send(vec!["first message".into()], cx)
4073 })
4074 })
4075 .await
4076 .unwrap();
4077
4078 // Send second message (creates another checkpoint) - we'll restore to this one
4079 cx.update(|cx| {
4080 thread.update(cx, |thread, cx| {
4081 thread.send(vec!["second message".into()], cx)
4082 })
4083 })
4084 .await
4085 .unwrap();
4086
4087 // Create 2 terminals BEFORE the checkpoint that have completed running
4088 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4089 let mock_terminal_1 = cx.new(|cx| {
4090 let builder = ::terminal::TerminalBuilder::new_display_only(
4091 ::terminal::terminal_settings::CursorShape::default(),
4092 ::terminal::terminal_settings::AlternateScroll::On,
4093 None,
4094 0,
4095 cx.background_executor(),
4096 PathStyle::local(),
4097 )
4098 .unwrap();
4099 builder.subscribe(cx)
4100 });
4101
4102 thread.update(cx, |thread, cx| {
4103 thread.on_terminal_provider_event(
4104 TerminalProviderEvent::Created {
4105 terminal_id: terminal_id_1.clone(),
4106 label: "echo 'first'".to_string(),
4107 cwd: Some(PathBuf::from("/test")),
4108 output_byte_limit: None,
4109 terminal: mock_terminal_1.clone(),
4110 },
4111 cx,
4112 );
4113 });
4114
4115 thread.update(cx, |thread, cx| {
4116 thread.on_terminal_provider_event(
4117 TerminalProviderEvent::Output {
4118 terminal_id: terminal_id_1.clone(),
4119 data: b"first\n".to_vec(),
4120 },
4121 cx,
4122 );
4123 });
4124
4125 thread.update(cx, |thread, cx| {
4126 thread.on_terminal_provider_event(
4127 TerminalProviderEvent::Exit {
4128 terminal_id: terminal_id_1.clone(),
4129 status: acp::TerminalExitStatus::new().exit_code(0),
4130 },
4131 cx,
4132 );
4133 });
4134
4135 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4136 let mock_terminal_2 = cx.new(|cx| {
4137 let builder = ::terminal::TerminalBuilder::new_display_only(
4138 ::terminal::terminal_settings::CursorShape::default(),
4139 ::terminal::terminal_settings::AlternateScroll::On,
4140 None,
4141 0,
4142 cx.background_executor(),
4143 PathStyle::local(),
4144 )
4145 .unwrap();
4146 builder.subscribe(cx)
4147 });
4148
4149 thread.update(cx, |thread, cx| {
4150 thread.on_terminal_provider_event(
4151 TerminalProviderEvent::Created {
4152 terminal_id: terminal_id_2.clone(),
4153 label: "echo 'second'".to_string(),
4154 cwd: Some(PathBuf::from("/test")),
4155 output_byte_limit: None,
4156 terminal: mock_terminal_2.clone(),
4157 },
4158 cx,
4159 );
4160 });
4161
4162 thread.update(cx, |thread, cx| {
4163 thread.on_terminal_provider_event(
4164 TerminalProviderEvent::Output {
4165 terminal_id: terminal_id_2.clone(),
4166 data: b"second\n".to_vec(),
4167 },
4168 cx,
4169 );
4170 });
4171
4172 thread.update(cx, |thread, cx| {
4173 thread.on_terminal_provider_event(
4174 TerminalProviderEvent::Exit {
4175 terminal_id: terminal_id_2.clone(),
4176 status: acp::TerminalExitStatus::new().exit_code(0),
4177 },
4178 cx,
4179 );
4180 });
4181
4182 // Get the second message ID to restore to
4183 let second_message_id = thread.read_with(cx, |thread, _| {
4184 // At this point we have:
4185 // - Index 0: First user message (with checkpoint)
4186 // - Index 1: Second user message (with checkpoint)
4187 // No assistant responses because FakeAgentConnection just returns EndTurn
4188 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4189 panic!("expected user message at index 1");
4190 };
4191 message.id.clone().unwrap()
4192 });
4193
4194 // Create a terminal AFTER the checkpoint we'll restore to.
4195 // This simulates the AI agent starting a long-running terminal command.
4196 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4197 let mock_terminal = cx.new(|cx| {
4198 let builder = ::terminal::TerminalBuilder::new_display_only(
4199 ::terminal::terminal_settings::CursorShape::default(),
4200 ::terminal::terminal_settings::AlternateScroll::On,
4201 None,
4202 0,
4203 cx.background_executor(),
4204 PathStyle::local(),
4205 )
4206 .unwrap();
4207 builder.subscribe(cx)
4208 });
4209
4210 // Register the terminal as created
4211 thread.update(cx, |thread, cx| {
4212 thread.on_terminal_provider_event(
4213 TerminalProviderEvent::Created {
4214 terminal_id: terminal_id.clone(),
4215 label: "sleep 1000".to_string(),
4216 cwd: Some(PathBuf::from("/test")),
4217 output_byte_limit: None,
4218 terminal: mock_terminal.clone(),
4219 },
4220 cx,
4221 );
4222 });
4223
4224 // Simulate the terminal producing output (still running)
4225 thread.update(cx, |thread, cx| {
4226 thread.on_terminal_provider_event(
4227 TerminalProviderEvent::Output {
4228 terminal_id: terminal_id.clone(),
4229 data: b"terminal is running...\n".to_vec(),
4230 },
4231 cx,
4232 );
4233 });
4234
4235 // Create a tool call entry that references this terminal
4236 // This represents the agent requesting a terminal command
4237 thread.update(cx, |thread, cx| {
4238 thread
4239 .handle_session_update(
4240 acp::SessionUpdate::ToolCall(
4241 acp::ToolCall::new("terminal-tool-1", "Running command")
4242 .kind(acp::ToolKind::Execute)
4243 .status(acp::ToolCallStatus::InProgress)
4244 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4245 terminal_id.clone(),
4246 ))])
4247 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4248 ),
4249 cx,
4250 )
4251 .unwrap();
4252 });
4253
4254 // Verify terminal exists and is in the thread
4255 let terminal_exists_before =
4256 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4257 assert!(
4258 terminal_exists_before,
4259 "Terminal should exist before checkpoint restore"
4260 );
4261
4262 // Verify the terminal's underlying task is still running (not completed)
4263 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4264 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4265 terminal_entity.read_with(cx, |term, _cx| {
4266 term.output().is_none() // output is None means it's still running
4267 })
4268 });
4269 assert!(
4270 terminal_running_before,
4271 "Terminal should be running before checkpoint restore"
4272 );
4273
4274 // Verify we have the expected entries before restore
4275 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4276 assert!(
4277 entry_count_before > 1,
4278 "Should have multiple entries before restore"
4279 );
4280
4281 // Restore the checkpoint to the second message.
4282 // This should:
4283 // 1. Cancel any in-progress generation (via the cancel() call)
4284 // 2. Remove the terminal that was created after that point
4285 thread
4286 .update(cx, |thread, cx| {
4287 thread.restore_checkpoint(second_message_id, cx)
4288 })
4289 .await
4290 .unwrap();
4291
4292 // Verify that no send_task is in progress after restore
4293 // (cancel() clears the send_task)
4294 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4295 assert!(
4296 !has_send_task_after,
4297 "Should not have a send_task after restore (cancel should have cleared it)"
4298 );
4299
4300 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4301 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4302 assert_eq!(
4303 entry_count, 1,
4304 "Should have 1 entry after restore (only the first user message)"
4305 );
4306
4307 // Verify the 2 completed terminals from before the checkpoint still exist
4308 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4309 thread.terminals.contains_key(&terminal_id_1)
4310 });
4311 assert!(
4312 terminal_1_exists,
4313 "Terminal 1 (from before checkpoint) should still exist"
4314 );
4315
4316 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4317 thread.terminals.contains_key(&terminal_id_2)
4318 });
4319 assert!(
4320 terminal_2_exists,
4321 "Terminal 2 (from before checkpoint) should still exist"
4322 );
4323
4324 // Verify they're still in completed state
4325 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4326 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4327 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4328 });
4329 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4330
4331 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4332 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4333 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4334 });
4335 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4336
4337 // Verify the running terminal (created after checkpoint) was removed
4338 let terminal_3_exists =
4339 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4340 assert!(
4341 !terminal_3_exists,
4342 "Terminal 3 (created after checkpoint) should have been removed"
4343 );
4344
4345 // Verify total count is 2 (the two from before the checkpoint)
4346 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4347 assert_eq!(
4348 terminal_count, 2,
4349 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4350 );
4351 }
4352
4353 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4354 /// even when a new user message is added while the async checkpoint comparison is in progress.
4355 ///
4356 /// This is a regression test for a bug where update_last_checkpoint would fail with
4357 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4358 /// update_last_checkpoint started and when its async closure ran.
4359 #[gpui::test]
4360 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4361 init_test(cx);
4362
4363 let fs = FakeFs::new(cx.executor());
4364 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4365 .await;
4366 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4367
4368 let handler_done = Arc::new(AtomicBool::new(false));
4369 let handler_done_clone = handler_done.clone();
4370 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4371 move |_, _thread, _cx| {
4372 handler_done_clone.store(true, SeqCst);
4373 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4374 },
4375 ));
4376
4377 let thread = cx
4378 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4379 .await
4380 .unwrap();
4381
4382 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4383 let send_task = cx.background_executor.spawn(send_future);
4384
4385 // Tick until handler completes, then a few more to let update_last_checkpoint start
4386 while !handler_done.load(SeqCst) {
4387 cx.executor().tick();
4388 }
4389 for _ in 0..5 {
4390 cx.executor().tick();
4391 }
4392
4393 thread.update(cx, |thread, cx| {
4394 thread.push_entry(
4395 AgentThreadEntry::UserMessage(UserMessage {
4396 id: Some(UserMessageId::new()),
4397 content: ContentBlock::Empty,
4398 chunks: vec!["Injected message (no checkpoint)".into()],
4399 checkpoint: None,
4400 indented: false,
4401 }),
4402 cx,
4403 );
4404 });
4405
4406 cx.run_until_parked();
4407 let result = send_task.await;
4408
4409 assert!(
4410 result.is_ok(),
4411 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4412 result.err()
4413 );
4414 }
4415}