1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7
8/// Key used in ACP ToolCall meta to store the tool's programmatic name.
9/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
10pub const TOOL_NAME_META_KEY: &str = "tool_name";
11
12/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
13pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
14
15/// Helper to extract tool name from ACP meta
16pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
17 meta.as_ref()
18 .and_then(|m| m.get(TOOL_NAME_META_KEY))
19 .and_then(|v| v.as_str())
20 .map(|s| SharedString::from(s.to_owned()))
21}
22
23/// Helper to extract subagent session id from ACP meta
24pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
25 meta.as_ref()
26 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
27 .and_then(|v| v.as_str())
28 .map(|s| acp::SessionId::from(s.to_string()))
29}
30
31/// Helper to create meta with tool name
32pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
33 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
34}
35use collections::HashSet;
36pub use connection::*;
37pub use diff::*;
38use language::language_settings::FormatOnSave;
39pub use mention::*;
40use project::lsp_store::{FormatTrigger, LspFormatTarget};
41use serde::{Deserialize, Serialize};
42use serde_json::to_string_pretty;
43use settings::Settings as _;
44use task::{Shell, ShellBuilder};
45pub use terminal::*;
46
47use action_log::{ActionLog, ActionLogTelemetry};
48use agent_client_protocol::{self as acp};
49use anyhow::{Context as _, Result, anyhow};
50use futures::{FutureExt, channel::oneshot, future::BoxFuture};
51use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
52use itertools::Itertools;
53use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
54use markdown::Markdown;
55use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
56use std::collections::HashMap;
57use std::error::Error;
58use std::fmt::{Formatter, Write};
59use std::ops::Range;
60use std::process::ExitStatus;
61use std::rc::Rc;
62use std::time::{Duration, Instant};
63use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
64use text::Bias;
65use ui::App;
66use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
67use uuid::Uuid;
68
69#[derive(Debug)]
70pub struct UserMessage {
71 pub id: Option<UserMessageId>,
72 pub content: ContentBlock,
73 pub chunks: Vec<acp::ContentBlock>,
74 pub checkpoint: Option<Checkpoint>,
75 pub indented: bool,
76}
77
78#[derive(Debug)]
79pub struct Checkpoint {
80 git_checkpoint: GitStoreCheckpoint,
81 pub show: bool,
82}
83
84impl UserMessage {
85 fn to_markdown(&self, cx: &App) -> String {
86 let mut markdown = String::new();
87 if self
88 .checkpoint
89 .as_ref()
90 .is_some_and(|checkpoint| checkpoint.show)
91 {
92 writeln!(markdown, "## User (checkpoint)").unwrap();
93 } else {
94 writeln!(markdown, "## User").unwrap();
95 }
96 writeln!(markdown).unwrap();
97 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
98 writeln!(markdown).unwrap();
99 markdown
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub struct AssistantMessage {
105 pub chunks: Vec<AssistantMessageChunk>,
106 pub indented: bool,
107}
108
109impl AssistantMessage {
110 pub fn to_markdown(&self, cx: &App) -> String {
111 format!(
112 "## Assistant\n\n{}\n\n",
113 self.chunks
114 .iter()
115 .map(|chunk| chunk.to_markdown(cx))
116 .join("\n\n")
117 )
118 }
119}
120
121#[derive(Debug, PartialEq)]
122pub enum AssistantMessageChunk {
123 Message { block: ContentBlock },
124 Thought { block: ContentBlock },
125}
126
127impl AssistantMessageChunk {
128 pub fn from_str(
129 chunk: &str,
130 language_registry: &Arc<LanguageRegistry>,
131 path_style: PathStyle,
132 cx: &mut App,
133 ) -> Self {
134 Self::Message {
135 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
136 }
137 }
138
139 fn to_markdown(&self, cx: &App) -> String {
140 match self {
141 Self::Message { block } => block.to_markdown(cx).to_string(),
142 Self::Thought { block } => {
143 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
144 }
145 }
146 }
147}
148
149#[derive(Debug)]
150pub enum AgentThreadEntry {
151 UserMessage(UserMessage),
152 AssistantMessage(AssistantMessage),
153 ToolCall(ToolCall),
154}
155
156impl AgentThreadEntry {
157 pub fn is_indented(&self) -> bool {
158 match self {
159 Self::UserMessage(message) => message.indented,
160 Self::AssistantMessage(message) => message.indented,
161 Self::ToolCall(_) => false,
162 }
163 }
164
165 pub fn to_markdown(&self, cx: &App) -> String {
166 match self {
167 Self::UserMessage(message) => message.to_markdown(cx),
168 Self::AssistantMessage(message) => message.to_markdown(cx),
169 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
170 }
171 }
172
173 pub fn user_message(&self) -> Option<&UserMessage> {
174 if let AgentThreadEntry::UserMessage(message) = self {
175 Some(message)
176 } else {
177 None
178 }
179 }
180
181 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
182 if let AgentThreadEntry::ToolCall(call) = self {
183 itertools::Either::Left(call.diffs())
184 } else {
185 itertools::Either::Right(std::iter::empty())
186 }
187 }
188
189 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
190 if let AgentThreadEntry::ToolCall(call) = self {
191 itertools::Either::Left(call.terminals())
192 } else {
193 itertools::Either::Right(std::iter::empty())
194 }
195 }
196
197 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
198 if let AgentThreadEntry::ToolCall(ToolCall {
199 locations,
200 resolved_locations,
201 ..
202 }) = self
203 {
204 Some((
205 locations.get(ix)?.clone(),
206 resolved_locations.get(ix)?.clone()?,
207 ))
208 } else {
209 None
210 }
211 }
212}
213
214#[derive(Debug)]
215pub struct ToolCall {
216 pub id: acp::ToolCallId,
217 pub label: Entity<Markdown>,
218 pub kind: acp::ToolKind,
219 pub content: Vec<ToolCallContent>,
220 pub status: ToolCallStatus,
221 pub locations: Vec<acp::ToolCallLocation>,
222 pub resolved_locations: Vec<Option<AgentLocation>>,
223 pub raw_input: Option<serde_json::Value>,
224 pub raw_input_markdown: Option<Entity<Markdown>>,
225 pub raw_output: Option<serde_json::Value>,
226 pub tool_name: Option<SharedString>,
227 pub subagent_session_id: Option<acp::SessionId>,
228}
229
230impl ToolCall {
231 fn from_acp(
232 tool_call: acp::ToolCall,
233 status: ToolCallStatus,
234 language_registry: Arc<LanguageRegistry>,
235 path_style: PathStyle,
236 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
237 cx: &mut App,
238 ) -> Result<Self> {
239 let title = if tool_call.kind == acp::ToolKind::Execute {
240 tool_call.title
241 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
242 first_line.to_owned() + "…"
243 } else {
244 tool_call.title
245 };
246 let mut content = Vec::with_capacity(tool_call.content.len());
247 for item in tool_call.content {
248 if let Some(item) = ToolCallContent::from_acp(
249 item,
250 language_registry.clone(),
251 path_style,
252 terminals,
253 cx,
254 )? {
255 content.push(item);
256 }
257 }
258
259 let raw_input_markdown = tool_call
260 .raw_input
261 .as_ref()
262 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
263
264 let tool_name = tool_name_from_meta(&tool_call.meta);
265
266 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
267
268 let result = Self {
269 id: tool_call.tool_call_id,
270 label: cx
271 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
272 kind: tool_call.kind,
273 content,
274 locations: tool_call.locations,
275 resolved_locations: Vec::default(),
276 status,
277 raw_input: tool_call.raw_input,
278 raw_input_markdown,
279 raw_output: tool_call.raw_output,
280 tool_name,
281 subagent_session_id: subagent_session,
282 };
283 Ok(result)
284 }
285
286 fn update_fields(
287 &mut self,
288 fields: acp::ToolCallUpdateFields,
289 meta: Option<acp::Meta>,
290 language_registry: Arc<LanguageRegistry>,
291 path_style: PathStyle,
292 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
293 cx: &mut App,
294 ) -> Result<()> {
295 let acp::ToolCallUpdateFields {
296 kind,
297 status,
298 title,
299 content,
300 locations,
301 raw_input,
302 raw_output,
303 ..
304 } = fields;
305
306 if let Some(kind) = kind {
307 self.kind = kind;
308 }
309
310 if let Some(status) = status {
311 self.status = status.into();
312 }
313
314 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
315 self.subagent_session_id = Some(subagent_session_id);
316 }
317
318 if let Some(title) = title {
319 self.label.update(cx, |label, cx| {
320 if self.kind == acp::ToolKind::Execute {
321 label.replace(title, cx);
322 } else if let Some((first_line, _)) = title.split_once("\n") {
323 label.replace(first_line.to_owned() + "…", cx);
324 } else {
325 label.replace(title, cx);
326 }
327 });
328 }
329
330 if let Some(content) = content {
331 let mut new_content_len = content.len();
332 let mut content = content.into_iter();
333
334 // Reuse existing content if we can
335 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
336 let valid_content =
337 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
338 if !valid_content {
339 new_content_len -= 1;
340 }
341 }
342 for new in content {
343 if let Some(new) = ToolCallContent::from_acp(
344 new,
345 language_registry.clone(),
346 path_style,
347 terminals,
348 cx,
349 )? {
350 self.content.push(new);
351 } else {
352 new_content_len -= 1;
353 }
354 }
355 self.content.truncate(new_content_len);
356 }
357
358 if let Some(locations) = locations {
359 self.locations = locations;
360 }
361
362 if let Some(raw_input) = raw_input {
363 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
364 self.raw_input = Some(raw_input);
365 }
366
367 if let Some(raw_output) = raw_output {
368 if self.content.is_empty()
369 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
370 {
371 self.content
372 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
373 markdown,
374 }));
375 }
376 self.raw_output = Some(raw_output);
377 }
378 Ok(())
379 }
380
381 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
382 self.content.iter().filter_map(|content| match content {
383 ToolCallContent::Diff(diff) => Some(diff),
384 ToolCallContent::ContentBlock(_) => None,
385 ToolCallContent::Terminal(_) => None,
386 })
387 }
388
389 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
390 self.content.iter().filter_map(|content| match content {
391 ToolCallContent::Terminal(terminal) => Some(terminal),
392 ToolCallContent::ContentBlock(_) => None,
393 ToolCallContent::Diff(_) => None,
394 })
395 }
396
397 pub fn is_subagent(&self) -> bool {
398 self.tool_name.as_ref().is_some_and(|s| s == "subagent")
399 || self.subagent_session_id.is_some()
400 }
401
402 pub fn to_markdown(&self, cx: &App) -> String {
403 let mut markdown = format!(
404 "**Tool Call: {}**\nStatus: {}\n\n",
405 self.label.read(cx).source(),
406 self.status
407 );
408 for content in &self.content {
409 markdown.push_str(content.to_markdown(cx).as_str());
410 markdown.push_str("\n\n");
411 }
412 markdown
413 }
414
415 async fn resolve_location(
416 location: acp::ToolCallLocation,
417 project: WeakEntity<Project>,
418 cx: &mut AsyncApp,
419 ) -> Option<ResolvedLocation> {
420 let buffer = project
421 .update(cx, |project, cx| {
422 project
423 .project_path_for_absolute_path(&location.path, cx)
424 .map(|path| project.open_buffer(path, cx))
425 })
426 .ok()??;
427 let buffer = buffer.await.log_err()?;
428 let position = buffer.update(cx, |buffer, _| {
429 let snapshot = buffer.snapshot();
430 if let Some(row) = location.line {
431 let column = snapshot.indent_size_for_line(row).len;
432 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
433 snapshot.anchor_before(point)
434 } else {
435 Anchor::min_for_buffer(snapshot.remote_id())
436 }
437 });
438
439 Some(ResolvedLocation { buffer, position })
440 }
441
442 fn resolve_locations(
443 &self,
444 project: Entity<Project>,
445 cx: &mut App,
446 ) -> Task<Vec<Option<ResolvedLocation>>> {
447 let locations = self.locations.clone();
448 project.update(cx, |_, cx| {
449 cx.spawn(async move |project, cx| {
450 let mut new_locations = Vec::new();
451 for location in locations {
452 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
453 }
454 new_locations
455 })
456 })
457 }
458}
459
460// Separate so we can hold a strong reference to the buffer
461// for saving on the thread
462#[derive(Clone, Debug, PartialEq, Eq)]
463struct ResolvedLocation {
464 buffer: Entity<Buffer>,
465 position: Anchor,
466}
467
468impl From<&ResolvedLocation> for AgentLocation {
469 fn from(value: &ResolvedLocation) -> Self {
470 Self {
471 buffer: value.buffer.downgrade(),
472 position: value.position,
473 }
474 }
475}
476
477#[derive(Debug)]
478pub enum ToolCallStatus {
479 /// The tool call hasn't started running yet, but we start showing it to
480 /// the user.
481 Pending,
482 /// The tool call is waiting for confirmation from the user.
483 WaitingForConfirmation {
484 options: PermissionOptions,
485 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
486 },
487 /// The tool call is currently running.
488 InProgress,
489 /// The tool call completed successfully.
490 Completed,
491 /// The tool call failed.
492 Failed,
493 /// The user rejected the tool call.
494 Rejected,
495 /// The user canceled generation so the tool call was canceled.
496 Canceled,
497}
498
499impl From<acp::ToolCallStatus> for ToolCallStatus {
500 fn from(status: acp::ToolCallStatus) -> Self {
501 match status {
502 acp::ToolCallStatus::Pending => Self::Pending,
503 acp::ToolCallStatus::InProgress => Self::InProgress,
504 acp::ToolCallStatus::Completed => Self::Completed,
505 acp::ToolCallStatus::Failed => Self::Failed,
506 _ => Self::Pending,
507 }
508 }
509}
510
511impl Display for ToolCallStatus {
512 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
513 write!(
514 f,
515 "{}",
516 match self {
517 ToolCallStatus::Pending => "Pending",
518 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
519 ToolCallStatus::InProgress => "In Progress",
520 ToolCallStatus::Completed => "Completed",
521 ToolCallStatus::Failed => "Failed",
522 ToolCallStatus::Rejected => "Rejected",
523 ToolCallStatus::Canceled => "Canceled",
524 }
525 )
526 }
527}
528
529#[derive(Debug, PartialEq, Clone)]
530pub enum ContentBlock {
531 Empty,
532 Markdown { markdown: Entity<Markdown> },
533 ResourceLink { resource_link: acp::ResourceLink },
534 Image { image: Arc<gpui::Image> },
535}
536
537impl ContentBlock {
538 pub fn new(
539 block: acp::ContentBlock,
540 language_registry: &Arc<LanguageRegistry>,
541 path_style: PathStyle,
542 cx: &mut App,
543 ) -> Self {
544 let mut this = Self::Empty;
545 this.append(block, language_registry, path_style, cx);
546 this
547 }
548
549 pub fn new_combined(
550 blocks: impl IntoIterator<Item = acp::ContentBlock>,
551 language_registry: Arc<LanguageRegistry>,
552 path_style: PathStyle,
553 cx: &mut App,
554 ) -> Self {
555 let mut this = Self::Empty;
556 for block in blocks {
557 this.append(block, &language_registry, path_style, cx);
558 }
559 this
560 }
561
562 pub fn append(
563 &mut self,
564 block: acp::ContentBlock,
565 language_registry: &Arc<LanguageRegistry>,
566 path_style: PathStyle,
567 cx: &mut App,
568 ) {
569 match (&mut *self, &block) {
570 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
571 *self = ContentBlock::ResourceLink {
572 resource_link: resource_link.clone(),
573 };
574 }
575 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
576 if let Some(image) = Self::decode_image(image_content) {
577 *self = ContentBlock::Image { image };
578 } else {
579 let new_content = Self::image_md(image_content);
580 *self = Self::create_markdown_block(new_content, language_registry, cx);
581 }
582 }
583 (ContentBlock::Empty, _) => {
584 let new_content = Self::block_string_contents(&block, path_style);
585 *self = Self::create_markdown_block(new_content, language_registry, cx);
586 }
587 (ContentBlock::Markdown { markdown }, _) => {
588 let new_content = Self::block_string_contents(&block, path_style);
589 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
590 }
591 (ContentBlock::ResourceLink { resource_link }, _) => {
592 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
593 let new_content = Self::block_string_contents(&block, path_style);
594 let combined = format!("{}\n{}", existing_content, new_content);
595 *self = Self::create_markdown_block(combined, language_registry, cx);
596 }
597 (ContentBlock::Image { .. }, _) => {
598 let new_content = Self::block_string_contents(&block, path_style);
599 let combined = format!("`Image`\n{}", new_content);
600 *self = Self::create_markdown_block(combined, language_registry, cx);
601 }
602 }
603 }
604
605 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
606 use base64::Engine as _;
607
608 let bytes = base64::engine::general_purpose::STANDARD
609 .decode(image_content.data.as_bytes())
610 .ok()?;
611 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
612 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
613 }
614
615 fn create_markdown_block(
616 content: String,
617 language_registry: &Arc<LanguageRegistry>,
618 cx: &mut App,
619 ) -> ContentBlock {
620 ContentBlock::Markdown {
621 markdown: cx
622 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
623 }
624 }
625
626 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
627 match block {
628 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
629 acp::ContentBlock::ResourceLink(resource_link) => {
630 Self::resource_link_md(&resource_link.uri, path_style)
631 }
632 acp::ContentBlock::Resource(acp::EmbeddedResource {
633 resource:
634 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
635 uri,
636 ..
637 }),
638 ..
639 }) => Self::resource_link_md(uri, path_style),
640 acp::ContentBlock::Image(image) => Self::image_md(image),
641 _ => String::new(),
642 }
643 }
644
645 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
646 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
647 uri.as_link().to_string()
648 } else {
649 uri.to_string()
650 }
651 }
652
653 fn image_md(_image: &acp::ImageContent) -> String {
654 "`Image`".into()
655 }
656
657 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
658 match self {
659 ContentBlock::Empty => "",
660 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
661 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
662 ContentBlock::Image { .. } => "`Image`",
663 }
664 }
665
666 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
667 match self {
668 ContentBlock::Empty => None,
669 ContentBlock::Markdown { markdown } => Some(markdown),
670 ContentBlock::ResourceLink { .. } => None,
671 ContentBlock::Image { .. } => None,
672 }
673 }
674
675 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
676 match self {
677 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
678 _ => None,
679 }
680 }
681
682 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
683 match self {
684 ContentBlock::Image { image } => Some(image),
685 _ => None,
686 }
687 }
688}
689
690#[derive(Debug)]
691pub enum ToolCallContent {
692 ContentBlock(ContentBlock),
693 Diff(Entity<Diff>),
694 Terminal(Entity<Terminal>),
695}
696
697impl ToolCallContent {
698 pub fn from_acp(
699 content: acp::ToolCallContent,
700 language_registry: Arc<LanguageRegistry>,
701 path_style: PathStyle,
702 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
703 cx: &mut App,
704 ) -> Result<Option<Self>> {
705 match content {
706 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
707 Ok(Some(Self::ContentBlock(ContentBlock::new(
708 content,
709 &language_registry,
710 path_style,
711 cx,
712 ))))
713 }
714 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
715 Diff::finalized(
716 diff.path.to_string_lossy().into_owned(),
717 diff.old_text,
718 diff.new_text,
719 language_registry,
720 cx,
721 )
722 })))),
723 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
724 .get(&terminal_id)
725 .cloned()
726 .map(|terminal| Some(Self::Terminal(terminal)))
727 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
728 _ => Ok(None),
729 }
730 }
731
732 pub fn update_from_acp(
733 &mut self,
734 new: acp::ToolCallContent,
735 language_registry: Arc<LanguageRegistry>,
736 path_style: PathStyle,
737 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
738 cx: &mut App,
739 ) -> Result<bool> {
740 let needs_update = match (&self, &new) {
741 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
742 old_diff.read(cx).needs_update(
743 new_diff.old_text.as_deref().unwrap_or(""),
744 &new_diff.new_text,
745 cx,
746 )
747 }
748 _ => true,
749 };
750
751 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
752 if needs_update {
753 *self = update;
754 }
755 Ok(true)
756 } else {
757 Ok(false)
758 }
759 }
760
761 pub fn to_markdown(&self, cx: &App) -> String {
762 match self {
763 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
764 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
765 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
766 }
767 }
768
769 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
770 match self {
771 Self::ContentBlock(content) => content.image(),
772 _ => None,
773 }
774 }
775}
776
777#[derive(Debug, PartialEq)]
778pub enum ToolCallUpdate {
779 UpdateFields(acp::ToolCallUpdate),
780 UpdateDiff(ToolCallUpdateDiff),
781 UpdateTerminal(ToolCallUpdateTerminal),
782}
783
784impl ToolCallUpdate {
785 fn id(&self) -> &acp::ToolCallId {
786 match self {
787 Self::UpdateFields(update) => &update.tool_call_id,
788 Self::UpdateDiff(diff) => &diff.id,
789 Self::UpdateTerminal(terminal) => &terminal.id,
790 }
791 }
792}
793
794impl From<acp::ToolCallUpdate> for ToolCallUpdate {
795 fn from(update: acp::ToolCallUpdate) -> Self {
796 Self::UpdateFields(update)
797 }
798}
799
800impl From<ToolCallUpdateDiff> for ToolCallUpdate {
801 fn from(diff: ToolCallUpdateDiff) -> Self {
802 Self::UpdateDiff(diff)
803 }
804}
805
806#[derive(Debug, PartialEq)]
807pub struct ToolCallUpdateDiff {
808 pub id: acp::ToolCallId,
809 pub diff: Entity<Diff>,
810}
811
812impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
813 fn from(terminal: ToolCallUpdateTerminal) -> Self {
814 Self::UpdateTerminal(terminal)
815 }
816}
817
818#[derive(Debug, PartialEq)]
819pub struct ToolCallUpdateTerminal {
820 pub id: acp::ToolCallId,
821 pub terminal: Entity<Terminal>,
822}
823
824#[derive(Debug, Default)]
825pub struct Plan {
826 pub entries: Vec<PlanEntry>,
827}
828
829#[derive(Debug)]
830pub struct PlanStats<'a> {
831 pub in_progress_entry: Option<&'a PlanEntry>,
832 pub pending: u32,
833 pub completed: u32,
834}
835
836impl Plan {
837 pub fn is_empty(&self) -> bool {
838 self.entries.is_empty()
839 }
840
841 pub fn stats(&self) -> PlanStats<'_> {
842 let mut stats = PlanStats {
843 in_progress_entry: None,
844 pending: 0,
845 completed: 0,
846 };
847
848 for entry in &self.entries {
849 match &entry.status {
850 acp::PlanEntryStatus::Pending => {
851 stats.pending += 1;
852 }
853 acp::PlanEntryStatus::InProgress => {
854 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
855 }
856 acp::PlanEntryStatus::Completed => {
857 stats.completed += 1;
858 }
859 _ => {}
860 }
861 }
862
863 stats
864 }
865}
866
867#[derive(Debug)]
868pub struct PlanEntry {
869 pub content: Entity<Markdown>,
870 pub priority: acp::PlanEntryPriority,
871 pub status: acp::PlanEntryStatus,
872}
873
874impl PlanEntry {
875 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
876 Self {
877 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
878 priority: entry.priority,
879 status: entry.status,
880 }
881 }
882}
883
884#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
885pub struct TokenUsage {
886 pub max_tokens: u64,
887 pub used_tokens: u64,
888 pub input_tokens: u64,
889 pub output_tokens: u64,
890}
891
892impl TokenUsage {
893 pub fn ratio(&self) -> TokenUsageRatio {
894 #[cfg(debug_assertions)]
895 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
896 .unwrap_or("0.8".to_string())
897 .parse()
898 .unwrap();
899 #[cfg(not(debug_assertions))]
900 let warning_threshold: f32 = 0.8;
901
902 // When the maximum is unknown because there is no selected model,
903 // avoid showing the token limit warning.
904 if self.max_tokens == 0 {
905 TokenUsageRatio::Normal
906 } else if self.used_tokens >= self.max_tokens {
907 TokenUsageRatio::Exceeded
908 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
909 TokenUsageRatio::Warning
910 } else {
911 TokenUsageRatio::Normal
912 }
913 }
914}
915
916#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
917pub enum TokenUsageRatio {
918 Normal,
919 Warning,
920 Exceeded,
921}
922
923#[derive(Debug, Clone)]
924pub struct RetryStatus {
925 pub last_error: SharedString,
926 pub attempt: usize,
927 pub max_attempts: usize,
928 pub started_at: Instant,
929 pub duration: Duration,
930}
931
932pub struct AcpThread {
933 parent_session_id: Option<acp::SessionId>,
934 title: SharedString,
935 entries: Vec<AgentThreadEntry>,
936 plan: Plan,
937 project: Entity<Project>,
938 action_log: Entity<ActionLog>,
939 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
940 send_task: Option<Task<()>>,
941 connection: Rc<dyn AgentConnection>,
942 session_id: acp::SessionId,
943 token_usage: Option<TokenUsage>,
944 prompt_capabilities: acp::PromptCapabilities,
945 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
946 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
947 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
948 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
949 // subagent cancellation fields
950 user_stopped: Arc<std::sync::atomic::AtomicBool>,
951 user_stop_tx: watch::Sender<bool>,
952}
953
954impl From<&AcpThread> for ActionLogTelemetry {
955 fn from(value: &AcpThread) -> Self {
956 Self {
957 agent_telemetry_id: value.connection().telemetry_id(),
958 session_id: value.session_id.0.clone(),
959 }
960 }
961}
962
963#[derive(Debug)]
964pub enum AcpThreadEvent {
965 NewEntry,
966 TitleUpdated,
967 TokenUsageUpdated,
968 EntryUpdated(usize),
969 EntriesRemoved(Range<usize>),
970 ToolAuthorizationRequired,
971 Retry(RetryStatus),
972 SubagentSpawned(acp::SessionId),
973 Stopped,
974 Error,
975 LoadError(LoadError),
976 PromptCapabilitiesUpdated,
977 Refusal,
978 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
979 ModeUpdated(acp::SessionModeId),
980 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
981}
982
983impl EventEmitter<AcpThreadEvent> for AcpThread {}
984
985#[derive(Debug, Clone)]
986pub enum TerminalProviderEvent {
987 Created {
988 terminal_id: acp::TerminalId,
989 label: String,
990 cwd: Option<PathBuf>,
991 output_byte_limit: Option<u64>,
992 terminal: Entity<::terminal::Terminal>,
993 },
994 Output {
995 terminal_id: acp::TerminalId,
996 data: Vec<u8>,
997 },
998 TitleChanged {
999 terminal_id: acp::TerminalId,
1000 title: String,
1001 },
1002 Exit {
1003 terminal_id: acp::TerminalId,
1004 status: acp::TerminalExitStatus,
1005 },
1006}
1007
1008#[derive(Debug, Clone)]
1009pub enum TerminalProviderCommand {
1010 WriteInput {
1011 terminal_id: acp::TerminalId,
1012 bytes: Vec<u8>,
1013 },
1014 Resize {
1015 terminal_id: acp::TerminalId,
1016 cols: u16,
1017 rows: u16,
1018 },
1019 Close {
1020 terminal_id: acp::TerminalId,
1021 },
1022}
1023
1024impl AcpThread {
1025 pub fn on_terminal_provider_event(
1026 &mut self,
1027 event: TerminalProviderEvent,
1028 cx: &mut Context<Self>,
1029 ) {
1030 match event {
1031 TerminalProviderEvent::Created {
1032 terminal_id,
1033 label,
1034 cwd,
1035 output_byte_limit,
1036 terminal,
1037 } => {
1038 let entity = self.register_terminal_created(
1039 terminal_id.clone(),
1040 label,
1041 cwd,
1042 output_byte_limit,
1043 terminal,
1044 cx,
1045 );
1046
1047 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1048 for data in chunks.drain(..) {
1049 entity.update(cx, |term, cx| {
1050 term.inner().update(cx, |inner, cx| {
1051 inner.write_output(&data, cx);
1052 })
1053 });
1054 }
1055 }
1056
1057 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1058 entity.update(cx, |_term, cx| {
1059 cx.notify();
1060 });
1061 }
1062
1063 cx.notify();
1064 }
1065 TerminalProviderEvent::Output { terminal_id, data } => {
1066 if let Some(entity) = self.terminals.get(&terminal_id) {
1067 entity.update(cx, |term, cx| {
1068 term.inner().update(cx, |inner, cx| {
1069 inner.write_output(&data, cx);
1070 })
1071 });
1072 } else {
1073 self.pending_terminal_output
1074 .entry(terminal_id)
1075 .or_default()
1076 .push(data);
1077 }
1078 }
1079 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1080 if let Some(entity) = self.terminals.get(&terminal_id) {
1081 entity.update(cx, |term, cx| {
1082 term.inner().update(cx, |inner, cx| {
1083 inner.breadcrumb_text = title;
1084 cx.emit(::terminal::Event::BreadcrumbsChanged);
1085 })
1086 });
1087 }
1088 }
1089 TerminalProviderEvent::Exit {
1090 terminal_id,
1091 status,
1092 } => {
1093 if let Some(entity) = self.terminals.get(&terminal_id) {
1094 entity.update(cx, |_term, cx| {
1095 cx.notify();
1096 });
1097 } else {
1098 self.pending_terminal_exit.insert(terminal_id, status);
1099 }
1100 }
1101 }
1102 }
1103}
1104
1105#[derive(PartialEq, Eq, Debug)]
1106pub enum ThreadStatus {
1107 Idle,
1108 Generating,
1109}
1110
1111#[derive(Debug, Clone)]
1112pub enum LoadError {
1113 Unsupported {
1114 command: SharedString,
1115 current_version: SharedString,
1116 minimum_version: SharedString,
1117 },
1118 FailedToInstall(SharedString),
1119 Exited {
1120 status: ExitStatus,
1121 },
1122 Other(SharedString),
1123}
1124
1125impl Display for LoadError {
1126 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1127 match self {
1128 LoadError::Unsupported {
1129 command: path,
1130 current_version,
1131 minimum_version,
1132 } => {
1133 write!(
1134 f,
1135 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1136 )
1137 }
1138 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1139 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1140 LoadError::Other(msg) => write!(f, "{msg}"),
1141 }
1142 }
1143}
1144
1145impl Error for LoadError {}
1146
1147impl AcpThread {
1148 pub fn new(
1149 parent_session_id: Option<acp::SessionId>,
1150 title: impl Into<SharedString>,
1151 connection: Rc<dyn AgentConnection>,
1152 project: Entity<Project>,
1153 action_log: Entity<ActionLog>,
1154 session_id: acp::SessionId,
1155 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1156 cx: &mut Context<Self>,
1157 ) -> Self {
1158 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1159 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1160 loop {
1161 let caps = prompt_capabilities_rx.recv().await?;
1162 this.update(cx, |this, cx| {
1163 this.prompt_capabilities = caps;
1164 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1165 })?;
1166 }
1167 });
1168
1169 let (user_stop_tx, _user_stop_rx) = watch::channel(false);
1170
1171 Self {
1172 parent_session_id,
1173 action_log,
1174 shared_buffers: Default::default(),
1175 entries: Default::default(),
1176 plan: Default::default(),
1177 title: title.into(),
1178 project,
1179 send_task: None,
1180 connection,
1181 session_id,
1182 token_usage: None,
1183 prompt_capabilities,
1184 _observe_prompt_capabilities: task,
1185 terminals: HashMap::default(),
1186 pending_terminal_output: HashMap::default(),
1187 pending_terminal_exit: HashMap::default(),
1188 user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1189 user_stop_tx,
1190 }
1191 }
1192
1193 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1194 self.parent_session_id.as_ref()
1195 }
1196
1197 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1198 self.prompt_capabilities.clone()
1199 }
1200
1201 /// Marks this thread as stopped by user action and signals any listeners.
1202 pub fn stop_by_user(&mut self) {
1203 self.user_stopped
1204 .store(true, std::sync::atomic::Ordering::SeqCst);
1205 self.user_stop_tx.send(true).ok();
1206 self.send_task.take();
1207 }
1208
1209 pub fn was_stopped_by_user(&self) -> bool {
1210 self.user_stopped.load(std::sync::atomic::Ordering::SeqCst)
1211 }
1212
1213 pub fn user_stop_receiver(&self) -> watch::Receiver<bool> {
1214 self.user_stop_tx.receiver()
1215 }
1216
1217 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1218 &self.connection
1219 }
1220
1221 pub fn action_log(&self) -> &Entity<ActionLog> {
1222 &self.action_log
1223 }
1224
1225 pub fn project(&self) -> &Entity<Project> {
1226 &self.project
1227 }
1228
1229 pub fn title(&self) -> SharedString {
1230 self.title.clone()
1231 }
1232
1233 pub fn entries(&self) -> &[AgentThreadEntry] {
1234 &self.entries
1235 }
1236
1237 pub fn session_id(&self) -> &acp::SessionId {
1238 &self.session_id
1239 }
1240
1241 pub fn status(&self) -> ThreadStatus {
1242 if self.send_task.is_some() {
1243 ThreadStatus::Generating
1244 } else {
1245 ThreadStatus::Idle
1246 }
1247 }
1248
1249 pub fn token_usage(&self) -> Option<&TokenUsage> {
1250 self.token_usage.as_ref()
1251 }
1252
1253 pub fn has_pending_edit_tool_calls(&self) -> bool {
1254 for entry in self.entries.iter().rev() {
1255 match entry {
1256 AgentThreadEntry::UserMessage(_) => return false,
1257 AgentThreadEntry::ToolCall(
1258 call @ ToolCall {
1259 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1260 ..
1261 },
1262 ) if call.diffs().next().is_some() => {
1263 return true;
1264 }
1265 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1266 }
1267 }
1268
1269 false
1270 }
1271
1272 pub fn has_in_progress_tool_calls(&self) -> bool {
1273 for entry in self.entries.iter().rev() {
1274 match entry {
1275 AgentThreadEntry::UserMessage(_) => return false,
1276 AgentThreadEntry::ToolCall(ToolCall {
1277 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1278 ..
1279 }) => {
1280 return true;
1281 }
1282 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1283 }
1284 }
1285
1286 false
1287 }
1288
1289 pub fn used_tools_since_last_user_message(&self) -> bool {
1290 for entry in self.entries.iter().rev() {
1291 match entry {
1292 AgentThreadEntry::UserMessage(..) => return false,
1293 AgentThreadEntry::AssistantMessage(..) => continue,
1294 AgentThreadEntry::ToolCall(..) => return true,
1295 }
1296 }
1297
1298 false
1299 }
1300
1301 pub fn handle_session_update(
1302 &mut self,
1303 update: acp::SessionUpdate,
1304 cx: &mut Context<Self>,
1305 ) -> Result<(), acp::Error> {
1306 match update {
1307 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1308 self.push_user_content_block(None, content, cx);
1309 }
1310 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1311 self.push_assistant_content_block(content, false, cx);
1312 }
1313 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1314 self.push_assistant_content_block(content, true, cx);
1315 }
1316 acp::SessionUpdate::ToolCall(tool_call) => {
1317 self.upsert_tool_call(tool_call, cx)?;
1318 }
1319 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1320 self.update_tool_call(tool_call_update, cx)?;
1321 }
1322 acp::SessionUpdate::Plan(plan) => {
1323 self.update_plan(plan, cx);
1324 }
1325 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1326 available_commands,
1327 ..
1328 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1329 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1330 current_mode_id,
1331 ..
1332 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1333 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1334 config_options,
1335 ..
1336 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1337 _ => {}
1338 }
1339 Ok(())
1340 }
1341
1342 pub fn push_user_content_block(
1343 &mut self,
1344 message_id: Option<UserMessageId>,
1345 chunk: acp::ContentBlock,
1346 cx: &mut Context<Self>,
1347 ) {
1348 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1349 }
1350
1351 pub fn push_user_content_block_with_indent(
1352 &mut self,
1353 message_id: Option<UserMessageId>,
1354 chunk: acp::ContentBlock,
1355 indented: bool,
1356 cx: &mut Context<Self>,
1357 ) {
1358 let language_registry = self.project.read(cx).languages().clone();
1359 let path_style = self.project.read(cx).path_style(cx);
1360 let entries_len = self.entries.len();
1361
1362 if let Some(last_entry) = self.entries.last_mut()
1363 && let AgentThreadEntry::UserMessage(UserMessage {
1364 id,
1365 content,
1366 chunks,
1367 indented: existing_indented,
1368 ..
1369 }) = last_entry
1370 && *existing_indented == indented
1371 {
1372 *id = message_id.or(id.take());
1373 content.append(chunk.clone(), &language_registry, path_style, cx);
1374 chunks.push(chunk);
1375 let idx = entries_len - 1;
1376 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1377 } else {
1378 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1379 self.push_entry(
1380 AgentThreadEntry::UserMessage(UserMessage {
1381 id: message_id,
1382 content,
1383 chunks: vec![chunk],
1384 checkpoint: None,
1385 indented,
1386 }),
1387 cx,
1388 );
1389 }
1390 }
1391
1392 pub fn push_assistant_content_block(
1393 &mut self,
1394 chunk: acp::ContentBlock,
1395 is_thought: bool,
1396 cx: &mut Context<Self>,
1397 ) {
1398 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1399 }
1400
1401 pub fn push_assistant_content_block_with_indent(
1402 &mut self,
1403 chunk: acp::ContentBlock,
1404 is_thought: bool,
1405 indented: bool,
1406 cx: &mut Context<Self>,
1407 ) {
1408 let language_registry = self.project.read(cx).languages().clone();
1409 let path_style = self.project.read(cx).path_style(cx);
1410 let entries_len = self.entries.len();
1411 if let Some(last_entry) = self.entries.last_mut()
1412 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1413 chunks,
1414 indented: existing_indented,
1415 }) = last_entry
1416 && *existing_indented == indented
1417 {
1418 let idx = entries_len - 1;
1419 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1420 match (chunks.last_mut(), is_thought) {
1421 (Some(AssistantMessageChunk::Message { block }), false)
1422 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1423 block.append(chunk, &language_registry, path_style, cx)
1424 }
1425 _ => {
1426 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1427 if is_thought {
1428 chunks.push(AssistantMessageChunk::Thought { block })
1429 } else {
1430 chunks.push(AssistantMessageChunk::Message { block })
1431 }
1432 }
1433 }
1434 } else {
1435 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1436 let chunk = if is_thought {
1437 AssistantMessageChunk::Thought { block }
1438 } else {
1439 AssistantMessageChunk::Message { block }
1440 };
1441
1442 self.push_entry(
1443 AgentThreadEntry::AssistantMessage(AssistantMessage {
1444 chunks: vec![chunk],
1445 indented,
1446 }),
1447 cx,
1448 );
1449 }
1450 }
1451
1452 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1453 self.entries.push(entry);
1454 cx.emit(AcpThreadEvent::NewEntry);
1455 }
1456
1457 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1458 self.connection.set_title(&self.session_id, cx).is_some()
1459 }
1460
1461 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1462 if title != self.title {
1463 self.title = title.clone();
1464 cx.emit(AcpThreadEvent::TitleUpdated);
1465 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1466 return set_title.run(title, cx);
1467 }
1468 }
1469 Task::ready(Ok(()))
1470 }
1471
1472 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1473 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1474 }
1475
1476 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1477 self.token_usage = usage;
1478 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1479 }
1480
1481 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1482 cx.emit(AcpThreadEvent::Retry(status));
1483 }
1484
1485 pub fn update_tool_call(
1486 &mut self,
1487 update: impl Into<ToolCallUpdate>,
1488 cx: &mut Context<Self>,
1489 ) -> Result<()> {
1490 let update = update.into();
1491 let languages = self.project.read(cx).languages().clone();
1492 let path_style = self.project.read(cx).path_style(cx);
1493
1494 let ix = match self.index_for_tool_call(update.id()) {
1495 Some(ix) => ix,
1496 None => {
1497 // Tool call not found - create a failed tool call entry
1498 let failed_tool_call = ToolCall {
1499 id: update.id().clone(),
1500 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1501 kind: acp::ToolKind::Fetch,
1502 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1503 "Tool call not found".into(),
1504 &languages,
1505 path_style,
1506 cx,
1507 ))],
1508 status: ToolCallStatus::Failed,
1509 locations: Vec::new(),
1510 resolved_locations: Vec::new(),
1511 raw_input: None,
1512 raw_input_markdown: None,
1513 raw_output: None,
1514 tool_name: None,
1515 subagent_session_id: None,
1516 };
1517 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1518 return Ok(());
1519 }
1520 };
1521 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1522 unreachable!()
1523 };
1524
1525 match update {
1526 ToolCallUpdate::UpdateFields(update) => {
1527 let location_updated = update.fields.locations.is_some();
1528 call.update_fields(
1529 update.fields,
1530 update.meta,
1531 languages,
1532 path_style,
1533 &self.terminals,
1534 cx,
1535 )?;
1536 if location_updated {
1537 self.resolve_locations(update.tool_call_id, cx);
1538 }
1539 }
1540 ToolCallUpdate::UpdateDiff(update) => {
1541 call.content.clear();
1542 call.content.push(ToolCallContent::Diff(update.diff));
1543 }
1544 ToolCallUpdate::UpdateTerminal(update) => {
1545 call.content.clear();
1546 call.content
1547 .push(ToolCallContent::Terminal(update.terminal));
1548 }
1549 }
1550
1551 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1552
1553 Ok(())
1554 }
1555
1556 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1557 pub fn upsert_tool_call(
1558 &mut self,
1559 tool_call: acp::ToolCall,
1560 cx: &mut Context<Self>,
1561 ) -> Result<(), acp::Error> {
1562 let status = tool_call.status.into();
1563 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1564 }
1565
1566 /// Fails if id does not match an existing entry.
1567 pub fn upsert_tool_call_inner(
1568 &mut self,
1569 update: acp::ToolCallUpdate,
1570 status: ToolCallStatus,
1571 cx: &mut Context<Self>,
1572 ) -> Result<(), acp::Error> {
1573 let language_registry = self.project.read(cx).languages().clone();
1574 let path_style = self.project.read(cx).path_style(cx);
1575 let id = update.tool_call_id.clone();
1576
1577 let agent_telemetry_id = self.connection().telemetry_id();
1578 let session = self.session_id();
1579 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1580 let status = if matches!(status, ToolCallStatus::Completed) {
1581 "completed"
1582 } else {
1583 "failed"
1584 };
1585 telemetry::event!(
1586 "Agent Tool Call Completed",
1587 agent_telemetry_id,
1588 session,
1589 status
1590 );
1591 }
1592
1593 if let Some(ix) = self.index_for_tool_call(&id) {
1594 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1595 unreachable!()
1596 };
1597
1598 call.update_fields(
1599 update.fields,
1600 update.meta,
1601 language_registry,
1602 path_style,
1603 &self.terminals,
1604 cx,
1605 )?;
1606 call.status = status;
1607
1608 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1609 } else {
1610 let call = ToolCall::from_acp(
1611 update.try_into()?,
1612 status,
1613 language_registry,
1614 self.project.read(cx).path_style(cx),
1615 &self.terminals,
1616 cx,
1617 )?;
1618 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1619 };
1620
1621 self.resolve_locations(id, cx);
1622 Ok(())
1623 }
1624
1625 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1626 self.entries
1627 .iter()
1628 .enumerate()
1629 .rev()
1630 .find_map(|(index, entry)| {
1631 if let AgentThreadEntry::ToolCall(tool_call) = entry
1632 && &tool_call.id == id
1633 {
1634 Some(index)
1635 } else {
1636 None
1637 }
1638 })
1639 }
1640
1641 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1642 // The tool call we are looking for is typically the last one, or very close to the end.
1643 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1644 self.entries
1645 .iter_mut()
1646 .enumerate()
1647 .rev()
1648 .find_map(|(index, tool_call)| {
1649 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1650 && &tool_call.id == id
1651 {
1652 Some((index, tool_call))
1653 } else {
1654 None
1655 }
1656 })
1657 }
1658
1659 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1660 self.entries
1661 .iter()
1662 .enumerate()
1663 .rev()
1664 .find_map(|(index, tool_call)| {
1665 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1666 && &tool_call.id == id
1667 {
1668 Some((index, tool_call))
1669 } else {
1670 None
1671 }
1672 })
1673 }
1674
1675 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1676 let project = self.project.clone();
1677 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1678 return;
1679 };
1680 let task = tool_call.resolve_locations(project, cx);
1681 cx.spawn(async move |this, cx| {
1682 let resolved_locations = task.await;
1683
1684 this.update(cx, |this, cx| {
1685 let project = this.project.clone();
1686
1687 for location in resolved_locations.iter().flatten() {
1688 this.shared_buffers
1689 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1690 }
1691 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1692 return;
1693 };
1694
1695 if let Some(Some(location)) = resolved_locations.last() {
1696 project.update(cx, |project, cx| {
1697 let should_ignore = if let Some(agent_location) = project
1698 .agent_location()
1699 .filter(|agent_location| agent_location.buffer == location.buffer)
1700 {
1701 let snapshot = location.buffer.read(cx).snapshot();
1702 let old_position = agent_location.position.to_point(&snapshot);
1703 let new_position = location.position.to_point(&snapshot);
1704
1705 // ignore this so that when we get updates from the edit tool
1706 // the position doesn't reset to the startof line
1707 old_position.row == new_position.row
1708 && old_position.column > new_position.column
1709 } else {
1710 false
1711 };
1712 if !should_ignore {
1713 project.set_agent_location(Some(location.into()), cx);
1714 }
1715 });
1716 }
1717
1718 let resolved_locations = resolved_locations
1719 .iter()
1720 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1721 .collect::<Vec<_>>();
1722
1723 if tool_call.resolved_locations != resolved_locations {
1724 tool_call.resolved_locations = resolved_locations;
1725 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1726 }
1727 })
1728 })
1729 .detach();
1730 }
1731
1732 pub fn request_tool_call_authorization(
1733 &mut self,
1734 tool_call: acp::ToolCallUpdate,
1735 options: PermissionOptions,
1736 respect_always_allow_setting: bool,
1737 cx: &mut Context<Self>,
1738 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1739 let (tx, rx) = oneshot::channel();
1740
1741 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1742 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1743 // some tools would (incorrectly) continue to auto-accept.
1744 if let Some(allow_once_option) = options.allow_once_option_id() {
1745 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1746 return Ok(async {
1747 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1748 allow_once_option,
1749 ))
1750 }
1751 .boxed());
1752 }
1753 }
1754
1755 let status = ToolCallStatus::WaitingForConfirmation {
1756 options,
1757 respond_tx: tx,
1758 };
1759
1760 self.upsert_tool_call_inner(tool_call, status, cx)?;
1761 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1762
1763 let fut = async {
1764 match rx.await {
1765 Ok(option) => acp::RequestPermissionOutcome::Selected(
1766 acp::SelectedPermissionOutcome::new(option),
1767 ),
1768 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1769 }
1770 }
1771 .boxed();
1772
1773 Ok(fut)
1774 }
1775
1776 pub fn authorize_tool_call(
1777 &mut self,
1778 id: acp::ToolCallId,
1779 option_id: acp::PermissionOptionId,
1780 option_kind: acp::PermissionOptionKind,
1781 cx: &mut Context<Self>,
1782 ) {
1783 let Some((ix, call)) = self.tool_call_mut(&id) else {
1784 return;
1785 };
1786
1787 let new_status = match option_kind {
1788 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1789 ToolCallStatus::Rejected
1790 }
1791 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1792 ToolCallStatus::InProgress
1793 }
1794 _ => ToolCallStatus::InProgress,
1795 };
1796
1797 let curr_status = mem::replace(&mut call.status, new_status);
1798
1799 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1800 respond_tx.send(option_id).log_err();
1801 } else if cfg!(debug_assertions) {
1802 panic!("tried to authorize an already authorized tool call");
1803 }
1804
1805 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1806 }
1807
1808 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1809 let mut first_tool_call = None;
1810
1811 for entry in self.entries.iter().rev() {
1812 match &entry {
1813 AgentThreadEntry::ToolCall(call) => {
1814 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1815 first_tool_call = Some(call);
1816 } else {
1817 continue;
1818 }
1819 }
1820 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1821 // Reached the beginning of the turn.
1822 // If we had pending permission requests in the previous turn, they have been cancelled.
1823 break;
1824 }
1825 }
1826 }
1827
1828 first_tool_call
1829 }
1830
1831 pub fn plan(&self) -> &Plan {
1832 &self.plan
1833 }
1834
1835 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1836 let new_entries_len = request.entries.len();
1837 let mut new_entries = request.entries.into_iter();
1838
1839 // Reuse existing markdown to prevent flickering
1840 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1841 let PlanEntry {
1842 content,
1843 priority,
1844 status,
1845 } = old;
1846 content.update(cx, |old, cx| {
1847 old.replace(new.content, cx);
1848 });
1849 *priority = new.priority;
1850 *status = new.status;
1851 }
1852 for new in new_entries {
1853 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1854 }
1855 self.plan.entries.truncate(new_entries_len);
1856
1857 cx.notify();
1858 }
1859
1860 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1861 self.plan
1862 .entries
1863 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1864 cx.notify();
1865 }
1866
1867 #[cfg(any(test, feature = "test-support"))]
1868 pub fn send_raw(
1869 &mut self,
1870 message: &str,
1871 cx: &mut Context<Self>,
1872 ) -> BoxFuture<'static, Result<()>> {
1873 self.send(vec![message.into()], cx)
1874 }
1875
1876 pub fn send(
1877 &mut self,
1878 message: Vec<acp::ContentBlock>,
1879 cx: &mut Context<Self>,
1880 ) -> BoxFuture<'static, Result<()>> {
1881 let block = ContentBlock::new_combined(
1882 message.clone(),
1883 self.project.read(cx).languages().clone(),
1884 self.project.read(cx).path_style(cx),
1885 cx,
1886 );
1887 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1888 let git_store = self.project.read(cx).git_store().clone();
1889
1890 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1891 Some(UserMessageId::new())
1892 } else {
1893 None
1894 };
1895
1896 self.run_turn(cx, async move |this, cx| {
1897 this.update(cx, |this, cx| {
1898 this.push_entry(
1899 AgentThreadEntry::UserMessage(UserMessage {
1900 id: message_id.clone(),
1901 content: block,
1902 chunks: message,
1903 checkpoint: None,
1904 indented: false,
1905 }),
1906 cx,
1907 );
1908 })
1909 .ok();
1910
1911 let old_checkpoint = git_store
1912 .update(cx, |git, cx| git.checkpoint(cx))
1913 .await
1914 .context("failed to get old checkpoint")
1915 .log_err();
1916 this.update(cx, |this, cx| {
1917 if let Some((_ix, message)) = this.last_user_message() {
1918 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1919 git_checkpoint,
1920 show: false,
1921 });
1922 }
1923 this.connection.prompt(message_id, request, cx)
1924 })?
1925 .await
1926 })
1927 }
1928
1929 pub fn can_retry(&self, cx: &App) -> bool {
1930 self.connection.retry(&self.session_id, cx).is_some()
1931 }
1932
1933 pub fn retry(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1934 self.run_turn(cx, async move |this, cx| {
1935 this.update(cx, |this, cx| {
1936 this.connection
1937 .retry(&this.session_id, cx)
1938 .map(|retry| retry.run(cx))
1939 })?
1940 .context("retrying a session is not supported")?
1941 .await
1942 })
1943 }
1944
1945 fn run_turn(
1946 &mut self,
1947 cx: &mut Context<Self>,
1948 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1949 ) -> BoxFuture<'static, Result<()>> {
1950 self.clear_completed_plan_entries(cx);
1951
1952 let (tx, rx) = oneshot::channel();
1953 let cancel_task = self.cancel(cx);
1954
1955 self.send_task = Some(cx.spawn(async move |this, cx| {
1956 cancel_task.await;
1957 tx.send(f(this, cx).await).ok();
1958 }));
1959
1960 cx.spawn(async move |this, cx| {
1961 let response = rx.await;
1962
1963 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1964 .await?;
1965
1966 this.update(cx, |this, cx| {
1967 this.project
1968 .update(cx, |project, cx| project.set_agent_location(None, cx));
1969 match response {
1970 Ok(Err(e)) => {
1971 this.send_task.take();
1972 cx.emit(AcpThreadEvent::Error);
1973 Err(e)
1974 }
1975 result => {
1976 let canceled = matches!(
1977 result,
1978 Ok(Ok(acp::PromptResponse {
1979 stop_reason: acp::StopReason::Cancelled,
1980 ..
1981 }))
1982 );
1983
1984 // We only take the task if the current prompt wasn't canceled.
1985 //
1986 // This prompt may have been canceled because another one was sent
1987 // while it was still generating. In these cases, dropping `send_task`
1988 // would cause the next generation to be canceled.
1989 if !canceled {
1990 this.send_task.take();
1991 }
1992
1993 // Handle refusal - distinguish between user prompt and tool call refusals
1994 if let Ok(Ok(acp::PromptResponse {
1995 stop_reason: acp::StopReason::Refusal,
1996 ..
1997 })) = result
1998 {
1999 if let Some((user_msg_ix, _)) = this.last_user_message() {
2000 // Check if there's a completed tool call with results after the last user message
2001 // This indicates the refusal is in response to tool output, not the user's prompt
2002 let has_completed_tool_call_after_user_msg =
2003 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
2004 if let AgentThreadEntry::ToolCall(tool_call) = entry {
2005 // Check if the tool call has completed and has output
2006 matches!(tool_call.status, ToolCallStatus::Completed)
2007 && tool_call.raw_output.is_some()
2008 } else {
2009 false
2010 }
2011 });
2012
2013 if has_completed_tool_call_after_user_msg {
2014 // Refusal is due to tool output - don't truncate, just notify
2015 // The model refused based on what the tool returned
2016 cx.emit(AcpThreadEvent::Refusal);
2017 } else {
2018 // User prompt was refused - truncate back to before the user message
2019 let range = user_msg_ix..this.entries.len();
2020 if range.start < range.end {
2021 this.entries.truncate(user_msg_ix);
2022 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2023 }
2024 cx.emit(AcpThreadEvent::Refusal);
2025 }
2026 } else {
2027 // No user message found, treat as general refusal
2028 cx.emit(AcpThreadEvent::Refusal);
2029 }
2030 }
2031
2032 cx.emit(AcpThreadEvent::Stopped);
2033 Ok(())
2034 }
2035 }
2036 })?
2037 })
2038 .boxed()
2039 }
2040
2041 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2042 let Some(send_task) = self.send_task.take() else {
2043 return Task::ready(());
2044 };
2045
2046 for entry in self.entries.iter_mut() {
2047 if let AgentThreadEntry::ToolCall(call) = entry {
2048 let cancel = matches!(
2049 call.status,
2050 ToolCallStatus::Pending
2051 | ToolCallStatus::WaitingForConfirmation { .. }
2052 | ToolCallStatus::InProgress
2053 );
2054
2055 if cancel {
2056 call.status = ToolCallStatus::Canceled;
2057 }
2058 }
2059 }
2060
2061 self.connection.cancel(&self.session_id, cx);
2062
2063 // Wait for the send task to complete
2064 cx.foreground_executor().spawn(send_task)
2065 }
2066
2067 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2068 pub fn restore_checkpoint(
2069 &mut self,
2070 id: UserMessageId,
2071 cx: &mut Context<Self>,
2072 ) -> Task<Result<()>> {
2073 let Some((_, message)) = self.user_message_mut(&id) else {
2074 return Task::ready(Err(anyhow!("message not found")));
2075 };
2076
2077 let checkpoint = message
2078 .checkpoint
2079 .as_ref()
2080 .map(|c| c.git_checkpoint.clone());
2081
2082 // Cancel any in-progress generation before restoring
2083 let cancel_task = self.cancel(cx);
2084 let rewind = self.rewind(id.clone(), cx);
2085 let git_store = self.project.read(cx).git_store().clone();
2086
2087 cx.spawn(async move |_, cx| {
2088 cancel_task.await;
2089 rewind.await?;
2090 if let Some(checkpoint) = checkpoint {
2091 git_store
2092 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2093 .await?;
2094 }
2095
2096 Ok(())
2097 })
2098 }
2099
2100 /// Rewinds this thread to before the entry at `index`, removing it and all
2101 /// subsequent entries while rejecting any action_log changes made from that point.
2102 /// Unlike `restore_checkpoint`, this method does not restore from git.
2103 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2104 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2105 return Task::ready(Err(anyhow!("not supported")));
2106 };
2107
2108 let telemetry = ActionLogTelemetry::from(&*self);
2109 cx.spawn(async move |this, cx| {
2110 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2111 this.update(cx, |this, cx| {
2112 if let Some((ix, _)) = this.user_message_mut(&id) {
2113 // Collect all terminals from entries that will be removed
2114 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2115 .iter()
2116 .flat_map(|entry| entry.terminals())
2117 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2118 .collect();
2119
2120 let range = ix..this.entries.len();
2121 this.entries.truncate(ix);
2122 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2123
2124 // Kill and remove the terminals
2125 for terminal_id in terminals_to_remove {
2126 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2127 terminal.update(cx, |terminal, cx| {
2128 terminal.kill(cx);
2129 });
2130 }
2131 }
2132 }
2133 this.action_log().update(cx, |action_log, cx| {
2134 action_log.reject_all_edits(Some(telemetry), cx)
2135 })
2136 })?
2137 .await;
2138 Ok(())
2139 })
2140 }
2141
2142 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2143 let git_store = self.project.read(cx).git_store().clone();
2144
2145 let Some((_, message)) = self.last_user_message() else {
2146 return Task::ready(Ok(()));
2147 };
2148 let Some(user_message_id) = message.id.clone() else {
2149 return Task::ready(Ok(()));
2150 };
2151 let Some(checkpoint) = message.checkpoint.as_ref() else {
2152 return Task::ready(Ok(()));
2153 };
2154 let old_checkpoint = checkpoint.git_checkpoint.clone();
2155
2156 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2157 cx.spawn(async move |this, cx| {
2158 let Some(new_checkpoint) = new_checkpoint
2159 .await
2160 .context("failed to get new checkpoint")
2161 .log_err()
2162 else {
2163 return Ok(());
2164 };
2165
2166 let equal = git_store
2167 .update(cx, |git, cx| {
2168 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2169 })
2170 .await
2171 .unwrap_or(true);
2172
2173 this.update(cx, |this, cx| {
2174 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2175 if let Some(checkpoint) = message.checkpoint.as_mut() {
2176 checkpoint.show = !equal;
2177 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2178 }
2179 }
2180 })?;
2181
2182 Ok(())
2183 })
2184 }
2185
2186 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2187 self.entries
2188 .iter_mut()
2189 .enumerate()
2190 .rev()
2191 .find_map(|(ix, entry)| {
2192 if let AgentThreadEntry::UserMessage(message) = entry {
2193 Some((ix, message))
2194 } else {
2195 None
2196 }
2197 })
2198 }
2199
2200 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2201 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2202 if let AgentThreadEntry::UserMessage(message) = entry {
2203 if message.id.as_ref() == Some(id) {
2204 Some((ix, message))
2205 } else {
2206 None
2207 }
2208 } else {
2209 None
2210 }
2211 })
2212 }
2213
2214 pub fn read_text_file(
2215 &self,
2216 path: PathBuf,
2217 line: Option<u32>,
2218 limit: Option<u32>,
2219 reuse_shared_snapshot: bool,
2220 cx: &mut Context<Self>,
2221 ) -> Task<Result<String, acp::Error>> {
2222 // Args are 1-based, move to 0-based
2223 let line = line.unwrap_or_default().saturating_sub(1);
2224 let limit = limit.unwrap_or(u32::MAX);
2225 let project = self.project.clone();
2226 let action_log = self.action_log.clone();
2227 cx.spawn(async move |this, cx| {
2228 let load = project.update(cx, |project, cx| {
2229 let path = project
2230 .project_path_for_absolute_path(&path, cx)
2231 .ok_or_else(|| {
2232 acp::Error::resource_not_found(Some(path.display().to_string()))
2233 })?;
2234 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2235 })?;
2236
2237 let buffer = load.await?;
2238
2239 let snapshot = if reuse_shared_snapshot {
2240 this.read_with(cx, |this, _| {
2241 this.shared_buffers.get(&buffer.clone()).cloned()
2242 })
2243 .log_err()
2244 .flatten()
2245 } else {
2246 None
2247 };
2248
2249 let snapshot = if let Some(snapshot) = snapshot {
2250 snapshot
2251 } else {
2252 action_log.update(cx, |action_log, cx| {
2253 action_log.buffer_read(buffer.clone(), cx);
2254 });
2255
2256 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2257 this.update(cx, |this, _| {
2258 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2259 })?;
2260 snapshot
2261 };
2262
2263 let max_point = snapshot.max_point();
2264 let start_position = Point::new(line, 0);
2265
2266 if start_position > max_point {
2267 return Err(acp::Error::invalid_params().data(format!(
2268 "Attempting to read beyond the end of the file, line {}:{}",
2269 max_point.row + 1,
2270 max_point.column
2271 )));
2272 }
2273
2274 let start = snapshot.anchor_before(start_position);
2275 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2276
2277 project.update(cx, |project, cx| {
2278 project.set_agent_location(
2279 Some(AgentLocation {
2280 buffer: buffer.downgrade(),
2281 position: start,
2282 }),
2283 cx,
2284 );
2285 });
2286
2287 Ok(snapshot.text_for_range(start..end).collect::<String>())
2288 })
2289 }
2290
2291 pub fn write_text_file(
2292 &self,
2293 path: PathBuf,
2294 content: String,
2295 cx: &mut Context<Self>,
2296 ) -> Task<Result<()>> {
2297 let project = self.project.clone();
2298 let action_log = self.action_log.clone();
2299 cx.spawn(async move |this, cx| {
2300 let load = project.update(cx, |project, cx| {
2301 let path = project
2302 .project_path_for_absolute_path(&path, cx)
2303 .context("invalid path")?;
2304 anyhow::Ok(project.open_buffer(path, cx))
2305 });
2306 let buffer = load?.await?;
2307 let snapshot = this.update(cx, |this, cx| {
2308 this.shared_buffers
2309 .get(&buffer)
2310 .cloned()
2311 .unwrap_or_else(|| buffer.read(cx).snapshot())
2312 })?;
2313 let edits = cx
2314 .background_executor()
2315 .spawn(async move {
2316 let old_text = snapshot.text();
2317 text_diff(old_text.as_str(), &content)
2318 .into_iter()
2319 .map(|(range, replacement)| {
2320 (
2321 snapshot.anchor_after(range.start)
2322 ..snapshot.anchor_before(range.end),
2323 replacement,
2324 )
2325 })
2326 .collect::<Vec<_>>()
2327 })
2328 .await;
2329
2330 project.update(cx, |project, cx| {
2331 project.set_agent_location(
2332 Some(AgentLocation {
2333 buffer: buffer.downgrade(),
2334 position: edits
2335 .last()
2336 .map(|(range, _)| range.end)
2337 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2338 }),
2339 cx,
2340 );
2341 });
2342
2343 let format_on_save = cx.update(|cx| {
2344 action_log.update(cx, |action_log, cx| {
2345 action_log.buffer_read(buffer.clone(), cx);
2346 });
2347
2348 let format_on_save = buffer.update(cx, |buffer, cx| {
2349 buffer.edit(edits, None, cx);
2350
2351 let settings = language::language_settings::language_settings(
2352 buffer.language().map(|l| l.name()),
2353 buffer.file(),
2354 cx,
2355 );
2356
2357 settings.format_on_save != FormatOnSave::Off
2358 });
2359 action_log.update(cx, |action_log, cx| {
2360 action_log.buffer_edited(buffer.clone(), cx);
2361 });
2362 format_on_save
2363 });
2364
2365 if format_on_save {
2366 let format_task = project.update(cx, |project, cx| {
2367 project.format(
2368 HashSet::from_iter([buffer.clone()]),
2369 LspFormatTarget::Buffers,
2370 false,
2371 FormatTrigger::Save,
2372 cx,
2373 )
2374 });
2375 format_task.await.log_err();
2376
2377 action_log.update(cx, |action_log, cx| {
2378 action_log.buffer_edited(buffer.clone(), cx);
2379 });
2380 }
2381
2382 project
2383 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2384 .await
2385 })
2386 }
2387
2388 pub fn create_terminal(
2389 &self,
2390 command: String,
2391 args: Vec<String>,
2392 extra_env: Vec<acp::EnvVariable>,
2393 cwd: Option<PathBuf>,
2394 output_byte_limit: Option<u64>,
2395 cx: &mut Context<Self>,
2396 ) -> Task<Result<Entity<Terminal>>> {
2397 let env = match &cwd {
2398 Some(dir) => self.project.update(cx, |project, cx| {
2399 project.environment().update(cx, |env, cx| {
2400 env.directory_environment(dir.as_path().into(), cx)
2401 })
2402 }),
2403 None => Task::ready(None).shared(),
2404 };
2405 let env = cx.spawn(async move |_, _| {
2406 let mut env = env.await.unwrap_or_default();
2407 // Disables paging for `git` and hopefully other commands
2408 env.insert("PAGER".into(), "".into());
2409 for var in extra_env {
2410 env.insert(var.name, var.value);
2411 }
2412 env
2413 });
2414
2415 let project = self.project.clone();
2416 let language_registry = project.read(cx).languages().clone();
2417 let is_windows = project.read(cx).path_style(cx).is_windows();
2418
2419 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2420 let terminal_task = cx.spawn({
2421 let terminal_id = terminal_id.clone();
2422 async move |_this, cx| {
2423 let env = env.await;
2424 let shell = project
2425 .update(cx, |project, cx| {
2426 project
2427 .remote_client()
2428 .and_then(|r| r.read(cx).default_system_shell())
2429 })
2430 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2431 let (task_command, task_args) =
2432 ShellBuilder::new(&Shell::Program(shell), is_windows)
2433 .redirect_stdin_to_dev_null()
2434 .build(Some(command.clone()), &args);
2435 let terminal = project
2436 .update(cx, |project, cx| {
2437 project.create_terminal_task(
2438 task::SpawnInTerminal {
2439 command: Some(task_command),
2440 args: task_args,
2441 cwd: cwd.clone(),
2442 env,
2443 ..Default::default()
2444 },
2445 cx,
2446 )
2447 })
2448 .await?;
2449
2450 anyhow::Ok(cx.new(|cx| {
2451 Terminal::new(
2452 terminal_id,
2453 &format!("{} {}", command, args.join(" ")),
2454 cwd,
2455 output_byte_limit.map(|l| l as usize),
2456 terminal,
2457 language_registry,
2458 cx,
2459 )
2460 }))
2461 }
2462 });
2463
2464 cx.spawn(async move |this, cx| {
2465 let terminal = terminal_task.await?;
2466 this.update(cx, |this, _cx| {
2467 this.terminals.insert(terminal_id, terminal.clone());
2468 terminal
2469 })
2470 })
2471 }
2472
2473 pub fn kill_terminal(
2474 &mut self,
2475 terminal_id: acp::TerminalId,
2476 cx: &mut Context<Self>,
2477 ) -> Result<()> {
2478 self.terminals
2479 .get(&terminal_id)
2480 .context("Terminal not found")?
2481 .update(cx, |terminal, cx| {
2482 terminal.kill(cx);
2483 });
2484
2485 Ok(())
2486 }
2487
2488 pub fn release_terminal(
2489 &mut self,
2490 terminal_id: acp::TerminalId,
2491 cx: &mut Context<Self>,
2492 ) -> Result<()> {
2493 self.terminals
2494 .remove(&terminal_id)
2495 .context("Terminal not found")?
2496 .update(cx, |terminal, cx| {
2497 terminal.kill(cx);
2498 });
2499
2500 Ok(())
2501 }
2502
2503 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2504 self.terminals
2505 .get(&terminal_id)
2506 .context("Terminal not found")
2507 .cloned()
2508 }
2509
2510 pub fn to_markdown(&self, cx: &App) -> String {
2511 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2512 }
2513
2514 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2515 cx.emit(AcpThreadEvent::LoadError(error));
2516 }
2517
2518 pub fn register_terminal_created(
2519 &mut self,
2520 terminal_id: acp::TerminalId,
2521 command_label: String,
2522 working_dir: Option<PathBuf>,
2523 output_byte_limit: Option<u64>,
2524 terminal: Entity<::terminal::Terminal>,
2525 cx: &mut Context<Self>,
2526 ) -> Entity<Terminal> {
2527 let language_registry = self.project.read(cx).languages().clone();
2528
2529 let entity = cx.new(|cx| {
2530 Terminal::new(
2531 terminal_id.clone(),
2532 &command_label,
2533 working_dir.clone(),
2534 output_byte_limit.map(|l| l as usize),
2535 terminal,
2536 language_registry,
2537 cx,
2538 )
2539 });
2540 self.terminals.insert(terminal_id.clone(), entity.clone());
2541 entity
2542 }
2543}
2544
2545fn markdown_for_raw_output(
2546 raw_output: &serde_json::Value,
2547 language_registry: &Arc<LanguageRegistry>,
2548 cx: &mut App,
2549) -> Option<Entity<Markdown>> {
2550 match raw_output {
2551 serde_json::Value::Null => None,
2552 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2553 Markdown::new(
2554 value.to_string().into(),
2555 Some(language_registry.clone()),
2556 None,
2557 cx,
2558 )
2559 })),
2560 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2561 Markdown::new(
2562 value.to_string().into(),
2563 Some(language_registry.clone()),
2564 None,
2565 cx,
2566 )
2567 })),
2568 serde_json::Value::String(value) => Some(cx.new(|cx| {
2569 Markdown::new(
2570 value.clone().into(),
2571 Some(language_registry.clone()),
2572 None,
2573 cx,
2574 )
2575 })),
2576 value => Some(cx.new(|cx| {
2577 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2578
2579 Markdown::new(
2580 format!("```json\n{}\n```", pretty_json).into(),
2581 Some(language_registry.clone()),
2582 None,
2583 cx,
2584 )
2585 })),
2586 }
2587}
2588
2589#[cfg(test)]
2590mod tests {
2591 use super::*;
2592 use anyhow::anyhow;
2593 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2594 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2595 use indoc::indoc;
2596 use project::{FakeFs, Fs};
2597 use rand::{distr, prelude::*};
2598 use serde_json::json;
2599 use settings::SettingsStore;
2600 use smol::stream::StreamExt as _;
2601 use std::{
2602 any::Any,
2603 cell::RefCell,
2604 path::Path,
2605 rc::Rc,
2606 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2607 time::Duration,
2608 };
2609 use util::path;
2610
2611 fn init_test(cx: &mut TestAppContext) {
2612 env_logger::try_init().ok();
2613 cx.update(|cx| {
2614 let settings_store = SettingsStore::test(cx);
2615 cx.set_global(settings_store);
2616 });
2617 }
2618
2619 #[gpui::test]
2620 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2621 init_test(cx);
2622
2623 let fs = FakeFs::new(cx.executor());
2624 let project = Project::test(fs, [], cx).await;
2625 let connection = Rc::new(FakeAgentConnection::new());
2626 let thread = cx
2627 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2628 .await
2629 .unwrap();
2630
2631 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2632
2633 // Send Output BEFORE Created - should be buffered by acp_thread
2634 thread.update(cx, |thread, cx| {
2635 thread.on_terminal_provider_event(
2636 TerminalProviderEvent::Output {
2637 terminal_id: terminal_id.clone(),
2638 data: b"hello buffered".to_vec(),
2639 },
2640 cx,
2641 );
2642 });
2643
2644 // Create a display-only terminal and then send Created
2645 let lower = cx.new(|cx| {
2646 let builder = ::terminal::TerminalBuilder::new_display_only(
2647 ::terminal::terminal_settings::CursorShape::default(),
2648 ::terminal::terminal_settings::AlternateScroll::On,
2649 None,
2650 0,
2651 cx.background_executor(),
2652 PathStyle::local(),
2653 )
2654 .unwrap();
2655 builder.subscribe(cx)
2656 });
2657
2658 thread.update(cx, |thread, cx| {
2659 thread.on_terminal_provider_event(
2660 TerminalProviderEvent::Created {
2661 terminal_id: terminal_id.clone(),
2662 label: "Buffered Test".to_string(),
2663 cwd: None,
2664 output_byte_limit: None,
2665 terminal: lower.clone(),
2666 },
2667 cx,
2668 );
2669 });
2670
2671 // After Created, buffered Output should have been flushed into the renderer
2672 let content = thread.read_with(cx, |thread, cx| {
2673 let term = thread.terminal(terminal_id.clone()).unwrap();
2674 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2675 });
2676
2677 assert!(
2678 content.contains("hello buffered"),
2679 "expected buffered output to render, got: {content}"
2680 );
2681 }
2682
2683 #[gpui::test]
2684 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2685 init_test(cx);
2686
2687 let fs = FakeFs::new(cx.executor());
2688 let project = Project::test(fs, [], cx).await;
2689 let connection = Rc::new(FakeAgentConnection::new());
2690 let thread = cx
2691 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2692 .await
2693 .unwrap();
2694
2695 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2696
2697 // Send Output BEFORE Created
2698 thread.update(cx, |thread, cx| {
2699 thread.on_terminal_provider_event(
2700 TerminalProviderEvent::Output {
2701 terminal_id: terminal_id.clone(),
2702 data: b"pre-exit data".to_vec(),
2703 },
2704 cx,
2705 );
2706 });
2707
2708 // Send Exit BEFORE Created
2709 thread.update(cx, |thread, cx| {
2710 thread.on_terminal_provider_event(
2711 TerminalProviderEvent::Exit {
2712 terminal_id: terminal_id.clone(),
2713 status: acp::TerminalExitStatus::new().exit_code(0),
2714 },
2715 cx,
2716 );
2717 });
2718
2719 // Now create a display-only lower-level terminal and send Created
2720 let lower = cx.new(|cx| {
2721 let builder = ::terminal::TerminalBuilder::new_display_only(
2722 ::terminal::terminal_settings::CursorShape::default(),
2723 ::terminal::terminal_settings::AlternateScroll::On,
2724 None,
2725 0,
2726 cx.background_executor(),
2727 PathStyle::local(),
2728 )
2729 .unwrap();
2730 builder.subscribe(cx)
2731 });
2732
2733 thread.update(cx, |thread, cx| {
2734 thread.on_terminal_provider_event(
2735 TerminalProviderEvent::Created {
2736 terminal_id: terminal_id.clone(),
2737 label: "Buffered Exit Test".to_string(),
2738 cwd: None,
2739 output_byte_limit: None,
2740 terminal: lower.clone(),
2741 },
2742 cx,
2743 );
2744 });
2745
2746 // Output should be present after Created (flushed from buffer)
2747 let content = thread.read_with(cx, |thread, cx| {
2748 let term = thread.terminal(terminal_id.clone()).unwrap();
2749 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2750 });
2751
2752 assert!(
2753 content.contains("pre-exit data"),
2754 "expected pre-exit data to render, got: {content}"
2755 );
2756 }
2757
2758 /// Test that killing a terminal via Terminal::kill properly:
2759 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2760 /// 2. The underlying terminal still has the output that was written before the kill
2761 ///
2762 /// This test verifies that the fix to kill_active_task (which now also kills
2763 /// the shell process in addition to the foreground process) properly allows
2764 /// wait_for_exit to complete instead of hanging indefinitely.
2765 #[cfg(unix)]
2766 #[gpui::test]
2767 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2768 use std::collections::HashMap;
2769 use task::Shell;
2770 use util::shell_builder::ShellBuilder;
2771
2772 init_test(cx);
2773 cx.executor().allow_parking();
2774
2775 let fs = FakeFs::new(cx.executor());
2776 let project = Project::test(fs, [], cx).await;
2777 let connection = Rc::new(FakeAgentConnection::new());
2778 let thread = cx
2779 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2780 .await
2781 .unwrap();
2782
2783 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2784
2785 // Create a real PTY terminal that runs a command which prints output then sleeps
2786 // We use printf instead of echo and chain with && sleep to ensure proper execution
2787 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2788 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2789 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2790 &[],
2791 );
2792
2793 let builder = cx
2794 .update(|cx| {
2795 ::terminal::TerminalBuilder::new(
2796 None,
2797 None,
2798 task::Shell::WithArguments {
2799 program,
2800 args,
2801 title_override: None,
2802 },
2803 HashMap::default(),
2804 ::terminal::terminal_settings::CursorShape::default(),
2805 ::terminal::terminal_settings::AlternateScroll::On,
2806 None,
2807 vec![],
2808 0,
2809 false,
2810 0,
2811 Some(completion_tx),
2812 cx,
2813 vec![],
2814 PathStyle::local(),
2815 )
2816 })
2817 .await
2818 .unwrap();
2819
2820 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2821
2822 // Create the acp_thread Terminal wrapper
2823 thread.update(cx, |thread, cx| {
2824 thread.on_terminal_provider_event(
2825 TerminalProviderEvent::Created {
2826 terminal_id: terminal_id.clone(),
2827 label: "printf output_before_kill && sleep 60".to_string(),
2828 cwd: None,
2829 output_byte_limit: None,
2830 terminal: lower_terminal.clone(),
2831 },
2832 cx,
2833 );
2834 });
2835
2836 // Wait for the printf command to execute and produce output
2837 // Use real time since parking is enabled
2838 cx.executor().timer(Duration::from_millis(500)).await;
2839
2840 // Get the acp_thread Terminal and kill it
2841 let wait_for_exit = thread.update(cx, |thread, cx| {
2842 let term = thread.terminals.get(&terminal_id).unwrap();
2843 let wait_for_exit = term.read(cx).wait_for_exit();
2844 term.update(cx, |term, cx| {
2845 term.kill(cx);
2846 });
2847 wait_for_exit
2848 });
2849
2850 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2851 // Before the fix to kill_active_task, this would hang forever because
2852 // only the foreground process was killed, not the shell, so the PTY
2853 // child never exited and wait_for_completed_task never completed.
2854 let exit_result = futures::select! {
2855 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2856 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2857 };
2858
2859 assert!(
2860 exit_result.is_some(),
2861 "wait_for_exit should complete after kill, but it timed out. \
2862 This indicates kill_active_task is not properly killing the shell process."
2863 );
2864
2865 // Give the system a chance to process any pending updates
2866 cx.run_until_parked();
2867
2868 // Verify that the underlying terminal still has the output that was
2869 // written before the kill. This verifies that killing doesn't lose output.
2870 let inner_content = thread.read_with(cx, |thread, cx| {
2871 let term = thread.terminals.get(&terminal_id).unwrap();
2872 term.read(cx).inner().read(cx).get_content()
2873 });
2874
2875 assert!(
2876 inner_content.contains("output_before_kill"),
2877 "Underlying terminal should contain output from before kill, got: {}",
2878 inner_content
2879 );
2880 }
2881
2882 #[gpui::test]
2883 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2884 init_test(cx);
2885
2886 let fs = FakeFs::new(cx.executor());
2887 let project = Project::test(fs, [], cx).await;
2888 let connection = Rc::new(FakeAgentConnection::new());
2889 let thread = cx
2890 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2891 .await
2892 .unwrap();
2893
2894 // Test creating a new user message
2895 thread.update(cx, |thread, cx| {
2896 thread.push_user_content_block(None, "Hello, ".into(), cx);
2897 });
2898
2899 thread.update(cx, |thread, cx| {
2900 assert_eq!(thread.entries.len(), 1);
2901 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2902 assert_eq!(user_msg.id, None);
2903 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2904 } else {
2905 panic!("Expected UserMessage");
2906 }
2907 });
2908
2909 // Test appending to existing user message
2910 let message_1_id = UserMessageId::new();
2911 thread.update(cx, |thread, cx| {
2912 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2913 });
2914
2915 thread.update(cx, |thread, cx| {
2916 assert_eq!(thread.entries.len(), 1);
2917 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2918 assert_eq!(user_msg.id, Some(message_1_id));
2919 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2920 } else {
2921 panic!("Expected UserMessage");
2922 }
2923 });
2924
2925 // Test creating new user message after assistant message
2926 thread.update(cx, |thread, cx| {
2927 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2928 });
2929
2930 let message_2_id = UserMessageId::new();
2931 thread.update(cx, |thread, cx| {
2932 thread.push_user_content_block(
2933 Some(message_2_id.clone()),
2934 "New user message".into(),
2935 cx,
2936 );
2937 });
2938
2939 thread.update(cx, |thread, cx| {
2940 assert_eq!(thread.entries.len(), 3);
2941 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2942 assert_eq!(user_msg.id, Some(message_2_id));
2943 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2944 } else {
2945 panic!("Expected UserMessage at index 2");
2946 }
2947 });
2948 }
2949
2950 #[gpui::test]
2951 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2952 init_test(cx);
2953
2954 let fs = FakeFs::new(cx.executor());
2955 let project = Project::test(fs, [], cx).await;
2956 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2957 |_, thread, mut cx| {
2958 async move {
2959 thread.update(&mut cx, |thread, cx| {
2960 thread
2961 .handle_session_update(
2962 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2963 "Thinking ".into(),
2964 )),
2965 cx,
2966 )
2967 .unwrap();
2968 thread
2969 .handle_session_update(
2970 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2971 "hard!".into(),
2972 )),
2973 cx,
2974 )
2975 .unwrap();
2976 })?;
2977 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2978 }
2979 .boxed_local()
2980 },
2981 ));
2982
2983 let thread = cx
2984 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2985 .await
2986 .unwrap();
2987
2988 thread
2989 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2990 .await
2991 .unwrap();
2992
2993 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2994 assert_eq!(
2995 output,
2996 indoc! {r#"
2997 ## User
2998
2999 Hello from Zed!
3000
3001 ## Assistant
3002
3003 <thinking>
3004 Thinking hard!
3005 </thinking>
3006
3007 "#}
3008 );
3009 }
3010
3011 #[gpui::test]
3012 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3013 init_test(cx);
3014
3015 let fs = FakeFs::new(cx.executor());
3016 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3017 .await;
3018 let project = Project::test(fs.clone(), [], cx).await;
3019 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3020 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3021 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3022 move |_, thread, mut cx| {
3023 let read_file_tx = read_file_tx.clone();
3024 async move {
3025 let content = thread
3026 .update(&mut cx, |thread, cx| {
3027 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3028 })
3029 .unwrap()
3030 .await
3031 .unwrap();
3032 assert_eq!(content, "one\ntwo\nthree\n");
3033 read_file_tx.take().unwrap().send(()).unwrap();
3034 thread
3035 .update(&mut cx, |thread, cx| {
3036 thread.write_text_file(
3037 path!("/tmp/foo").into(),
3038 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3039 cx,
3040 )
3041 })
3042 .unwrap()
3043 .await
3044 .unwrap();
3045 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3046 }
3047 .boxed_local()
3048 },
3049 ));
3050
3051 let (worktree, pathbuf) = project
3052 .update(cx, |project, cx| {
3053 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3054 })
3055 .await
3056 .unwrap();
3057 let buffer = project
3058 .update(cx, |project, cx| {
3059 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3060 })
3061 .await
3062 .unwrap();
3063
3064 let thread = cx
3065 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3066 .await
3067 .unwrap();
3068
3069 let request = thread.update(cx, |thread, cx| {
3070 thread.send_raw("Extend the count in /tmp/foo", cx)
3071 });
3072 read_file_rx.await.ok();
3073 buffer.update(cx, |buffer, cx| {
3074 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3075 });
3076 cx.run_until_parked();
3077 assert_eq!(
3078 buffer.read_with(cx, |buffer, _| buffer.text()),
3079 "zero\none\ntwo\nthree\nfour\nfive\n"
3080 );
3081 assert_eq!(
3082 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3083 "zero\none\ntwo\nthree\nfour\nfive\n"
3084 );
3085 request.await.unwrap();
3086 }
3087
3088 #[gpui::test]
3089 async fn test_reading_from_line(cx: &mut TestAppContext) {
3090 init_test(cx);
3091
3092 let fs = FakeFs::new(cx.executor());
3093 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3094 .await;
3095 let project = Project::test(fs.clone(), [], cx).await;
3096 project
3097 .update(cx, |project, cx| {
3098 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3099 })
3100 .await
3101 .unwrap();
3102
3103 let connection = Rc::new(FakeAgentConnection::new());
3104
3105 let thread = cx
3106 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3107 .await
3108 .unwrap();
3109
3110 // Whole file
3111 let content = thread
3112 .update(cx, |thread, cx| {
3113 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3114 })
3115 .await
3116 .unwrap();
3117
3118 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3119
3120 // Only start line
3121 let content = thread
3122 .update(cx, |thread, cx| {
3123 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3124 })
3125 .await
3126 .unwrap();
3127
3128 assert_eq!(content, "three\nfour\n");
3129
3130 // Only limit
3131 let content = thread
3132 .update(cx, |thread, cx| {
3133 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3134 })
3135 .await
3136 .unwrap();
3137
3138 assert_eq!(content, "one\ntwo\n");
3139
3140 // Range
3141 let content = thread
3142 .update(cx, |thread, cx| {
3143 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3144 })
3145 .await
3146 .unwrap();
3147
3148 assert_eq!(content, "two\nthree\n");
3149
3150 // Invalid
3151 let err = thread
3152 .update(cx, |thread, cx| {
3153 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3154 })
3155 .await
3156 .unwrap_err();
3157
3158 assert_eq!(
3159 err.to_string(),
3160 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3161 );
3162 }
3163
3164 #[gpui::test]
3165 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3166 init_test(cx);
3167
3168 let fs = FakeFs::new(cx.executor());
3169 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3170 let project = Project::test(fs.clone(), [], cx).await;
3171 project
3172 .update(cx, |project, cx| {
3173 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3174 })
3175 .await
3176 .unwrap();
3177
3178 let connection = Rc::new(FakeAgentConnection::new());
3179
3180 let thread = cx
3181 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3182 .await
3183 .unwrap();
3184
3185 // Whole file
3186 let content = thread
3187 .update(cx, |thread, cx| {
3188 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3189 })
3190 .await
3191 .unwrap();
3192
3193 assert_eq!(content, "");
3194
3195 // Only start line
3196 let content = thread
3197 .update(cx, |thread, cx| {
3198 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3199 })
3200 .await
3201 .unwrap();
3202
3203 assert_eq!(content, "");
3204
3205 // Only limit
3206 let content = thread
3207 .update(cx, |thread, cx| {
3208 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3209 })
3210 .await
3211 .unwrap();
3212
3213 assert_eq!(content, "");
3214
3215 // Range
3216 let content = thread
3217 .update(cx, |thread, cx| {
3218 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3219 })
3220 .await
3221 .unwrap();
3222
3223 assert_eq!(content, "");
3224
3225 // Invalid
3226 let err = thread
3227 .update(cx, |thread, cx| {
3228 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3229 })
3230 .await
3231 .unwrap_err();
3232
3233 assert_eq!(
3234 err.to_string(),
3235 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3236 );
3237 }
3238 #[gpui::test]
3239 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3240 init_test(cx);
3241
3242 let fs = FakeFs::new(cx.executor());
3243 fs.insert_tree(path!("/tmp"), json!({})).await;
3244 let project = Project::test(fs.clone(), [], cx).await;
3245 project
3246 .update(cx, |project, cx| {
3247 project.find_or_create_worktree(path!("/tmp"), true, cx)
3248 })
3249 .await
3250 .unwrap();
3251
3252 let connection = Rc::new(FakeAgentConnection::new());
3253
3254 let thread = cx
3255 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3256 .await
3257 .unwrap();
3258
3259 // Out of project file
3260 let err = thread
3261 .update(cx, |thread, cx| {
3262 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3263 })
3264 .await
3265 .unwrap_err();
3266
3267 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3268 }
3269
3270 #[gpui::test]
3271 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3272 init_test(cx);
3273
3274 let fs = FakeFs::new(cx.executor());
3275 let project = Project::test(fs, [], cx).await;
3276 let id = acp::ToolCallId::new("test");
3277
3278 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3279 let id = id.clone();
3280 move |_, thread, mut cx| {
3281 let id = id.clone();
3282 async move {
3283 thread
3284 .update(&mut cx, |thread, cx| {
3285 thread.handle_session_update(
3286 acp::SessionUpdate::ToolCall(
3287 acp::ToolCall::new(id.clone(), "Label")
3288 .kind(acp::ToolKind::Fetch)
3289 .status(acp::ToolCallStatus::InProgress),
3290 ),
3291 cx,
3292 )
3293 })
3294 .unwrap()
3295 .unwrap();
3296 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3297 }
3298 .boxed_local()
3299 }
3300 }));
3301
3302 let thread = cx
3303 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3304 .await
3305 .unwrap();
3306
3307 let request = thread.update(cx, |thread, cx| {
3308 thread.send_raw("Fetch https://example.com", cx)
3309 });
3310
3311 run_until_first_tool_call(&thread, cx).await;
3312
3313 thread.read_with(cx, |thread, _| {
3314 assert!(matches!(
3315 thread.entries[1],
3316 AgentThreadEntry::ToolCall(ToolCall {
3317 status: ToolCallStatus::InProgress,
3318 ..
3319 })
3320 ));
3321 });
3322
3323 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3324
3325 thread.read_with(cx, |thread, _| {
3326 assert!(matches!(
3327 &thread.entries[1],
3328 AgentThreadEntry::ToolCall(ToolCall {
3329 status: ToolCallStatus::Canceled,
3330 ..
3331 })
3332 ));
3333 });
3334
3335 thread
3336 .update(cx, |thread, cx| {
3337 thread.handle_session_update(
3338 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3339 id,
3340 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3341 )),
3342 cx,
3343 )
3344 })
3345 .unwrap();
3346
3347 request.await.unwrap();
3348
3349 thread.read_with(cx, |thread, _| {
3350 assert!(matches!(
3351 thread.entries[1],
3352 AgentThreadEntry::ToolCall(ToolCall {
3353 status: ToolCallStatus::Completed,
3354 ..
3355 })
3356 ));
3357 });
3358 }
3359
3360 #[gpui::test]
3361 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3362 init_test(cx);
3363 let fs = FakeFs::new(cx.background_executor.clone());
3364 fs.insert_tree(path!("/test"), json!({})).await;
3365 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3366
3367 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3368 move |_, thread, mut cx| {
3369 async move {
3370 thread
3371 .update(&mut cx, |thread, cx| {
3372 thread.handle_session_update(
3373 acp::SessionUpdate::ToolCall(
3374 acp::ToolCall::new("test", "Label")
3375 .kind(acp::ToolKind::Edit)
3376 .status(acp::ToolCallStatus::Completed)
3377 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3378 "/test/test.txt",
3379 "foo",
3380 ))]),
3381 ),
3382 cx,
3383 )
3384 })
3385 .unwrap()
3386 .unwrap();
3387 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3388 }
3389 .boxed_local()
3390 }
3391 }));
3392
3393 let thread = cx
3394 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3395 .await
3396 .unwrap();
3397
3398 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3399 .await
3400 .unwrap();
3401
3402 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3403 }
3404
3405 #[gpui::test(iterations = 10)]
3406 async fn test_checkpoints(cx: &mut TestAppContext) {
3407 init_test(cx);
3408 let fs = FakeFs::new(cx.background_executor.clone());
3409 fs.insert_tree(
3410 path!("/test"),
3411 json!({
3412 ".git": {}
3413 }),
3414 )
3415 .await;
3416 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3417
3418 let simulate_changes = Arc::new(AtomicBool::new(true));
3419 let next_filename = Arc::new(AtomicUsize::new(0));
3420 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3421 let simulate_changes = simulate_changes.clone();
3422 let next_filename = next_filename.clone();
3423 let fs = fs.clone();
3424 move |request, thread, mut cx| {
3425 let fs = fs.clone();
3426 let simulate_changes = simulate_changes.clone();
3427 let next_filename = next_filename.clone();
3428 async move {
3429 if simulate_changes.load(SeqCst) {
3430 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3431 fs.write(Path::new(&filename), b"").await?;
3432 }
3433
3434 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3435 panic!("expected text content block");
3436 };
3437 thread.update(&mut cx, |thread, cx| {
3438 thread
3439 .handle_session_update(
3440 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3441 content.text.to_uppercase().into(),
3442 )),
3443 cx,
3444 )
3445 .unwrap();
3446 })?;
3447 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3448 }
3449 .boxed_local()
3450 }
3451 }));
3452 let thread = cx
3453 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3454 .await
3455 .unwrap();
3456
3457 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3458 .await
3459 .unwrap();
3460 thread.read_with(cx, |thread, cx| {
3461 assert_eq!(
3462 thread.to_markdown(cx),
3463 indoc! {"
3464 ## User (checkpoint)
3465
3466 Lorem
3467
3468 ## Assistant
3469
3470 LOREM
3471
3472 "}
3473 );
3474 });
3475 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3476
3477 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3478 .await
3479 .unwrap();
3480 thread.read_with(cx, |thread, cx| {
3481 assert_eq!(
3482 thread.to_markdown(cx),
3483 indoc! {"
3484 ## User (checkpoint)
3485
3486 Lorem
3487
3488 ## Assistant
3489
3490 LOREM
3491
3492 ## User (checkpoint)
3493
3494 ipsum
3495
3496 ## Assistant
3497
3498 IPSUM
3499
3500 "}
3501 );
3502 });
3503 assert_eq!(
3504 fs.files(),
3505 vec![
3506 Path::new(path!("/test/file-0")),
3507 Path::new(path!("/test/file-1"))
3508 ]
3509 );
3510
3511 // Checkpoint isn't stored when there are no changes.
3512 simulate_changes.store(false, SeqCst);
3513 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3514 .await
3515 .unwrap();
3516 thread.read_with(cx, |thread, cx| {
3517 assert_eq!(
3518 thread.to_markdown(cx),
3519 indoc! {"
3520 ## User (checkpoint)
3521
3522 Lorem
3523
3524 ## Assistant
3525
3526 LOREM
3527
3528 ## User (checkpoint)
3529
3530 ipsum
3531
3532 ## Assistant
3533
3534 IPSUM
3535
3536 ## User
3537
3538 dolor
3539
3540 ## Assistant
3541
3542 DOLOR
3543
3544 "}
3545 );
3546 });
3547 assert_eq!(
3548 fs.files(),
3549 vec![
3550 Path::new(path!("/test/file-0")),
3551 Path::new(path!("/test/file-1"))
3552 ]
3553 );
3554
3555 // Rewinding the conversation truncates the history and restores the checkpoint.
3556 thread
3557 .update(cx, |thread, cx| {
3558 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3559 panic!("unexpected entries {:?}", thread.entries)
3560 };
3561 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3562 })
3563 .await
3564 .unwrap();
3565 thread.read_with(cx, |thread, cx| {
3566 assert_eq!(
3567 thread.to_markdown(cx),
3568 indoc! {"
3569 ## User (checkpoint)
3570
3571 Lorem
3572
3573 ## Assistant
3574
3575 LOREM
3576
3577 "}
3578 );
3579 });
3580 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3581 }
3582
3583 #[gpui::test]
3584 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3585 use std::sync::atomic::AtomicUsize;
3586 init_test(cx);
3587
3588 let fs = FakeFs::new(cx.executor());
3589 let project = Project::test(fs, None, cx).await;
3590
3591 // Create a connection that simulates refusal after tool result
3592 let prompt_count = Arc::new(AtomicUsize::new(0));
3593 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3594 let prompt_count = prompt_count.clone();
3595 move |_request, thread, mut cx| {
3596 let count = prompt_count.fetch_add(1, SeqCst);
3597 async move {
3598 if count == 0 {
3599 // First prompt: Generate a tool call with result
3600 thread.update(&mut cx, |thread, cx| {
3601 thread
3602 .handle_session_update(
3603 acp::SessionUpdate::ToolCall(
3604 acp::ToolCall::new("tool1", "Test Tool")
3605 .kind(acp::ToolKind::Fetch)
3606 .status(acp::ToolCallStatus::Completed)
3607 .raw_input(serde_json::json!({"query": "test"}))
3608 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3609 ),
3610 cx,
3611 )
3612 .unwrap();
3613 })?;
3614
3615 // Now return refusal because of the tool result
3616 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3617 } else {
3618 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3619 }
3620 }
3621 .boxed_local()
3622 }
3623 }));
3624
3625 let thread = cx
3626 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3627 .await
3628 .unwrap();
3629
3630 // Track if we see a Refusal event
3631 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3632 let saw_refusal_event_captured = saw_refusal_event.clone();
3633 thread.update(cx, |_thread, cx| {
3634 cx.subscribe(
3635 &thread,
3636 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3637 if matches!(event, AcpThreadEvent::Refusal) {
3638 *saw_refusal_event_captured.lock().unwrap() = true;
3639 }
3640 },
3641 )
3642 .detach();
3643 });
3644
3645 // Send a user message - this will trigger tool call and then refusal
3646 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3647 cx.background_executor.spawn(send_task).detach();
3648 cx.run_until_parked();
3649
3650 // Verify that:
3651 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3652 // 2. The user message was NOT truncated
3653 assert!(
3654 *saw_refusal_event.lock().unwrap(),
3655 "Refusal event should be emitted for tool result refusals"
3656 );
3657
3658 thread.read_with(cx, |thread, _| {
3659 let entries = thread.entries();
3660 assert!(entries.len() >= 2, "Should have user message and tool call");
3661
3662 // Verify user message is still there
3663 assert!(
3664 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3665 "User message should not be truncated"
3666 );
3667
3668 // Verify tool call is there with result
3669 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3670 assert!(
3671 tool_call.raw_output.is_some(),
3672 "Tool call should have output"
3673 );
3674 } else {
3675 panic!("Expected tool call at index 1");
3676 }
3677 });
3678 }
3679
3680 #[gpui::test]
3681 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3682 init_test(cx);
3683
3684 let fs = FakeFs::new(cx.executor());
3685 let project = Project::test(fs, None, cx).await;
3686
3687 let refuse_next = Arc::new(AtomicBool::new(false));
3688 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3689 let refuse_next = refuse_next.clone();
3690 move |_request, _thread, _cx| {
3691 if refuse_next.load(SeqCst) {
3692 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3693 .boxed_local()
3694 } else {
3695 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3696 .boxed_local()
3697 }
3698 }
3699 }));
3700
3701 let thread = cx
3702 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3703 .await
3704 .unwrap();
3705
3706 // Track if we see a Refusal event
3707 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3708 let saw_refusal_event_captured = saw_refusal_event.clone();
3709 thread.update(cx, |_thread, cx| {
3710 cx.subscribe(
3711 &thread,
3712 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3713 if matches!(event, AcpThreadEvent::Refusal) {
3714 *saw_refusal_event_captured.lock().unwrap() = true;
3715 }
3716 },
3717 )
3718 .detach();
3719 });
3720
3721 // Send a message that will be refused
3722 refuse_next.store(true, SeqCst);
3723 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3724 .await
3725 .unwrap();
3726
3727 // Verify that a Refusal event WAS emitted for user prompt refusal
3728 assert!(
3729 *saw_refusal_event.lock().unwrap(),
3730 "Refusal event should be emitted for user prompt refusals"
3731 );
3732
3733 // Verify the message was truncated (user prompt refusal)
3734 thread.read_with(cx, |thread, cx| {
3735 assert_eq!(thread.to_markdown(cx), "");
3736 });
3737 }
3738
3739 #[gpui::test]
3740 async fn test_refusal(cx: &mut TestAppContext) {
3741 init_test(cx);
3742 let fs = FakeFs::new(cx.background_executor.clone());
3743 fs.insert_tree(path!("/"), json!({})).await;
3744 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3745
3746 let refuse_next = Arc::new(AtomicBool::new(false));
3747 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3748 let refuse_next = refuse_next.clone();
3749 move |request, thread, mut cx| {
3750 let refuse_next = refuse_next.clone();
3751 async move {
3752 if refuse_next.load(SeqCst) {
3753 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3754 }
3755
3756 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3757 panic!("expected text content block");
3758 };
3759 thread.update(&mut cx, |thread, cx| {
3760 thread
3761 .handle_session_update(
3762 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3763 content.text.to_uppercase().into(),
3764 )),
3765 cx,
3766 )
3767 .unwrap();
3768 })?;
3769 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3770 }
3771 .boxed_local()
3772 }
3773 }));
3774 let thread = cx
3775 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3776 .await
3777 .unwrap();
3778
3779 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3780 .await
3781 .unwrap();
3782 thread.read_with(cx, |thread, cx| {
3783 assert_eq!(
3784 thread.to_markdown(cx),
3785 indoc! {"
3786 ## User
3787
3788 hello
3789
3790 ## Assistant
3791
3792 HELLO
3793
3794 "}
3795 );
3796 });
3797
3798 // Simulate refusing the second message. The message should be truncated
3799 // when a user prompt is refused.
3800 refuse_next.store(true, SeqCst);
3801 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3802 .await
3803 .unwrap();
3804 thread.read_with(cx, |thread, cx| {
3805 assert_eq!(
3806 thread.to_markdown(cx),
3807 indoc! {"
3808 ## User
3809
3810 hello
3811
3812 ## Assistant
3813
3814 HELLO
3815
3816 "}
3817 );
3818 });
3819 }
3820
3821 async fn run_until_first_tool_call(
3822 thread: &Entity<AcpThread>,
3823 cx: &mut TestAppContext,
3824 ) -> usize {
3825 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3826
3827 let subscription = cx.update(|cx| {
3828 cx.subscribe(thread, move |thread, _, cx| {
3829 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3830 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3831 return tx.try_send(ix).unwrap();
3832 }
3833 }
3834 })
3835 });
3836
3837 select! {
3838 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3839 panic!("Timeout waiting for tool call")
3840 }
3841 ix = rx.next().fuse() => {
3842 drop(subscription);
3843 ix.unwrap()
3844 }
3845 }
3846 }
3847
3848 #[derive(Clone, Default)]
3849 struct FakeAgentConnection {
3850 auth_methods: Vec<acp::AuthMethod>,
3851 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3852 on_user_message: Option<
3853 Rc<
3854 dyn Fn(
3855 acp::PromptRequest,
3856 WeakEntity<AcpThread>,
3857 AsyncApp,
3858 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3859 + 'static,
3860 >,
3861 >,
3862 }
3863
3864 impl FakeAgentConnection {
3865 fn new() -> Self {
3866 Self {
3867 auth_methods: Vec::new(),
3868 on_user_message: None,
3869 sessions: Arc::default(),
3870 }
3871 }
3872
3873 #[expect(unused)]
3874 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3875 self.auth_methods = auth_methods;
3876 self
3877 }
3878
3879 fn on_user_message(
3880 mut self,
3881 handler: impl Fn(
3882 acp::PromptRequest,
3883 WeakEntity<AcpThread>,
3884 AsyncApp,
3885 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3886 + 'static,
3887 ) -> Self {
3888 self.on_user_message.replace(Rc::new(handler));
3889 self
3890 }
3891 }
3892
3893 impl AgentConnection for FakeAgentConnection {
3894 fn telemetry_id(&self) -> SharedString {
3895 "fake".into()
3896 }
3897
3898 fn auth_methods(&self) -> &[acp::AuthMethod] {
3899 &self.auth_methods
3900 }
3901
3902 fn new_session(
3903 self: Rc<Self>,
3904 project: Entity<Project>,
3905 _cwd: &Path,
3906 cx: &mut App,
3907 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3908 let session_id = acp::SessionId::new(
3909 rand::rng()
3910 .sample_iter(&distr::Alphanumeric)
3911 .take(7)
3912 .map(char::from)
3913 .collect::<String>(),
3914 );
3915 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3916 let thread = cx.new(|cx| {
3917 AcpThread::new(
3918 None,
3919 "Test",
3920 self.clone(),
3921 project,
3922 action_log,
3923 session_id.clone(),
3924 watch::Receiver::constant(
3925 acp::PromptCapabilities::new()
3926 .image(true)
3927 .audio(true)
3928 .embedded_context(true),
3929 ),
3930 cx,
3931 )
3932 });
3933 self.sessions.lock().insert(session_id, thread.downgrade());
3934 Task::ready(Ok(thread))
3935 }
3936
3937 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3938 if self.auth_methods().iter().any(|m| m.id == method) {
3939 Task::ready(Ok(()))
3940 } else {
3941 Task::ready(Err(anyhow!("Invalid Auth Method")))
3942 }
3943 }
3944
3945 fn prompt(
3946 &self,
3947 _id: Option<UserMessageId>,
3948 params: acp::PromptRequest,
3949 cx: &mut App,
3950 ) -> Task<gpui::Result<acp::PromptResponse>> {
3951 let sessions = self.sessions.lock();
3952 let thread = sessions.get(¶ms.session_id).unwrap();
3953 if let Some(handler) = &self.on_user_message {
3954 let handler = handler.clone();
3955 let thread = thread.clone();
3956 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3957 } else {
3958 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3959 }
3960 }
3961
3962 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3963 let sessions = self.sessions.lock();
3964 let thread = sessions.get(session_id).unwrap().clone();
3965
3966 cx.spawn(async move |cx| {
3967 thread
3968 .update(cx, |thread, cx| thread.cancel(cx))
3969 .unwrap()
3970 .await
3971 })
3972 .detach();
3973 }
3974
3975 fn truncate(
3976 &self,
3977 session_id: &acp::SessionId,
3978 _cx: &App,
3979 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3980 Some(Rc::new(FakeAgentSessionEditor {
3981 _session_id: session_id.clone(),
3982 }))
3983 }
3984
3985 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3986 self
3987 }
3988 }
3989
3990 struct FakeAgentSessionEditor {
3991 _session_id: acp::SessionId,
3992 }
3993
3994 impl AgentSessionTruncate for FakeAgentSessionEditor {
3995 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3996 Task::ready(Ok(()))
3997 }
3998 }
3999
4000 #[gpui::test]
4001 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
4002 init_test(cx);
4003
4004 let fs = FakeFs::new(cx.executor());
4005 let project = Project::test(fs, [], cx).await;
4006 let connection = Rc::new(FakeAgentConnection::new());
4007 let thread = cx
4008 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4009 .await
4010 .unwrap();
4011
4012 // Try to update a tool call that doesn't exist
4013 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
4014 thread.update(cx, |thread, cx| {
4015 let result = thread.handle_session_update(
4016 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4017 nonexistent_id.clone(),
4018 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4019 )),
4020 cx,
4021 );
4022
4023 // The update should succeed (not return an error)
4024 assert!(result.is_ok());
4025
4026 // There should now be exactly one entry in the thread
4027 assert_eq!(thread.entries.len(), 1);
4028
4029 // The entry should be a failed tool call
4030 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4031 assert_eq!(tool_call.id, nonexistent_id);
4032 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4033 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4034
4035 // Check that the content contains the error message
4036 assert_eq!(tool_call.content.len(), 1);
4037 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4038 match content_block {
4039 ContentBlock::Markdown { markdown } => {
4040 let markdown_text = markdown.read(cx).source();
4041 assert!(markdown_text.contains("Tool call not found"));
4042 }
4043 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4044 ContentBlock::ResourceLink { .. } => {
4045 panic!("Expected markdown content, got resource link")
4046 }
4047 ContentBlock::Image { .. } => {
4048 panic!("Expected markdown content, got image")
4049 }
4050 }
4051 } else {
4052 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4053 }
4054 } else {
4055 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4056 }
4057 });
4058 }
4059
4060 /// Tests that restoring a checkpoint properly cleans up terminals that were
4061 /// created after that checkpoint, and cancels any in-progress generation.
4062 ///
4063 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4064 /// that were started after that checkpoint should be terminated, and any in-progress
4065 /// AI generation should be canceled.
4066 #[gpui::test]
4067 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4068 init_test(cx);
4069
4070 let fs = FakeFs::new(cx.executor());
4071 let project = Project::test(fs, [], cx).await;
4072 let connection = Rc::new(FakeAgentConnection::new());
4073 let thread = cx
4074 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4075 .await
4076 .unwrap();
4077
4078 // Send first user message to create a checkpoint
4079 cx.update(|cx| {
4080 thread.update(cx, |thread, cx| {
4081 thread.send(vec!["first message".into()], cx)
4082 })
4083 })
4084 .await
4085 .unwrap();
4086
4087 // Send second message (creates another checkpoint) - we'll restore to this one
4088 cx.update(|cx| {
4089 thread.update(cx, |thread, cx| {
4090 thread.send(vec!["second message".into()], cx)
4091 })
4092 })
4093 .await
4094 .unwrap();
4095
4096 // Create 2 terminals BEFORE the checkpoint that have completed running
4097 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4098 let mock_terminal_1 = cx.new(|cx| {
4099 let builder = ::terminal::TerminalBuilder::new_display_only(
4100 ::terminal::terminal_settings::CursorShape::default(),
4101 ::terminal::terminal_settings::AlternateScroll::On,
4102 None,
4103 0,
4104 cx.background_executor(),
4105 PathStyle::local(),
4106 )
4107 .unwrap();
4108 builder.subscribe(cx)
4109 });
4110
4111 thread.update(cx, |thread, cx| {
4112 thread.on_terminal_provider_event(
4113 TerminalProviderEvent::Created {
4114 terminal_id: terminal_id_1.clone(),
4115 label: "echo 'first'".to_string(),
4116 cwd: Some(PathBuf::from("/test")),
4117 output_byte_limit: None,
4118 terminal: mock_terminal_1.clone(),
4119 },
4120 cx,
4121 );
4122 });
4123
4124 thread.update(cx, |thread, cx| {
4125 thread.on_terminal_provider_event(
4126 TerminalProviderEvent::Output {
4127 terminal_id: terminal_id_1.clone(),
4128 data: b"first\n".to_vec(),
4129 },
4130 cx,
4131 );
4132 });
4133
4134 thread.update(cx, |thread, cx| {
4135 thread.on_terminal_provider_event(
4136 TerminalProviderEvent::Exit {
4137 terminal_id: terminal_id_1.clone(),
4138 status: acp::TerminalExitStatus::new().exit_code(0),
4139 },
4140 cx,
4141 );
4142 });
4143
4144 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4145 let mock_terminal_2 = cx.new(|cx| {
4146 let builder = ::terminal::TerminalBuilder::new_display_only(
4147 ::terminal::terminal_settings::CursorShape::default(),
4148 ::terminal::terminal_settings::AlternateScroll::On,
4149 None,
4150 0,
4151 cx.background_executor(),
4152 PathStyle::local(),
4153 )
4154 .unwrap();
4155 builder.subscribe(cx)
4156 });
4157
4158 thread.update(cx, |thread, cx| {
4159 thread.on_terminal_provider_event(
4160 TerminalProviderEvent::Created {
4161 terminal_id: terminal_id_2.clone(),
4162 label: "echo 'second'".to_string(),
4163 cwd: Some(PathBuf::from("/test")),
4164 output_byte_limit: None,
4165 terminal: mock_terminal_2.clone(),
4166 },
4167 cx,
4168 );
4169 });
4170
4171 thread.update(cx, |thread, cx| {
4172 thread.on_terminal_provider_event(
4173 TerminalProviderEvent::Output {
4174 terminal_id: terminal_id_2.clone(),
4175 data: b"second\n".to_vec(),
4176 },
4177 cx,
4178 );
4179 });
4180
4181 thread.update(cx, |thread, cx| {
4182 thread.on_terminal_provider_event(
4183 TerminalProviderEvent::Exit {
4184 terminal_id: terminal_id_2.clone(),
4185 status: acp::TerminalExitStatus::new().exit_code(0),
4186 },
4187 cx,
4188 );
4189 });
4190
4191 // Get the second message ID to restore to
4192 let second_message_id = thread.read_with(cx, |thread, _| {
4193 // At this point we have:
4194 // - Index 0: First user message (with checkpoint)
4195 // - Index 1: Second user message (with checkpoint)
4196 // No assistant responses because FakeAgentConnection just returns EndTurn
4197 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4198 panic!("expected user message at index 1");
4199 };
4200 message.id.clone().unwrap()
4201 });
4202
4203 // Create a terminal AFTER the checkpoint we'll restore to.
4204 // This simulates the AI agent starting a long-running terminal command.
4205 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4206 let mock_terminal = cx.new(|cx| {
4207 let builder = ::terminal::TerminalBuilder::new_display_only(
4208 ::terminal::terminal_settings::CursorShape::default(),
4209 ::terminal::terminal_settings::AlternateScroll::On,
4210 None,
4211 0,
4212 cx.background_executor(),
4213 PathStyle::local(),
4214 )
4215 .unwrap();
4216 builder.subscribe(cx)
4217 });
4218
4219 // Register the terminal as created
4220 thread.update(cx, |thread, cx| {
4221 thread.on_terminal_provider_event(
4222 TerminalProviderEvent::Created {
4223 terminal_id: terminal_id.clone(),
4224 label: "sleep 1000".to_string(),
4225 cwd: Some(PathBuf::from("/test")),
4226 output_byte_limit: None,
4227 terminal: mock_terminal.clone(),
4228 },
4229 cx,
4230 );
4231 });
4232
4233 // Simulate the terminal producing output (still running)
4234 thread.update(cx, |thread, cx| {
4235 thread.on_terminal_provider_event(
4236 TerminalProviderEvent::Output {
4237 terminal_id: terminal_id.clone(),
4238 data: b"terminal is running...\n".to_vec(),
4239 },
4240 cx,
4241 );
4242 });
4243
4244 // Create a tool call entry that references this terminal
4245 // This represents the agent requesting a terminal command
4246 thread.update(cx, |thread, cx| {
4247 thread
4248 .handle_session_update(
4249 acp::SessionUpdate::ToolCall(
4250 acp::ToolCall::new("terminal-tool-1", "Running command")
4251 .kind(acp::ToolKind::Execute)
4252 .status(acp::ToolCallStatus::InProgress)
4253 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4254 terminal_id.clone(),
4255 ))])
4256 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4257 ),
4258 cx,
4259 )
4260 .unwrap();
4261 });
4262
4263 // Verify terminal exists and is in the thread
4264 let terminal_exists_before =
4265 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4266 assert!(
4267 terminal_exists_before,
4268 "Terminal should exist before checkpoint restore"
4269 );
4270
4271 // Verify the terminal's underlying task is still running (not completed)
4272 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4273 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4274 terminal_entity.read_with(cx, |term, _cx| {
4275 term.output().is_none() // output is None means it's still running
4276 })
4277 });
4278 assert!(
4279 terminal_running_before,
4280 "Terminal should be running before checkpoint restore"
4281 );
4282
4283 // Verify we have the expected entries before restore
4284 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4285 assert!(
4286 entry_count_before > 1,
4287 "Should have multiple entries before restore"
4288 );
4289
4290 // Restore the checkpoint to the second message.
4291 // This should:
4292 // 1. Cancel any in-progress generation (via the cancel() call)
4293 // 2. Remove the terminal that was created after that point
4294 thread
4295 .update(cx, |thread, cx| {
4296 thread.restore_checkpoint(second_message_id, cx)
4297 })
4298 .await
4299 .unwrap();
4300
4301 // Verify that no send_task is in progress after restore
4302 // (cancel() clears the send_task)
4303 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4304 assert!(
4305 !has_send_task_after,
4306 "Should not have a send_task after restore (cancel should have cleared it)"
4307 );
4308
4309 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4310 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4311 assert_eq!(
4312 entry_count, 1,
4313 "Should have 1 entry after restore (only the first user message)"
4314 );
4315
4316 // Verify the 2 completed terminals from before the checkpoint still exist
4317 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4318 thread.terminals.contains_key(&terminal_id_1)
4319 });
4320 assert!(
4321 terminal_1_exists,
4322 "Terminal 1 (from before checkpoint) should still exist"
4323 );
4324
4325 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4326 thread.terminals.contains_key(&terminal_id_2)
4327 });
4328 assert!(
4329 terminal_2_exists,
4330 "Terminal 2 (from before checkpoint) should still exist"
4331 );
4332
4333 // Verify they're still in completed state
4334 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4335 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4336 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4337 });
4338 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4339
4340 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4341 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4342 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4343 });
4344 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4345
4346 // Verify the running terminal (created after checkpoint) was removed
4347 let terminal_3_exists =
4348 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4349 assert!(
4350 !terminal_3_exists,
4351 "Terminal 3 (created after checkpoint) should have been removed"
4352 );
4353
4354 // Verify total count is 2 (the two from before the checkpoint)
4355 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4356 assert_eq!(
4357 terminal_count, 2,
4358 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4359 );
4360 }
4361
4362 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4363 /// even when a new user message is added while the async checkpoint comparison is in progress.
4364 ///
4365 /// This is a regression test for a bug where update_last_checkpoint would fail with
4366 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4367 /// update_last_checkpoint started and when its async closure ran.
4368 #[gpui::test]
4369 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4370 init_test(cx);
4371
4372 let fs = FakeFs::new(cx.executor());
4373 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4374 .await;
4375 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4376
4377 let handler_done = Arc::new(AtomicBool::new(false));
4378 let handler_done_clone = handler_done.clone();
4379 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4380 move |_, _thread, _cx| {
4381 handler_done_clone.store(true, SeqCst);
4382 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4383 },
4384 ));
4385
4386 let thread = cx
4387 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4388 .await
4389 .unwrap();
4390
4391 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4392 let send_task = cx.background_executor.spawn(send_future);
4393
4394 // Tick until handler completes, then a few more to let update_last_checkpoint start
4395 while !handler_done.load(SeqCst) {
4396 cx.executor().tick();
4397 }
4398 for _ in 0..5 {
4399 cx.executor().tick();
4400 }
4401
4402 thread.update(cx, |thread, cx| {
4403 thread.push_entry(
4404 AgentThreadEntry::UserMessage(UserMessage {
4405 id: Some(UserMessageId::new()),
4406 content: ContentBlock::Empty,
4407 chunks: vec!["Injected message (no checkpoint)".into()],
4408 checkpoint: None,
4409 indented: false,
4410 }),
4411 cx,
4412 );
4413 });
4414
4415 cx.run_until_parked();
4416 let result = send_task.await;
4417
4418 assert!(
4419 result.is_ok(),
4420 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4421 result.err()
4422 );
4423 }
4424}