1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 if self.kind == acp::ToolKind::Execute {
318 for terminal in self.terminals() {
319 terminal.update(cx, |terminal, cx| {
320 terminal.update_command_label(&title, cx);
321 });
322 }
323 }
324 self.label.update(cx, |label, cx| {
325 if self.kind == acp::ToolKind::Execute {
326 label.replace(title, cx);
327 } else if let Some((first_line, _)) = title.split_once("\n") {
328 label.replace(first_line.to_owned() + "…", cx);
329 } else {
330 label.replace(title, cx);
331 }
332 });
333 }
334
335 if let Some(content) = content {
336 let mut new_content_len = content.len();
337 let mut content = content.into_iter();
338
339 // Reuse existing content if we can
340 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
341 let valid_content =
342 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
343 if !valid_content {
344 new_content_len -= 1;
345 }
346 }
347 for new in content {
348 if let Some(new) = ToolCallContent::from_acp(
349 new,
350 language_registry.clone(),
351 path_style,
352 terminals,
353 cx,
354 )? {
355 self.content.push(new);
356 } else {
357 new_content_len -= 1;
358 }
359 }
360 self.content.truncate(new_content_len);
361 }
362
363 if let Some(locations) = locations {
364 self.locations = locations;
365 }
366
367 if let Some(raw_input) = raw_input {
368 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
369 self.raw_input = Some(raw_input);
370 }
371
372 if let Some(raw_output) = raw_output {
373 if self.content.is_empty()
374 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
375 {
376 self.content
377 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
378 markdown,
379 }));
380 }
381 self.raw_output = Some(raw_output);
382 }
383 Ok(())
384 }
385
386 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
387 self.content.iter().filter_map(|content| match content {
388 ToolCallContent::Diff(diff) => Some(diff),
389 ToolCallContent::ContentBlock(_) => None,
390 ToolCallContent::Terminal(_) => None,
391 })
392 }
393
394 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
395 self.content.iter().filter_map(|content| match content {
396 ToolCallContent::Terminal(terminal) => Some(terminal),
397 ToolCallContent::ContentBlock(_) => None,
398 ToolCallContent::Diff(_) => None,
399 })
400 }
401
402 pub fn is_subagent(&self) -> bool {
403 self.tool_name.as_ref().is_some_and(|s| s == "subagent")
404 || self.subagent_session_id.is_some()
405 }
406
407 pub fn to_markdown(&self, cx: &App) -> String {
408 let mut markdown = format!(
409 "**Tool Call: {}**\nStatus: {}\n\n",
410 self.label.read(cx).source(),
411 self.status
412 );
413 for content in &self.content {
414 markdown.push_str(content.to_markdown(cx).as_str());
415 markdown.push_str("\n\n");
416 }
417 markdown
418 }
419
420 async fn resolve_location(
421 location: acp::ToolCallLocation,
422 project: WeakEntity<Project>,
423 cx: &mut AsyncApp,
424 ) -> Option<ResolvedLocation> {
425 let buffer = project
426 .update(cx, |project, cx| {
427 project
428 .project_path_for_absolute_path(&location.path, cx)
429 .map(|path| project.open_buffer(path, cx))
430 })
431 .ok()??;
432 let buffer = buffer.await.log_err()?;
433 let position = buffer.update(cx, |buffer, _| {
434 let snapshot = buffer.snapshot();
435 if let Some(row) = location.line {
436 let column = snapshot.indent_size_for_line(row).len;
437 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
438 snapshot.anchor_before(point)
439 } else {
440 Anchor::min_for_buffer(snapshot.remote_id())
441 }
442 });
443
444 Some(ResolvedLocation { buffer, position })
445 }
446
447 fn resolve_locations(
448 &self,
449 project: Entity<Project>,
450 cx: &mut App,
451 ) -> Task<Vec<Option<ResolvedLocation>>> {
452 let locations = self.locations.clone();
453 project.update(cx, |_, cx| {
454 cx.spawn(async move |project, cx| {
455 let mut new_locations = Vec::new();
456 for location in locations {
457 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
458 }
459 new_locations
460 })
461 })
462 }
463}
464
465// Separate so we can hold a strong reference to the buffer
466// for saving on the thread
467#[derive(Clone, Debug, PartialEq, Eq)]
468struct ResolvedLocation {
469 buffer: Entity<Buffer>,
470 position: Anchor,
471}
472
473impl From<&ResolvedLocation> for AgentLocation {
474 fn from(value: &ResolvedLocation) -> Self {
475 Self {
476 buffer: value.buffer.downgrade(),
477 position: value.position,
478 }
479 }
480}
481
482#[derive(Debug)]
483pub enum ToolCallStatus {
484 /// The tool call hasn't started running yet, but we start showing it to
485 /// the user.
486 Pending,
487 /// The tool call is waiting for confirmation from the user.
488 WaitingForConfirmation {
489 options: PermissionOptions,
490 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
491 },
492 /// The tool call is currently running.
493 InProgress,
494 /// The tool call completed successfully.
495 Completed,
496 /// The tool call failed.
497 Failed,
498 /// The user rejected the tool call.
499 Rejected,
500 /// The user canceled generation so the tool call was canceled.
501 Canceled,
502}
503
504impl From<acp::ToolCallStatus> for ToolCallStatus {
505 fn from(status: acp::ToolCallStatus) -> Self {
506 match status {
507 acp::ToolCallStatus::Pending => Self::Pending,
508 acp::ToolCallStatus::InProgress => Self::InProgress,
509 acp::ToolCallStatus::Completed => Self::Completed,
510 acp::ToolCallStatus::Failed => Self::Failed,
511 _ => Self::Pending,
512 }
513 }
514}
515
516impl Display for ToolCallStatus {
517 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "{}",
521 match self {
522 ToolCallStatus::Pending => "Pending",
523 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
524 ToolCallStatus::InProgress => "In Progress",
525 ToolCallStatus::Completed => "Completed",
526 ToolCallStatus::Failed => "Failed",
527 ToolCallStatus::Rejected => "Rejected",
528 ToolCallStatus::Canceled => "Canceled",
529 }
530 )
531 }
532}
533
534#[derive(Debug, PartialEq, Clone)]
535pub enum ContentBlock {
536 Empty,
537 Markdown { markdown: Entity<Markdown> },
538 ResourceLink { resource_link: acp::ResourceLink },
539 Image { image: Arc<gpui::Image> },
540}
541
542impl ContentBlock {
543 pub fn new(
544 block: acp::ContentBlock,
545 language_registry: &Arc<LanguageRegistry>,
546 path_style: PathStyle,
547 cx: &mut App,
548 ) -> Self {
549 let mut this = Self::Empty;
550 this.append(block, language_registry, path_style, cx);
551 this
552 }
553
554 pub fn new_combined(
555 blocks: impl IntoIterator<Item = acp::ContentBlock>,
556 language_registry: Arc<LanguageRegistry>,
557 path_style: PathStyle,
558 cx: &mut App,
559 ) -> Self {
560 let mut this = Self::Empty;
561 for block in blocks {
562 this.append(block, &language_registry, path_style, cx);
563 }
564 this
565 }
566
567 pub fn append(
568 &mut self,
569 block: acp::ContentBlock,
570 language_registry: &Arc<LanguageRegistry>,
571 path_style: PathStyle,
572 cx: &mut App,
573 ) {
574 match (&mut *self, &block) {
575 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
576 *self = ContentBlock::ResourceLink {
577 resource_link: resource_link.clone(),
578 };
579 }
580 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
581 if let Some(image) = Self::decode_image(image_content) {
582 *self = ContentBlock::Image { image };
583 } else {
584 let new_content = Self::image_md(image_content);
585 *self = Self::create_markdown_block(new_content, language_registry, cx);
586 }
587 }
588 (ContentBlock::Empty, _) => {
589 let new_content = Self::block_string_contents(&block, path_style);
590 *self = Self::create_markdown_block(new_content, language_registry, cx);
591 }
592 (ContentBlock::Markdown { markdown }, _) => {
593 let new_content = Self::block_string_contents(&block, path_style);
594 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
595 }
596 (ContentBlock::ResourceLink { resource_link }, _) => {
597 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
598 let new_content = Self::block_string_contents(&block, path_style);
599 let combined = format!("{}\n{}", existing_content, new_content);
600 *self = Self::create_markdown_block(combined, language_registry, cx);
601 }
602 (ContentBlock::Image { .. }, _) => {
603 let new_content = Self::block_string_contents(&block, path_style);
604 let combined = format!("`Image`\n{}", new_content);
605 *self = Self::create_markdown_block(combined, language_registry, cx);
606 }
607 }
608 }
609
610 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
611 use base64::Engine as _;
612
613 let bytes = base64::engine::general_purpose::STANDARD
614 .decode(image_content.data.as_bytes())
615 .ok()?;
616 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
617 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
618 }
619
620 fn create_markdown_block(
621 content: String,
622 language_registry: &Arc<LanguageRegistry>,
623 cx: &mut App,
624 ) -> ContentBlock {
625 ContentBlock::Markdown {
626 markdown: cx
627 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
628 }
629 }
630
631 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
632 match block {
633 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
634 acp::ContentBlock::ResourceLink(resource_link) => {
635 Self::resource_link_md(&resource_link.uri, path_style)
636 }
637 acp::ContentBlock::Resource(acp::EmbeddedResource {
638 resource:
639 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
640 uri,
641 ..
642 }),
643 ..
644 }) => Self::resource_link_md(uri, path_style),
645 acp::ContentBlock::Image(image) => Self::image_md(image),
646 _ => String::new(),
647 }
648 }
649
650 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
651 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
652 uri.as_link().to_string()
653 } else {
654 uri.to_string()
655 }
656 }
657
658 fn image_md(_image: &acp::ImageContent) -> String {
659 "`Image`".into()
660 }
661
662 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
663 match self {
664 ContentBlock::Empty => "",
665 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
666 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
667 ContentBlock::Image { .. } => "`Image`",
668 }
669 }
670
671 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
672 match self {
673 ContentBlock::Empty => None,
674 ContentBlock::Markdown { markdown } => Some(markdown),
675 ContentBlock::ResourceLink { .. } => None,
676 ContentBlock::Image { .. } => None,
677 }
678 }
679
680 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
681 match self {
682 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
683 _ => None,
684 }
685 }
686
687 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
688 match self {
689 ContentBlock::Image { image } => Some(image),
690 _ => None,
691 }
692 }
693}
694
695#[derive(Debug)]
696pub enum ToolCallContent {
697 ContentBlock(ContentBlock),
698 Diff(Entity<Diff>),
699 Terminal(Entity<Terminal>),
700}
701
702impl ToolCallContent {
703 pub fn from_acp(
704 content: acp::ToolCallContent,
705 language_registry: Arc<LanguageRegistry>,
706 path_style: PathStyle,
707 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
708 cx: &mut App,
709 ) -> Result<Option<Self>> {
710 match content {
711 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
712 Ok(Some(Self::ContentBlock(ContentBlock::new(
713 content,
714 &language_registry,
715 path_style,
716 cx,
717 ))))
718 }
719 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
720 Diff::finalized(
721 diff.path.to_string_lossy().into_owned(),
722 diff.old_text,
723 diff.new_text,
724 language_registry,
725 cx,
726 )
727 })))),
728 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
729 .get(&terminal_id)
730 .cloned()
731 .map(|terminal| Some(Self::Terminal(terminal)))
732 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
733 _ => Ok(None),
734 }
735 }
736
737 pub fn update_from_acp(
738 &mut self,
739 new: acp::ToolCallContent,
740 language_registry: Arc<LanguageRegistry>,
741 path_style: PathStyle,
742 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
743 cx: &mut App,
744 ) -> Result<bool> {
745 let needs_update = match (&self, &new) {
746 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
747 old_diff.read(cx).needs_update(
748 new_diff.old_text.as_deref().unwrap_or(""),
749 &new_diff.new_text,
750 cx,
751 )
752 }
753 _ => true,
754 };
755
756 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
757 if needs_update {
758 *self = update;
759 }
760 Ok(true)
761 } else {
762 Ok(false)
763 }
764 }
765
766 pub fn to_markdown(&self, cx: &App) -> String {
767 match self {
768 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
769 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
770 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
771 }
772 }
773
774 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
775 match self {
776 Self::ContentBlock(content) => content.image(),
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, PartialEq)]
783pub enum ToolCallUpdate {
784 UpdateFields(acp::ToolCallUpdate),
785 UpdateDiff(ToolCallUpdateDiff),
786 UpdateTerminal(ToolCallUpdateTerminal),
787}
788
789impl ToolCallUpdate {
790 fn id(&self) -> &acp::ToolCallId {
791 match self {
792 Self::UpdateFields(update) => &update.tool_call_id,
793 Self::UpdateDiff(diff) => &diff.id,
794 Self::UpdateTerminal(terminal) => &terminal.id,
795 }
796 }
797}
798
799impl From<acp::ToolCallUpdate> for ToolCallUpdate {
800 fn from(update: acp::ToolCallUpdate) -> Self {
801 Self::UpdateFields(update)
802 }
803}
804
805impl From<ToolCallUpdateDiff> for ToolCallUpdate {
806 fn from(diff: ToolCallUpdateDiff) -> Self {
807 Self::UpdateDiff(diff)
808 }
809}
810
811#[derive(Debug, PartialEq)]
812pub struct ToolCallUpdateDiff {
813 pub id: acp::ToolCallId,
814 pub diff: Entity<Diff>,
815}
816
817impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
818 fn from(terminal: ToolCallUpdateTerminal) -> Self {
819 Self::UpdateTerminal(terminal)
820 }
821}
822
823#[derive(Debug, PartialEq)]
824pub struct ToolCallUpdateTerminal {
825 pub id: acp::ToolCallId,
826 pub terminal: Entity<Terminal>,
827}
828
829#[derive(Debug, Default)]
830pub struct Plan {
831 pub entries: Vec<PlanEntry>,
832}
833
834#[derive(Debug)]
835pub struct PlanStats<'a> {
836 pub in_progress_entry: Option<&'a PlanEntry>,
837 pub pending: u32,
838 pub completed: u32,
839}
840
841impl Plan {
842 pub fn is_empty(&self) -> bool {
843 self.entries.is_empty()
844 }
845
846 pub fn stats(&self) -> PlanStats<'_> {
847 let mut stats = PlanStats {
848 in_progress_entry: None,
849 pending: 0,
850 completed: 0,
851 };
852
853 for entry in &self.entries {
854 match &entry.status {
855 acp::PlanEntryStatus::Pending => {
856 stats.pending += 1;
857 }
858 acp::PlanEntryStatus::InProgress => {
859 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
860 }
861 acp::PlanEntryStatus::Completed => {
862 stats.completed += 1;
863 }
864 _ => {}
865 }
866 }
867
868 stats
869 }
870}
871
872#[derive(Debug)]
873pub struct PlanEntry {
874 pub content: Entity<Markdown>,
875 pub priority: acp::PlanEntryPriority,
876 pub status: acp::PlanEntryStatus,
877}
878
879impl PlanEntry {
880 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
881 Self {
882 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
883 priority: entry.priority,
884 status: entry.status,
885 }
886 }
887}
888
889#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
890pub struct TokenUsage {
891 pub max_tokens: u64,
892 pub used_tokens: u64,
893 pub input_tokens: u64,
894 pub output_tokens: u64,
895 pub max_output_tokens: Option<u64>,
896}
897
898impl TokenUsage {
899 pub fn ratio(&self) -> TokenUsageRatio {
900 #[cfg(debug_assertions)]
901 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
902 .unwrap_or("0.8".to_string())
903 .parse()
904 .unwrap();
905 #[cfg(not(debug_assertions))]
906 let warning_threshold: f32 = 0.8;
907
908 // When the maximum is unknown because there is no selected model,
909 // avoid showing the token limit warning.
910 if self.max_tokens == 0 {
911 TokenUsageRatio::Normal
912 } else if self.used_tokens >= self.max_tokens {
913 TokenUsageRatio::Exceeded
914 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
915 TokenUsageRatio::Warning
916 } else {
917 TokenUsageRatio::Normal
918 }
919 }
920}
921
922#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
923pub enum TokenUsageRatio {
924 Normal,
925 Warning,
926 Exceeded,
927}
928
929#[derive(Debug, Clone)]
930pub struct RetryStatus {
931 pub last_error: SharedString,
932 pub attempt: usize,
933 pub max_attempts: usize,
934 pub started_at: Instant,
935 pub duration: Duration,
936}
937
938struct RunningTurn {
939 id: u32,
940 send_task: Task<()>,
941}
942
943pub struct AcpThread {
944 parent_session_id: Option<acp::SessionId>,
945 title: SharedString,
946 entries: Vec<AgentThreadEntry>,
947 plan: Plan,
948 project: Entity<Project>,
949 action_log: Entity<ActionLog>,
950 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
951 turn_id: u32,
952 running_turn: Option<RunningTurn>,
953 connection: Rc<dyn AgentConnection>,
954 session_id: acp::SessionId,
955 token_usage: Option<TokenUsage>,
956 prompt_capabilities: acp::PromptCapabilities,
957 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
958 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
959 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
960 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
961 had_error: bool,
962}
963
964impl From<&AcpThread> for ActionLogTelemetry {
965 fn from(value: &AcpThread) -> Self {
966 Self {
967 agent_telemetry_id: value.connection().telemetry_id(),
968 session_id: value.session_id.0.clone(),
969 }
970 }
971}
972
973#[derive(Debug)]
974pub enum AcpThreadEvent {
975 NewEntry,
976 TitleUpdated,
977 TokenUsageUpdated,
978 EntryUpdated(usize),
979 EntriesRemoved(Range<usize>),
980 ToolAuthorizationRequired,
981 Retry(RetryStatus),
982 SubagentSpawned(acp::SessionId),
983 Stopped,
984 Error,
985 LoadError(LoadError),
986 PromptCapabilitiesUpdated,
987 Refusal,
988 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
989 ModeUpdated(acp::SessionModeId),
990 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
991}
992
993impl EventEmitter<AcpThreadEvent> for AcpThread {}
994
995#[derive(Debug, Clone)]
996pub enum TerminalProviderEvent {
997 Created {
998 terminal_id: acp::TerminalId,
999 label: String,
1000 cwd: Option<PathBuf>,
1001 output_byte_limit: Option<u64>,
1002 terminal: Entity<::terminal::Terminal>,
1003 },
1004 Output {
1005 terminal_id: acp::TerminalId,
1006 data: Vec<u8>,
1007 },
1008 TitleChanged {
1009 terminal_id: acp::TerminalId,
1010 title: String,
1011 },
1012 Exit {
1013 terminal_id: acp::TerminalId,
1014 status: acp::TerminalExitStatus,
1015 },
1016}
1017
1018#[derive(Debug, Clone)]
1019pub enum TerminalProviderCommand {
1020 WriteInput {
1021 terminal_id: acp::TerminalId,
1022 bytes: Vec<u8>,
1023 },
1024 Resize {
1025 terminal_id: acp::TerminalId,
1026 cols: u16,
1027 rows: u16,
1028 },
1029 Close {
1030 terminal_id: acp::TerminalId,
1031 },
1032}
1033
1034impl AcpThread {
1035 pub fn on_terminal_provider_event(
1036 &mut self,
1037 event: TerminalProviderEvent,
1038 cx: &mut Context<Self>,
1039 ) {
1040 match event {
1041 TerminalProviderEvent::Created {
1042 terminal_id,
1043 label,
1044 cwd,
1045 output_byte_limit,
1046 terminal,
1047 } => {
1048 let entity = self.register_terminal_created(
1049 terminal_id.clone(),
1050 label,
1051 cwd,
1052 output_byte_limit,
1053 terminal,
1054 cx,
1055 );
1056
1057 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1058 for data in chunks.drain(..) {
1059 entity.update(cx, |term, cx| {
1060 term.inner().update(cx, |inner, cx| {
1061 inner.write_output(&data, cx);
1062 })
1063 });
1064 }
1065 }
1066
1067 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1068 entity.update(cx, |_term, cx| {
1069 cx.notify();
1070 });
1071 }
1072
1073 cx.notify();
1074 }
1075 TerminalProviderEvent::Output { terminal_id, data } => {
1076 if let Some(entity) = self.terminals.get(&terminal_id) {
1077 entity.update(cx, |term, cx| {
1078 term.inner().update(cx, |inner, cx| {
1079 inner.write_output(&data, cx);
1080 })
1081 });
1082 } else {
1083 self.pending_terminal_output
1084 .entry(terminal_id)
1085 .or_default()
1086 .push(data);
1087 }
1088 }
1089 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1090 if let Some(entity) = self.terminals.get(&terminal_id) {
1091 entity.update(cx, |term, cx| {
1092 term.inner().update(cx, |inner, cx| {
1093 inner.breadcrumb_text = title;
1094 cx.emit(::terminal::Event::BreadcrumbsChanged);
1095 })
1096 });
1097 }
1098 }
1099 TerminalProviderEvent::Exit {
1100 terminal_id,
1101 status,
1102 } => {
1103 if let Some(entity) = self.terminals.get(&terminal_id) {
1104 entity.update(cx, |_term, cx| {
1105 cx.notify();
1106 });
1107 } else {
1108 self.pending_terminal_exit.insert(terminal_id, status);
1109 }
1110 }
1111 }
1112 }
1113}
1114
1115#[derive(PartialEq, Eq, Debug)]
1116pub enum ThreadStatus {
1117 Idle,
1118 Generating,
1119}
1120
1121#[derive(Debug, Clone)]
1122pub enum LoadError {
1123 Unsupported {
1124 command: SharedString,
1125 current_version: SharedString,
1126 minimum_version: SharedString,
1127 },
1128 FailedToInstall(SharedString),
1129 Exited {
1130 status: ExitStatus,
1131 },
1132 Other(SharedString),
1133}
1134
1135impl Display for LoadError {
1136 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1137 match self {
1138 LoadError::Unsupported {
1139 command: path,
1140 current_version,
1141 minimum_version,
1142 } => {
1143 write!(
1144 f,
1145 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1146 )
1147 }
1148 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1149 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1150 LoadError::Other(msg) => write!(f, "{msg}"),
1151 }
1152 }
1153}
1154
1155impl Error for LoadError {}
1156
1157impl AcpThread {
1158 pub fn new(
1159 parent_session_id: Option<acp::SessionId>,
1160 title: impl Into<SharedString>,
1161 connection: Rc<dyn AgentConnection>,
1162 project: Entity<Project>,
1163 action_log: Entity<ActionLog>,
1164 session_id: acp::SessionId,
1165 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1166 cx: &mut Context<Self>,
1167 ) -> Self {
1168 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1169 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1170 loop {
1171 let caps = prompt_capabilities_rx.recv().await?;
1172 this.update(cx, |this, cx| {
1173 this.prompt_capabilities = caps;
1174 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1175 })?;
1176 }
1177 });
1178
1179 Self {
1180 parent_session_id,
1181 action_log,
1182 shared_buffers: Default::default(),
1183 entries: Default::default(),
1184 plan: Default::default(),
1185 title: title.into(),
1186 project,
1187 running_turn: None,
1188 turn_id: 0,
1189 connection,
1190 session_id,
1191 token_usage: None,
1192 prompt_capabilities,
1193 _observe_prompt_capabilities: task,
1194 terminals: HashMap::default(),
1195 pending_terminal_output: HashMap::default(),
1196 pending_terminal_exit: HashMap::default(),
1197 had_error: false,
1198 }
1199 }
1200
1201 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1202 self.parent_session_id.as_ref()
1203 }
1204
1205 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1206 self.prompt_capabilities.clone()
1207 }
1208
1209 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1210 &self.connection
1211 }
1212
1213 pub fn action_log(&self) -> &Entity<ActionLog> {
1214 &self.action_log
1215 }
1216
1217 pub fn project(&self) -> &Entity<Project> {
1218 &self.project
1219 }
1220
1221 pub fn title(&self) -> SharedString {
1222 self.title.clone()
1223 }
1224
1225 pub fn entries(&self) -> &[AgentThreadEntry] {
1226 &self.entries
1227 }
1228
1229 pub fn session_id(&self) -> &acp::SessionId {
1230 &self.session_id
1231 }
1232
1233 pub fn status(&self) -> ThreadStatus {
1234 if self.running_turn.is_some() {
1235 ThreadStatus::Generating
1236 } else {
1237 ThreadStatus::Idle
1238 }
1239 }
1240
1241 pub fn had_error(&self) -> bool {
1242 self.had_error
1243 }
1244
1245 pub fn is_waiting_for_confirmation(&self) -> bool {
1246 for entry in self.entries.iter().rev() {
1247 match entry {
1248 AgentThreadEntry::UserMessage(_) => return false,
1249 AgentThreadEntry::ToolCall(ToolCall {
1250 status: ToolCallStatus::WaitingForConfirmation { .. },
1251 ..
1252 }) => return true,
1253 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1254 }
1255 }
1256 false
1257 }
1258
1259 pub fn token_usage(&self) -> Option<&TokenUsage> {
1260 self.token_usage.as_ref()
1261 }
1262
1263 pub fn has_pending_edit_tool_calls(&self) -> bool {
1264 for entry in self.entries.iter().rev() {
1265 match entry {
1266 AgentThreadEntry::UserMessage(_) => return false,
1267 AgentThreadEntry::ToolCall(
1268 call @ ToolCall {
1269 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1270 ..
1271 },
1272 ) if call.diffs().next().is_some() => {
1273 return true;
1274 }
1275 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1276 }
1277 }
1278
1279 false
1280 }
1281
1282 pub fn has_in_progress_tool_calls(&self) -> bool {
1283 for entry in self.entries.iter().rev() {
1284 match entry {
1285 AgentThreadEntry::UserMessage(_) => return false,
1286 AgentThreadEntry::ToolCall(ToolCall {
1287 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1288 ..
1289 }) => {
1290 return true;
1291 }
1292 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1293 }
1294 }
1295
1296 false
1297 }
1298
1299 pub fn used_tools_since_last_user_message(&self) -> bool {
1300 for entry in self.entries.iter().rev() {
1301 match entry {
1302 AgentThreadEntry::UserMessage(..) => return false,
1303 AgentThreadEntry::AssistantMessage(..) => continue,
1304 AgentThreadEntry::ToolCall(..) => return true,
1305 }
1306 }
1307
1308 false
1309 }
1310
1311 pub fn handle_session_update(
1312 &mut self,
1313 update: acp::SessionUpdate,
1314 cx: &mut Context<Self>,
1315 ) -> Result<(), acp::Error> {
1316 match update {
1317 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1318 self.push_user_content_block(None, content, cx);
1319 }
1320 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1321 self.push_assistant_content_block(content, false, cx);
1322 }
1323 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1324 self.push_assistant_content_block(content, true, cx);
1325 }
1326 acp::SessionUpdate::ToolCall(tool_call) => {
1327 self.upsert_tool_call(tool_call, cx)?;
1328 }
1329 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1330 self.update_tool_call(tool_call_update, cx)?;
1331 }
1332 acp::SessionUpdate::Plan(plan) => {
1333 self.update_plan(plan, cx);
1334 }
1335 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1336 available_commands,
1337 ..
1338 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1339 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1340 current_mode_id,
1341 ..
1342 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1343 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1344 config_options,
1345 ..
1346 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1347 _ => {}
1348 }
1349 Ok(())
1350 }
1351
1352 pub fn push_user_content_block(
1353 &mut self,
1354 message_id: Option<UserMessageId>,
1355 chunk: acp::ContentBlock,
1356 cx: &mut Context<Self>,
1357 ) {
1358 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1359 }
1360
1361 pub fn push_user_content_block_with_indent(
1362 &mut self,
1363 message_id: Option<UserMessageId>,
1364 chunk: acp::ContentBlock,
1365 indented: bool,
1366 cx: &mut Context<Self>,
1367 ) {
1368 let language_registry = self.project.read(cx).languages().clone();
1369 let path_style = self.project.read(cx).path_style(cx);
1370 let entries_len = self.entries.len();
1371
1372 if let Some(last_entry) = self.entries.last_mut()
1373 && let AgentThreadEntry::UserMessage(UserMessage {
1374 id,
1375 content,
1376 chunks,
1377 indented: existing_indented,
1378 ..
1379 }) = last_entry
1380 && *existing_indented == indented
1381 {
1382 *id = message_id.or(id.take());
1383 content.append(chunk.clone(), &language_registry, path_style, cx);
1384 chunks.push(chunk);
1385 let idx = entries_len - 1;
1386 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1387 } else {
1388 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1389 self.push_entry(
1390 AgentThreadEntry::UserMessage(UserMessage {
1391 id: message_id,
1392 content,
1393 chunks: vec![chunk],
1394 checkpoint: None,
1395 indented,
1396 }),
1397 cx,
1398 );
1399 }
1400 }
1401
1402 pub fn push_assistant_content_block(
1403 &mut self,
1404 chunk: acp::ContentBlock,
1405 is_thought: bool,
1406 cx: &mut Context<Self>,
1407 ) {
1408 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1409 }
1410
1411 pub fn push_assistant_content_block_with_indent(
1412 &mut self,
1413 chunk: acp::ContentBlock,
1414 is_thought: bool,
1415 indented: bool,
1416 cx: &mut Context<Self>,
1417 ) {
1418 let language_registry = self.project.read(cx).languages().clone();
1419 let path_style = self.project.read(cx).path_style(cx);
1420 let entries_len = self.entries.len();
1421 if let Some(last_entry) = self.entries.last_mut()
1422 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1423 chunks,
1424 indented: existing_indented,
1425 }) = last_entry
1426 && *existing_indented == indented
1427 {
1428 let idx = entries_len - 1;
1429 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1430 match (chunks.last_mut(), is_thought) {
1431 (Some(AssistantMessageChunk::Message { block }), false)
1432 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1433 block.append(chunk, &language_registry, path_style, cx)
1434 }
1435 _ => {
1436 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1437 if is_thought {
1438 chunks.push(AssistantMessageChunk::Thought { block })
1439 } else {
1440 chunks.push(AssistantMessageChunk::Message { block })
1441 }
1442 }
1443 }
1444 } else {
1445 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1446 let chunk = if is_thought {
1447 AssistantMessageChunk::Thought { block }
1448 } else {
1449 AssistantMessageChunk::Message { block }
1450 };
1451
1452 self.push_entry(
1453 AgentThreadEntry::AssistantMessage(AssistantMessage {
1454 chunks: vec![chunk],
1455 indented,
1456 }),
1457 cx,
1458 );
1459 }
1460 }
1461
1462 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1463 self.entries.push(entry);
1464 cx.emit(AcpThreadEvent::NewEntry);
1465 }
1466
1467 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1468 self.connection.set_title(&self.session_id, cx).is_some()
1469 }
1470
1471 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1472 if title != self.title {
1473 self.title = title.clone();
1474 cx.emit(AcpThreadEvent::TitleUpdated);
1475 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1476 return set_title.run(title, cx);
1477 }
1478 }
1479 Task::ready(Ok(()))
1480 }
1481
1482 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1483 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1484 }
1485
1486 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1487 self.token_usage = usage;
1488 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1489 }
1490
1491 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1492 cx.emit(AcpThreadEvent::Retry(status));
1493 }
1494
1495 pub fn update_tool_call(
1496 &mut self,
1497 update: impl Into<ToolCallUpdate>,
1498 cx: &mut Context<Self>,
1499 ) -> Result<()> {
1500 let update = update.into();
1501 let languages = self.project.read(cx).languages().clone();
1502 let path_style = self.project.read(cx).path_style(cx);
1503
1504 let ix = match self.index_for_tool_call(update.id()) {
1505 Some(ix) => ix,
1506 None => {
1507 // Tool call not found - create a failed tool call entry
1508 let failed_tool_call = ToolCall {
1509 id: update.id().clone(),
1510 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1511 kind: acp::ToolKind::Fetch,
1512 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1513 "Tool call not found".into(),
1514 &languages,
1515 path_style,
1516 cx,
1517 ))],
1518 status: ToolCallStatus::Failed,
1519 locations: Vec::new(),
1520 resolved_locations: Vec::new(),
1521 raw_input: None,
1522 raw_input_markdown: None,
1523 raw_output: None,
1524 tool_name: None,
1525 subagent_session_id: None,
1526 };
1527 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1528 return Ok(());
1529 }
1530 };
1531 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1532 unreachable!()
1533 };
1534
1535 match update {
1536 ToolCallUpdate::UpdateFields(update) => {
1537 let location_updated = update.fields.locations.is_some();
1538 call.update_fields(
1539 update.fields,
1540 update.meta,
1541 languages,
1542 path_style,
1543 &self.terminals,
1544 cx,
1545 )?;
1546 if location_updated {
1547 self.resolve_locations(update.tool_call_id, cx);
1548 }
1549 }
1550 ToolCallUpdate::UpdateDiff(update) => {
1551 call.content.clear();
1552 call.content.push(ToolCallContent::Diff(update.diff));
1553 }
1554 ToolCallUpdate::UpdateTerminal(update) => {
1555 call.content.clear();
1556 call.content
1557 .push(ToolCallContent::Terminal(update.terminal));
1558 }
1559 }
1560
1561 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1562
1563 Ok(())
1564 }
1565
1566 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1567 pub fn upsert_tool_call(
1568 &mut self,
1569 tool_call: acp::ToolCall,
1570 cx: &mut Context<Self>,
1571 ) -> Result<(), acp::Error> {
1572 let status = tool_call.status.into();
1573 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1574 }
1575
1576 /// Fails if id does not match an existing entry.
1577 pub fn upsert_tool_call_inner(
1578 &mut self,
1579 update: acp::ToolCallUpdate,
1580 status: ToolCallStatus,
1581 cx: &mut Context<Self>,
1582 ) -> Result<(), acp::Error> {
1583 let language_registry = self.project.read(cx).languages().clone();
1584 let path_style = self.project.read(cx).path_style(cx);
1585 let id = update.tool_call_id.clone();
1586
1587 let agent_telemetry_id = self.connection().telemetry_id();
1588 let session = self.session_id();
1589 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1590 let status = if matches!(status, ToolCallStatus::Completed) {
1591 "completed"
1592 } else {
1593 "failed"
1594 };
1595 telemetry::event!(
1596 "Agent Tool Call Completed",
1597 agent_telemetry_id,
1598 session,
1599 status
1600 );
1601 }
1602
1603 if let Some(ix) = self.index_for_tool_call(&id) {
1604 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1605 unreachable!()
1606 };
1607
1608 call.update_fields(
1609 update.fields,
1610 update.meta,
1611 language_registry,
1612 path_style,
1613 &self.terminals,
1614 cx,
1615 )?;
1616 call.status = status;
1617
1618 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1619 } else {
1620 let call = ToolCall::from_acp(
1621 update.try_into()?,
1622 status,
1623 language_registry,
1624 self.project.read(cx).path_style(cx),
1625 &self.terminals,
1626 cx,
1627 )?;
1628 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1629 };
1630
1631 self.resolve_locations(id, cx);
1632 Ok(())
1633 }
1634
1635 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1636 self.entries
1637 .iter()
1638 .enumerate()
1639 .rev()
1640 .find_map(|(index, entry)| {
1641 if let AgentThreadEntry::ToolCall(tool_call) = entry
1642 && &tool_call.id == id
1643 {
1644 Some(index)
1645 } else {
1646 None
1647 }
1648 })
1649 }
1650
1651 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1652 // The tool call we are looking for is typically the last one, or very close to the end.
1653 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1654 self.entries
1655 .iter_mut()
1656 .enumerate()
1657 .rev()
1658 .find_map(|(index, tool_call)| {
1659 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1660 && &tool_call.id == id
1661 {
1662 Some((index, tool_call))
1663 } else {
1664 None
1665 }
1666 })
1667 }
1668
1669 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1670 self.entries
1671 .iter()
1672 .enumerate()
1673 .rev()
1674 .find_map(|(index, tool_call)| {
1675 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1676 && &tool_call.id == id
1677 {
1678 Some((index, tool_call))
1679 } else {
1680 None
1681 }
1682 })
1683 }
1684
1685 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1686 let project = self.project.clone();
1687 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1688 return;
1689 };
1690 let task = tool_call.resolve_locations(project, cx);
1691 cx.spawn(async move |this, cx| {
1692 let resolved_locations = task.await;
1693
1694 this.update(cx, |this, cx| {
1695 let project = this.project.clone();
1696
1697 for location in resolved_locations.iter().flatten() {
1698 this.shared_buffers
1699 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1700 }
1701 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1702 return;
1703 };
1704
1705 if let Some(Some(location)) = resolved_locations.last() {
1706 project.update(cx, |project, cx| {
1707 let should_ignore = if let Some(agent_location) = project
1708 .agent_location()
1709 .filter(|agent_location| agent_location.buffer == location.buffer)
1710 {
1711 let snapshot = location.buffer.read(cx).snapshot();
1712 let old_position = agent_location.position.to_point(&snapshot);
1713 let new_position = location.position.to_point(&snapshot);
1714
1715 // ignore this so that when we get updates from the edit tool
1716 // the position doesn't reset to the startof line
1717 old_position.row == new_position.row
1718 && old_position.column > new_position.column
1719 } else {
1720 false
1721 };
1722 if !should_ignore {
1723 project.set_agent_location(Some(location.into()), cx);
1724 }
1725 });
1726 }
1727
1728 let resolved_locations = resolved_locations
1729 .iter()
1730 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1731 .collect::<Vec<_>>();
1732
1733 if tool_call.resolved_locations != resolved_locations {
1734 tool_call.resolved_locations = resolved_locations;
1735 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1736 }
1737 })
1738 })
1739 .detach();
1740 }
1741
1742 pub fn request_tool_call_authorization(
1743 &mut self,
1744 tool_call: acp::ToolCallUpdate,
1745 options: PermissionOptions,
1746 cx: &mut Context<Self>,
1747 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1748 let (tx, rx) = oneshot::channel();
1749
1750 let status = ToolCallStatus::WaitingForConfirmation {
1751 options,
1752 respond_tx: tx,
1753 };
1754
1755 self.upsert_tool_call_inner(tool_call, status, cx)?;
1756 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1757
1758 let fut = async {
1759 match rx.await {
1760 Ok(option) => acp::RequestPermissionOutcome::Selected(
1761 acp::SelectedPermissionOutcome::new(option),
1762 ),
1763 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1764 }
1765 }
1766 .boxed();
1767
1768 Ok(fut)
1769 }
1770
1771 pub fn authorize_tool_call(
1772 &mut self,
1773 id: acp::ToolCallId,
1774 option_id: acp::PermissionOptionId,
1775 option_kind: acp::PermissionOptionKind,
1776 cx: &mut Context<Self>,
1777 ) {
1778 let Some((ix, call)) = self.tool_call_mut(&id) else {
1779 return;
1780 };
1781
1782 let new_status = match option_kind {
1783 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1784 ToolCallStatus::Rejected
1785 }
1786 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1787 ToolCallStatus::InProgress
1788 }
1789 _ => ToolCallStatus::InProgress,
1790 };
1791
1792 let curr_status = mem::replace(&mut call.status, new_status);
1793
1794 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1795 respond_tx.send(option_id).log_err();
1796 } else if cfg!(debug_assertions) {
1797 panic!("tried to authorize an already authorized tool call");
1798 }
1799
1800 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1801 }
1802
1803 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1804 let mut first_tool_call = None;
1805
1806 for entry in self.entries.iter().rev() {
1807 match &entry {
1808 AgentThreadEntry::ToolCall(call) => {
1809 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1810 first_tool_call = Some(call);
1811 } else {
1812 continue;
1813 }
1814 }
1815 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1816 // Reached the beginning of the turn.
1817 // If we had pending permission requests in the previous turn, they have been cancelled.
1818 break;
1819 }
1820 }
1821 }
1822
1823 first_tool_call
1824 }
1825
1826 pub fn plan(&self) -> &Plan {
1827 &self.plan
1828 }
1829
1830 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1831 let new_entries_len = request.entries.len();
1832 let mut new_entries = request.entries.into_iter();
1833
1834 // Reuse existing markdown to prevent flickering
1835 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1836 let PlanEntry {
1837 content,
1838 priority,
1839 status,
1840 } = old;
1841 content.update(cx, |old, cx| {
1842 old.replace(new.content, cx);
1843 });
1844 *priority = new.priority;
1845 *status = new.status;
1846 }
1847 for new in new_entries {
1848 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1849 }
1850 self.plan.entries.truncate(new_entries_len);
1851
1852 cx.notify();
1853 }
1854
1855 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1856 self.plan
1857 .entries
1858 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1859 cx.notify();
1860 }
1861
1862 #[cfg(any(test, feature = "test-support"))]
1863 pub fn send_raw(
1864 &mut self,
1865 message: &str,
1866 cx: &mut Context<Self>,
1867 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1868 self.send(vec![message.into()], cx)
1869 }
1870
1871 pub fn send(
1872 &mut self,
1873 message: Vec<acp::ContentBlock>,
1874 cx: &mut Context<Self>,
1875 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1876 let block = ContentBlock::new_combined(
1877 message.clone(),
1878 self.project.read(cx).languages().clone(),
1879 self.project.read(cx).path_style(cx),
1880 cx,
1881 );
1882 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1883 let git_store = self.project.read(cx).git_store().clone();
1884
1885 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1886 Some(UserMessageId::new())
1887 } else {
1888 None
1889 };
1890
1891 self.run_turn(cx, async move |this, cx| {
1892 this.update(cx, |this, cx| {
1893 this.push_entry(
1894 AgentThreadEntry::UserMessage(UserMessage {
1895 id: message_id.clone(),
1896 content: block,
1897 chunks: message,
1898 checkpoint: None,
1899 indented: false,
1900 }),
1901 cx,
1902 );
1903 })
1904 .ok();
1905
1906 let old_checkpoint = git_store
1907 .update(cx, |git, cx| git.checkpoint(cx))
1908 .await
1909 .context("failed to get old checkpoint")
1910 .log_err();
1911 this.update(cx, |this, cx| {
1912 if let Some((_ix, message)) = this.last_user_message() {
1913 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1914 git_checkpoint,
1915 show: false,
1916 });
1917 }
1918 this.connection.prompt(message_id, request, cx)
1919 })?
1920 .await
1921 })
1922 }
1923
1924 pub fn can_retry(&self, cx: &App) -> bool {
1925 self.connection.retry(&self.session_id, cx).is_some()
1926 }
1927
1928 pub fn retry(
1929 &mut self,
1930 cx: &mut Context<Self>,
1931 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1932 self.run_turn(cx, async move |this, cx| {
1933 this.update(cx, |this, cx| {
1934 this.connection
1935 .retry(&this.session_id, cx)
1936 .map(|retry| retry.run(cx))
1937 })?
1938 .context("retrying a session is not supported")?
1939 .await
1940 })
1941 }
1942
1943 fn run_turn(
1944 &mut self,
1945 cx: &mut Context<Self>,
1946 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1947 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1948 self.clear_completed_plan_entries(cx);
1949 self.had_error = false;
1950
1951 let (tx, rx) = oneshot::channel();
1952 let cancel_task = self.cancel(cx);
1953
1954 self.turn_id += 1;
1955 let turn_id = self.turn_id;
1956 self.running_turn = Some(RunningTurn {
1957 id: turn_id,
1958 send_task: cx.spawn(async move |this, cx| {
1959 cancel_task.await;
1960 tx.send(f(this, cx).await).ok();
1961 }),
1962 });
1963
1964 cx.spawn(async move |this, cx| {
1965 let response = rx.await;
1966
1967 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1968 .await?;
1969
1970 this.update(cx, |this, cx| {
1971 this.project
1972 .update(cx, |project, cx| project.set_agent_location(None, cx));
1973 let Ok(response) = response else {
1974 // tx dropped, just return
1975 return Ok(None);
1976 };
1977
1978 let is_same_turn = this
1979 .running_turn
1980 .as_ref()
1981 .is_some_and(|turn| turn_id == turn.id);
1982
1983 // If the user submitted a follow up message, running_turn might
1984 // already point to a different turn. Therefore we only want to
1985 // take the task if it's the same turn.
1986 if is_same_turn {
1987 this.running_turn.take();
1988 }
1989
1990 match response {
1991 Ok(r) => {
1992 if r.stop_reason == acp::StopReason::MaxTokens {
1993 this.had_error = true;
1994 cx.emit(AcpThreadEvent::Error);
1995 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1996 return Err(anyhow!("Max tokens reached"));
1997 }
1998
1999 let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
2000 if canceled {
2001 this.mark_pending_tools_as_canceled();
2002 }
2003
2004 // Handle refusal - distinguish between user prompt and tool call refusals
2005 if let acp::StopReason::Refusal = r.stop_reason {
2006 this.had_error = true;
2007 if let Some((user_msg_ix, _)) = this.last_user_message() {
2008 // Check if there's a completed tool call with results after the last user message
2009 // This indicates the refusal is in response to tool output, not the user's prompt
2010 let has_completed_tool_call_after_user_msg =
2011 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
2012 if let AgentThreadEntry::ToolCall(tool_call) = entry {
2013 // Check if the tool call has completed and has output
2014 matches!(tool_call.status, ToolCallStatus::Completed)
2015 && tool_call.raw_output.is_some()
2016 } else {
2017 false
2018 }
2019 });
2020
2021 if has_completed_tool_call_after_user_msg {
2022 // Refusal is due to tool output - don't truncate, just notify
2023 // The model refused based on what the tool returned
2024 cx.emit(AcpThreadEvent::Refusal);
2025 } else {
2026 // User prompt was refused - truncate back to before the user message
2027 let range = user_msg_ix..this.entries.len();
2028 if range.start < range.end {
2029 this.entries.truncate(user_msg_ix);
2030 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2031 }
2032 cx.emit(AcpThreadEvent::Refusal);
2033 }
2034 } else {
2035 // No user message found, treat as general refusal
2036 cx.emit(AcpThreadEvent::Refusal);
2037 }
2038 }
2039
2040 cx.emit(AcpThreadEvent::Stopped);
2041 Ok(Some(r))
2042 }
2043 Err(e) => {
2044 this.had_error = true;
2045 cx.emit(AcpThreadEvent::Error);
2046 log::error!("Error in run turn: {:?}", e);
2047 Err(e)
2048 }
2049 }
2050 })?
2051 })
2052 .boxed()
2053 }
2054
2055 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2056 let Some(turn) = self.running_turn.take() else {
2057 return Task::ready(());
2058 };
2059 self.connection.cancel(&self.session_id, cx);
2060
2061 self.mark_pending_tools_as_canceled();
2062
2063 // Wait for the send task to complete
2064 cx.background_spawn(turn.send_task)
2065 }
2066
2067 fn mark_pending_tools_as_canceled(&mut self) {
2068 for entry in self.entries.iter_mut() {
2069 if let AgentThreadEntry::ToolCall(call) = entry {
2070 let cancel = matches!(
2071 call.status,
2072 ToolCallStatus::Pending
2073 | ToolCallStatus::WaitingForConfirmation { .. }
2074 | ToolCallStatus::InProgress
2075 );
2076
2077 if cancel {
2078 call.status = ToolCallStatus::Canceled;
2079 }
2080 }
2081 }
2082 }
2083
2084 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2085 pub fn restore_checkpoint(
2086 &mut self,
2087 id: UserMessageId,
2088 cx: &mut Context<Self>,
2089 ) -> Task<Result<()>> {
2090 let Some((_, message)) = self.user_message_mut(&id) else {
2091 return Task::ready(Err(anyhow!("message not found")));
2092 };
2093
2094 let checkpoint = message
2095 .checkpoint
2096 .as_ref()
2097 .map(|c| c.git_checkpoint.clone());
2098
2099 // Cancel any in-progress generation before restoring
2100 let cancel_task = self.cancel(cx);
2101 let rewind = self.rewind(id.clone(), cx);
2102 let git_store = self.project.read(cx).git_store().clone();
2103
2104 cx.spawn(async move |_, cx| {
2105 cancel_task.await;
2106 rewind.await?;
2107 if let Some(checkpoint) = checkpoint {
2108 git_store
2109 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2110 .await?;
2111 }
2112
2113 Ok(())
2114 })
2115 }
2116
2117 /// Rewinds this thread to before the entry at `index`, removing it and all
2118 /// subsequent entries while rejecting any action_log changes made from that point.
2119 /// Unlike `restore_checkpoint`, this method does not restore from git.
2120 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2121 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2122 return Task::ready(Err(anyhow!("not supported")));
2123 };
2124
2125 let telemetry = ActionLogTelemetry::from(&*self);
2126 cx.spawn(async move |this, cx| {
2127 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2128 this.update(cx, |this, cx| {
2129 if let Some((ix, _)) = this.user_message_mut(&id) {
2130 // Collect all terminals from entries that will be removed
2131 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2132 .iter()
2133 .flat_map(|entry| entry.terminals())
2134 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2135 .collect();
2136
2137 let range = ix..this.entries.len();
2138 this.entries.truncate(ix);
2139 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2140
2141 // Kill and remove the terminals
2142 for terminal_id in terminals_to_remove {
2143 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2144 terminal.update(cx, |terminal, cx| {
2145 terminal.kill(cx);
2146 });
2147 }
2148 }
2149 }
2150 this.action_log().update(cx, |action_log, cx| {
2151 action_log.reject_all_edits(Some(telemetry), cx)
2152 })
2153 })?
2154 .await;
2155 Ok(())
2156 })
2157 }
2158
2159 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2160 let git_store = self.project.read(cx).git_store().clone();
2161
2162 let Some((_, message)) = self.last_user_message() else {
2163 return Task::ready(Ok(()));
2164 };
2165 let Some(user_message_id) = message.id.clone() else {
2166 return Task::ready(Ok(()));
2167 };
2168 let Some(checkpoint) = message.checkpoint.as_ref() else {
2169 return Task::ready(Ok(()));
2170 };
2171 let old_checkpoint = checkpoint.git_checkpoint.clone();
2172
2173 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2174 cx.spawn(async move |this, cx| {
2175 let Some(new_checkpoint) = new_checkpoint
2176 .await
2177 .context("failed to get new checkpoint")
2178 .log_err()
2179 else {
2180 return Ok(());
2181 };
2182
2183 let equal = git_store
2184 .update(cx, |git, cx| {
2185 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2186 })
2187 .await
2188 .unwrap_or(true);
2189
2190 this.update(cx, |this, cx| {
2191 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2192 if let Some(checkpoint) = message.checkpoint.as_mut() {
2193 checkpoint.show = !equal;
2194 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2195 }
2196 }
2197 })?;
2198
2199 Ok(())
2200 })
2201 }
2202
2203 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2204 self.entries
2205 .iter_mut()
2206 .enumerate()
2207 .rev()
2208 .find_map(|(ix, entry)| {
2209 if let AgentThreadEntry::UserMessage(message) = entry {
2210 Some((ix, message))
2211 } else {
2212 None
2213 }
2214 })
2215 }
2216
2217 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2218 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2219 if let AgentThreadEntry::UserMessage(message) = entry {
2220 if message.id.as_ref() == Some(id) {
2221 Some((ix, message))
2222 } else {
2223 None
2224 }
2225 } else {
2226 None
2227 }
2228 })
2229 }
2230
2231 pub fn read_text_file(
2232 &self,
2233 path: PathBuf,
2234 line: Option<u32>,
2235 limit: Option<u32>,
2236 reuse_shared_snapshot: bool,
2237 cx: &mut Context<Self>,
2238 ) -> Task<Result<String, acp::Error>> {
2239 // Args are 1-based, move to 0-based
2240 let line = line.unwrap_or_default().saturating_sub(1);
2241 let limit = limit.unwrap_or(u32::MAX);
2242 let project = self.project.clone();
2243 let action_log = self.action_log.clone();
2244 cx.spawn(async move |this, cx| {
2245 let load = project.update(cx, |project, cx| {
2246 let path = project
2247 .project_path_for_absolute_path(&path, cx)
2248 .ok_or_else(|| {
2249 acp::Error::resource_not_found(Some(path.display().to_string()))
2250 })?;
2251 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2252 })?;
2253
2254 let buffer = load.await?;
2255
2256 let snapshot = if reuse_shared_snapshot {
2257 this.read_with(cx, |this, _| {
2258 this.shared_buffers.get(&buffer.clone()).cloned()
2259 })
2260 .log_err()
2261 .flatten()
2262 } else {
2263 None
2264 };
2265
2266 let snapshot = if let Some(snapshot) = snapshot {
2267 snapshot
2268 } else {
2269 action_log.update(cx, |action_log, cx| {
2270 action_log.buffer_read(buffer.clone(), cx);
2271 });
2272
2273 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2274 this.update(cx, |this, _| {
2275 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2276 })?;
2277 snapshot
2278 };
2279
2280 let max_point = snapshot.max_point();
2281 let start_position = Point::new(line, 0);
2282
2283 if start_position > max_point {
2284 return Err(acp::Error::invalid_params().data(format!(
2285 "Attempting to read beyond the end of the file, line {}:{}",
2286 max_point.row + 1,
2287 max_point.column
2288 )));
2289 }
2290
2291 let start = snapshot.anchor_before(start_position);
2292 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2293
2294 project.update(cx, |project, cx| {
2295 project.set_agent_location(
2296 Some(AgentLocation {
2297 buffer: buffer.downgrade(),
2298 position: start,
2299 }),
2300 cx,
2301 );
2302 });
2303
2304 Ok(snapshot.text_for_range(start..end).collect::<String>())
2305 })
2306 }
2307
2308 pub fn write_text_file(
2309 &self,
2310 path: PathBuf,
2311 content: String,
2312 cx: &mut Context<Self>,
2313 ) -> Task<Result<()>> {
2314 let project = self.project.clone();
2315 let action_log = self.action_log.clone();
2316 cx.spawn(async move |this, cx| {
2317 let load = project.update(cx, |project, cx| {
2318 let path = project
2319 .project_path_for_absolute_path(&path, cx)
2320 .context("invalid path")?;
2321 anyhow::Ok(project.open_buffer(path, cx))
2322 });
2323 let buffer = load?.await?;
2324 let snapshot = this.update(cx, |this, cx| {
2325 this.shared_buffers
2326 .get(&buffer)
2327 .cloned()
2328 .unwrap_or_else(|| buffer.read(cx).snapshot())
2329 })?;
2330 let edits = cx
2331 .background_executor()
2332 .spawn(async move {
2333 let old_text = snapshot.text();
2334 text_diff(old_text.as_str(), &content)
2335 .into_iter()
2336 .map(|(range, replacement)| {
2337 (
2338 snapshot.anchor_after(range.start)
2339 ..snapshot.anchor_before(range.end),
2340 replacement,
2341 )
2342 })
2343 .collect::<Vec<_>>()
2344 })
2345 .await;
2346
2347 project.update(cx, |project, cx| {
2348 project.set_agent_location(
2349 Some(AgentLocation {
2350 buffer: buffer.downgrade(),
2351 position: edits
2352 .last()
2353 .map(|(range, _)| range.end)
2354 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2355 }),
2356 cx,
2357 );
2358 });
2359
2360 let format_on_save = cx.update(|cx| {
2361 action_log.update(cx, |action_log, cx| {
2362 action_log.buffer_read(buffer.clone(), cx);
2363 });
2364
2365 let format_on_save = buffer.update(cx, |buffer, cx| {
2366 buffer.edit(edits, None, cx);
2367
2368 let settings = language::language_settings::language_settings(
2369 buffer.language().map(|l| l.name()),
2370 buffer.file(),
2371 cx,
2372 );
2373
2374 settings.format_on_save != FormatOnSave::Off
2375 });
2376 action_log.update(cx, |action_log, cx| {
2377 action_log.buffer_edited(buffer.clone(), cx);
2378 });
2379 format_on_save
2380 });
2381
2382 if format_on_save {
2383 let format_task = project.update(cx, |project, cx| {
2384 project.format(
2385 HashSet::from_iter([buffer.clone()]),
2386 LspFormatTarget::Buffers,
2387 false,
2388 FormatTrigger::Save,
2389 cx,
2390 )
2391 });
2392 format_task.await.log_err();
2393
2394 action_log.update(cx, |action_log, cx| {
2395 action_log.buffer_edited(buffer.clone(), cx);
2396 });
2397 }
2398
2399 project
2400 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2401 .await
2402 })
2403 }
2404
2405 pub fn create_terminal(
2406 &self,
2407 command: String,
2408 args: Vec<String>,
2409 extra_env: Vec<acp::EnvVariable>,
2410 cwd: Option<PathBuf>,
2411 output_byte_limit: Option<u64>,
2412 cx: &mut Context<Self>,
2413 ) -> Task<Result<Entity<Terminal>>> {
2414 let env = match &cwd {
2415 Some(dir) => self.project.update(cx, |project, cx| {
2416 project.environment().update(cx, |env, cx| {
2417 env.directory_environment(dir.as_path().into(), cx)
2418 })
2419 }),
2420 None => Task::ready(None).shared(),
2421 };
2422 let env = cx.spawn(async move |_, _| {
2423 let mut env = env.await.unwrap_or_default();
2424 // Disables paging for `git` and hopefully other commands
2425 env.insert("PAGER".into(), "".into());
2426 for var in extra_env {
2427 env.insert(var.name, var.value);
2428 }
2429 env
2430 });
2431
2432 let project = self.project.clone();
2433 let language_registry = project.read(cx).languages().clone();
2434 let is_windows = project.read(cx).path_style(cx).is_windows();
2435
2436 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2437 let terminal_task = cx.spawn({
2438 let terminal_id = terminal_id.clone();
2439 async move |_this, cx| {
2440 let env = env.await;
2441 let shell = project
2442 .update(cx, |project, cx| {
2443 project
2444 .remote_client()
2445 .and_then(|r| r.read(cx).default_system_shell())
2446 })
2447 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2448 let (task_command, task_args) =
2449 ShellBuilder::new(&Shell::Program(shell), is_windows)
2450 .redirect_stdin_to_dev_null()
2451 .build(Some(command.clone()), &args);
2452 let terminal = project
2453 .update(cx, |project, cx| {
2454 project.create_terminal_task(
2455 task::SpawnInTerminal {
2456 command: Some(task_command),
2457 args: task_args,
2458 cwd: cwd.clone(),
2459 env,
2460 ..Default::default()
2461 },
2462 cx,
2463 )
2464 })
2465 .await?;
2466
2467 anyhow::Ok(cx.new(|cx| {
2468 Terminal::new(
2469 terminal_id,
2470 &format!("{} {}", command, args.join(" ")),
2471 cwd,
2472 output_byte_limit.map(|l| l as usize),
2473 terminal,
2474 language_registry,
2475 cx,
2476 )
2477 }))
2478 }
2479 });
2480
2481 cx.spawn(async move |this, cx| {
2482 let terminal = terminal_task.await?;
2483 this.update(cx, |this, _cx| {
2484 this.terminals.insert(terminal_id, terminal.clone());
2485 terminal
2486 })
2487 })
2488 }
2489
2490 pub fn kill_terminal(
2491 &mut self,
2492 terminal_id: acp::TerminalId,
2493 cx: &mut Context<Self>,
2494 ) -> Result<()> {
2495 self.terminals
2496 .get(&terminal_id)
2497 .context("Terminal not found")?
2498 .update(cx, |terminal, cx| {
2499 terminal.kill(cx);
2500 });
2501
2502 Ok(())
2503 }
2504
2505 pub fn release_terminal(
2506 &mut self,
2507 terminal_id: acp::TerminalId,
2508 cx: &mut Context<Self>,
2509 ) -> Result<()> {
2510 self.terminals
2511 .remove(&terminal_id)
2512 .context("Terminal not found")?
2513 .update(cx, |terminal, cx| {
2514 terminal.kill(cx);
2515 });
2516
2517 Ok(())
2518 }
2519
2520 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2521 self.terminals
2522 .get(&terminal_id)
2523 .context("Terminal not found")
2524 .cloned()
2525 }
2526
2527 pub fn to_markdown(&self, cx: &App) -> String {
2528 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2529 }
2530
2531 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2532 cx.emit(AcpThreadEvent::LoadError(error));
2533 }
2534
2535 pub fn register_terminal_created(
2536 &mut self,
2537 terminal_id: acp::TerminalId,
2538 command_label: String,
2539 working_dir: Option<PathBuf>,
2540 output_byte_limit: Option<u64>,
2541 terminal: Entity<::terminal::Terminal>,
2542 cx: &mut Context<Self>,
2543 ) -> Entity<Terminal> {
2544 let language_registry = self.project.read(cx).languages().clone();
2545
2546 let entity = cx.new(|cx| {
2547 Terminal::new(
2548 terminal_id.clone(),
2549 &command_label,
2550 working_dir.clone(),
2551 output_byte_limit.map(|l| l as usize),
2552 terminal,
2553 language_registry,
2554 cx,
2555 )
2556 });
2557 self.terminals.insert(terminal_id.clone(), entity.clone());
2558 entity
2559 }
2560}
2561
2562fn markdown_for_raw_output(
2563 raw_output: &serde_json::Value,
2564 language_registry: &Arc<LanguageRegistry>,
2565 cx: &mut App,
2566) -> Option<Entity<Markdown>> {
2567 match raw_output {
2568 serde_json::Value::Null => None,
2569 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2570 Markdown::new(
2571 value.to_string().into(),
2572 Some(language_registry.clone()),
2573 None,
2574 cx,
2575 )
2576 })),
2577 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2578 Markdown::new(
2579 value.to_string().into(),
2580 Some(language_registry.clone()),
2581 None,
2582 cx,
2583 )
2584 })),
2585 serde_json::Value::String(value) => Some(cx.new(|cx| {
2586 Markdown::new(
2587 value.clone().into(),
2588 Some(language_registry.clone()),
2589 None,
2590 cx,
2591 )
2592 })),
2593 value => Some(cx.new(|cx| {
2594 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2595
2596 Markdown::new(
2597 format!("```json\n{}\n```", pretty_json).into(),
2598 Some(language_registry.clone()),
2599 None,
2600 cx,
2601 )
2602 })),
2603 }
2604}
2605
2606#[cfg(test)]
2607mod tests {
2608 use super::*;
2609 use anyhow::anyhow;
2610 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2611 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2612 use indoc::indoc;
2613 use project::{FakeFs, Fs};
2614 use rand::{distr, prelude::*};
2615 use serde_json::json;
2616 use settings::SettingsStore;
2617 use smol::stream::StreamExt as _;
2618 use std::{
2619 any::Any,
2620 cell::RefCell,
2621 path::Path,
2622 rc::Rc,
2623 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2624 time::Duration,
2625 };
2626 use util::path;
2627
2628 fn init_test(cx: &mut TestAppContext) {
2629 env_logger::try_init().ok();
2630 cx.update(|cx| {
2631 let settings_store = SettingsStore::test(cx);
2632 cx.set_global(settings_store);
2633 });
2634 }
2635
2636 #[gpui::test]
2637 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2638 init_test(cx);
2639
2640 let fs = FakeFs::new(cx.executor());
2641 let project = Project::test(fs, [], cx).await;
2642 let connection = Rc::new(FakeAgentConnection::new());
2643 let thread = cx
2644 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2645 .await
2646 .unwrap();
2647
2648 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2649
2650 // Send Output BEFORE Created - should be buffered by acp_thread
2651 thread.update(cx, |thread, cx| {
2652 thread.on_terminal_provider_event(
2653 TerminalProviderEvent::Output {
2654 terminal_id: terminal_id.clone(),
2655 data: b"hello buffered".to_vec(),
2656 },
2657 cx,
2658 );
2659 });
2660
2661 // Create a display-only terminal and then send Created
2662 let lower = cx.new(|cx| {
2663 let builder = ::terminal::TerminalBuilder::new_display_only(
2664 ::terminal::terminal_settings::CursorShape::default(),
2665 ::terminal::terminal_settings::AlternateScroll::On,
2666 None,
2667 0,
2668 cx.background_executor(),
2669 PathStyle::local(),
2670 )
2671 .unwrap();
2672 builder.subscribe(cx)
2673 });
2674
2675 thread.update(cx, |thread, cx| {
2676 thread.on_terminal_provider_event(
2677 TerminalProviderEvent::Created {
2678 terminal_id: terminal_id.clone(),
2679 label: "Buffered Test".to_string(),
2680 cwd: None,
2681 output_byte_limit: None,
2682 terminal: lower.clone(),
2683 },
2684 cx,
2685 );
2686 });
2687
2688 // After Created, buffered Output should have been flushed into the renderer
2689 let content = thread.read_with(cx, |thread, cx| {
2690 let term = thread.terminal(terminal_id.clone()).unwrap();
2691 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2692 });
2693
2694 assert!(
2695 content.contains("hello buffered"),
2696 "expected buffered output to render, got: {content}"
2697 );
2698 }
2699
2700 #[gpui::test]
2701 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2702 init_test(cx);
2703
2704 let fs = FakeFs::new(cx.executor());
2705 let project = Project::test(fs, [], cx).await;
2706 let connection = Rc::new(FakeAgentConnection::new());
2707 let thread = cx
2708 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2709 .await
2710 .unwrap();
2711
2712 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2713
2714 // Send Output BEFORE Created
2715 thread.update(cx, |thread, cx| {
2716 thread.on_terminal_provider_event(
2717 TerminalProviderEvent::Output {
2718 terminal_id: terminal_id.clone(),
2719 data: b"pre-exit data".to_vec(),
2720 },
2721 cx,
2722 );
2723 });
2724
2725 // Send Exit BEFORE Created
2726 thread.update(cx, |thread, cx| {
2727 thread.on_terminal_provider_event(
2728 TerminalProviderEvent::Exit {
2729 terminal_id: terminal_id.clone(),
2730 status: acp::TerminalExitStatus::new().exit_code(0),
2731 },
2732 cx,
2733 );
2734 });
2735
2736 // Now create a display-only lower-level terminal and send Created
2737 let lower = cx.new(|cx| {
2738 let builder = ::terminal::TerminalBuilder::new_display_only(
2739 ::terminal::terminal_settings::CursorShape::default(),
2740 ::terminal::terminal_settings::AlternateScroll::On,
2741 None,
2742 0,
2743 cx.background_executor(),
2744 PathStyle::local(),
2745 )
2746 .unwrap();
2747 builder.subscribe(cx)
2748 });
2749
2750 thread.update(cx, |thread, cx| {
2751 thread.on_terminal_provider_event(
2752 TerminalProviderEvent::Created {
2753 terminal_id: terminal_id.clone(),
2754 label: "Buffered Exit Test".to_string(),
2755 cwd: None,
2756 output_byte_limit: None,
2757 terminal: lower.clone(),
2758 },
2759 cx,
2760 );
2761 });
2762
2763 // Output should be present after Created (flushed from buffer)
2764 let content = thread.read_with(cx, |thread, cx| {
2765 let term = thread.terminal(terminal_id.clone()).unwrap();
2766 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2767 });
2768
2769 assert!(
2770 content.contains("pre-exit data"),
2771 "expected pre-exit data to render, got: {content}"
2772 );
2773 }
2774
2775 /// Test that killing a terminal via Terminal::kill properly:
2776 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2777 /// 2. The underlying terminal still has the output that was written before the kill
2778 ///
2779 /// This test verifies that the fix to kill_active_task (which now also kills
2780 /// the shell process in addition to the foreground process) properly allows
2781 /// wait_for_exit to complete instead of hanging indefinitely.
2782 #[cfg(unix)]
2783 #[gpui::test]
2784 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2785 use std::collections::HashMap;
2786 use task::Shell;
2787 use util::shell_builder::ShellBuilder;
2788
2789 init_test(cx);
2790 cx.executor().allow_parking();
2791
2792 let fs = FakeFs::new(cx.executor());
2793 let project = Project::test(fs, [], cx).await;
2794 let connection = Rc::new(FakeAgentConnection::new());
2795 let thread = cx
2796 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2797 .await
2798 .unwrap();
2799
2800 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2801
2802 // Create a real PTY terminal that runs a command which prints output then sleeps
2803 // We use printf instead of echo and chain with && sleep to ensure proper execution
2804 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2805 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2806 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2807 &[],
2808 );
2809
2810 let builder = cx
2811 .update(|cx| {
2812 ::terminal::TerminalBuilder::new(
2813 None,
2814 None,
2815 task::Shell::WithArguments {
2816 program,
2817 args,
2818 title_override: None,
2819 },
2820 HashMap::default(),
2821 ::terminal::terminal_settings::CursorShape::default(),
2822 ::terminal::terminal_settings::AlternateScroll::On,
2823 None,
2824 vec![],
2825 0,
2826 false,
2827 0,
2828 Some(completion_tx),
2829 cx,
2830 vec![],
2831 PathStyle::local(),
2832 )
2833 })
2834 .await
2835 .unwrap();
2836
2837 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2838
2839 // Create the acp_thread Terminal wrapper
2840 thread.update(cx, |thread, cx| {
2841 thread.on_terminal_provider_event(
2842 TerminalProviderEvent::Created {
2843 terminal_id: terminal_id.clone(),
2844 label: "printf output_before_kill && sleep 60".to_string(),
2845 cwd: None,
2846 output_byte_limit: None,
2847 terminal: lower_terminal.clone(),
2848 },
2849 cx,
2850 );
2851 });
2852
2853 // Wait for the printf command to execute and produce output
2854 // Use real time since parking is enabled
2855 cx.executor().timer(Duration::from_millis(500)).await;
2856
2857 // Get the acp_thread Terminal and kill it
2858 let wait_for_exit = thread.update(cx, |thread, cx| {
2859 let term = thread.terminals.get(&terminal_id).unwrap();
2860 let wait_for_exit = term.read(cx).wait_for_exit();
2861 term.update(cx, |term, cx| {
2862 term.kill(cx);
2863 });
2864 wait_for_exit
2865 });
2866
2867 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2868 // Before the fix to kill_active_task, this would hang forever because
2869 // only the foreground process was killed, not the shell, so the PTY
2870 // child never exited and wait_for_completed_task never completed.
2871 let exit_result = futures::select! {
2872 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2873 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2874 };
2875
2876 assert!(
2877 exit_result.is_some(),
2878 "wait_for_exit should complete after kill, but it timed out. \
2879 This indicates kill_active_task is not properly killing the shell process."
2880 );
2881
2882 // Give the system a chance to process any pending updates
2883 cx.run_until_parked();
2884
2885 // Verify that the underlying terminal still has the output that was
2886 // written before the kill. This verifies that killing doesn't lose output.
2887 let inner_content = thread.read_with(cx, |thread, cx| {
2888 let term = thread.terminals.get(&terminal_id).unwrap();
2889 term.read(cx).inner().read(cx).get_content()
2890 });
2891
2892 assert!(
2893 inner_content.contains("output_before_kill"),
2894 "Underlying terminal should contain output from before kill, got: {}",
2895 inner_content
2896 );
2897 }
2898
2899 #[gpui::test]
2900 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2901 init_test(cx);
2902
2903 let fs = FakeFs::new(cx.executor());
2904 let project = Project::test(fs, [], cx).await;
2905 let connection = Rc::new(FakeAgentConnection::new());
2906 let thread = cx
2907 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2908 .await
2909 .unwrap();
2910
2911 // Test creating a new user message
2912 thread.update(cx, |thread, cx| {
2913 thread.push_user_content_block(None, "Hello, ".into(), cx);
2914 });
2915
2916 thread.update(cx, |thread, cx| {
2917 assert_eq!(thread.entries.len(), 1);
2918 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2919 assert_eq!(user_msg.id, None);
2920 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2921 } else {
2922 panic!("Expected UserMessage");
2923 }
2924 });
2925
2926 // Test appending to existing user message
2927 let message_1_id = UserMessageId::new();
2928 thread.update(cx, |thread, cx| {
2929 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2930 });
2931
2932 thread.update(cx, |thread, cx| {
2933 assert_eq!(thread.entries.len(), 1);
2934 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2935 assert_eq!(user_msg.id, Some(message_1_id));
2936 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2937 } else {
2938 panic!("Expected UserMessage");
2939 }
2940 });
2941
2942 // Test creating new user message after assistant message
2943 thread.update(cx, |thread, cx| {
2944 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2945 });
2946
2947 let message_2_id = UserMessageId::new();
2948 thread.update(cx, |thread, cx| {
2949 thread.push_user_content_block(
2950 Some(message_2_id.clone()),
2951 "New user message".into(),
2952 cx,
2953 );
2954 });
2955
2956 thread.update(cx, |thread, cx| {
2957 assert_eq!(thread.entries.len(), 3);
2958 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2959 assert_eq!(user_msg.id, Some(message_2_id));
2960 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2961 } else {
2962 panic!("Expected UserMessage at index 2");
2963 }
2964 });
2965 }
2966
2967 #[gpui::test]
2968 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2969 init_test(cx);
2970
2971 let fs = FakeFs::new(cx.executor());
2972 let project = Project::test(fs, [], cx).await;
2973 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2974 |_, thread, mut cx| {
2975 async move {
2976 thread.update(&mut cx, |thread, cx| {
2977 thread
2978 .handle_session_update(
2979 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2980 "Thinking ".into(),
2981 )),
2982 cx,
2983 )
2984 .unwrap();
2985 thread
2986 .handle_session_update(
2987 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2988 "hard!".into(),
2989 )),
2990 cx,
2991 )
2992 .unwrap();
2993 })?;
2994 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2995 }
2996 .boxed_local()
2997 },
2998 ));
2999
3000 let thread = cx
3001 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3002 .await
3003 .unwrap();
3004
3005 thread
3006 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
3007 .await
3008 .unwrap();
3009
3010 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
3011 assert_eq!(
3012 output,
3013 indoc! {r#"
3014 ## User
3015
3016 Hello from Zed!
3017
3018 ## Assistant
3019
3020 <thinking>
3021 Thinking hard!
3022 </thinking>
3023
3024 "#}
3025 );
3026 }
3027
3028 #[gpui::test]
3029 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3030 init_test(cx);
3031
3032 let fs = FakeFs::new(cx.executor());
3033 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3034 .await;
3035 let project = Project::test(fs.clone(), [], cx).await;
3036 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3037 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3038 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3039 move |_, thread, mut cx| {
3040 let read_file_tx = read_file_tx.clone();
3041 async move {
3042 let content = thread
3043 .update(&mut cx, |thread, cx| {
3044 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3045 })
3046 .unwrap()
3047 .await
3048 .unwrap();
3049 assert_eq!(content, "one\ntwo\nthree\n");
3050 read_file_tx.take().unwrap().send(()).unwrap();
3051 thread
3052 .update(&mut cx, |thread, cx| {
3053 thread.write_text_file(
3054 path!("/tmp/foo").into(),
3055 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3056 cx,
3057 )
3058 })
3059 .unwrap()
3060 .await
3061 .unwrap();
3062 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3063 }
3064 .boxed_local()
3065 },
3066 ));
3067
3068 let (worktree, pathbuf) = project
3069 .update(cx, |project, cx| {
3070 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3071 })
3072 .await
3073 .unwrap();
3074 let buffer = project
3075 .update(cx, |project, cx| {
3076 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3077 })
3078 .await
3079 .unwrap();
3080
3081 let thread = cx
3082 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3083 .await
3084 .unwrap();
3085
3086 let request = thread.update(cx, |thread, cx| {
3087 thread.send_raw("Extend the count in /tmp/foo", cx)
3088 });
3089 read_file_rx.await.ok();
3090 buffer.update(cx, |buffer, cx| {
3091 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3092 });
3093 cx.run_until_parked();
3094 assert_eq!(
3095 buffer.read_with(cx, |buffer, _| buffer.text()),
3096 "zero\none\ntwo\nthree\nfour\nfive\n"
3097 );
3098 assert_eq!(
3099 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3100 "zero\none\ntwo\nthree\nfour\nfive\n"
3101 );
3102 request.await.unwrap();
3103 }
3104
3105 #[gpui::test]
3106 async fn test_reading_from_line(cx: &mut TestAppContext) {
3107 init_test(cx);
3108
3109 let fs = FakeFs::new(cx.executor());
3110 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3111 .await;
3112 let project = Project::test(fs.clone(), [], cx).await;
3113 project
3114 .update(cx, |project, cx| {
3115 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3116 })
3117 .await
3118 .unwrap();
3119
3120 let connection = Rc::new(FakeAgentConnection::new());
3121
3122 let thread = cx
3123 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3124 .await
3125 .unwrap();
3126
3127 // Whole file
3128 let content = thread
3129 .update(cx, |thread, cx| {
3130 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3131 })
3132 .await
3133 .unwrap();
3134
3135 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3136
3137 // Only start line
3138 let content = thread
3139 .update(cx, |thread, cx| {
3140 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3141 })
3142 .await
3143 .unwrap();
3144
3145 assert_eq!(content, "three\nfour\n");
3146
3147 // Only limit
3148 let content = thread
3149 .update(cx, |thread, cx| {
3150 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3151 })
3152 .await
3153 .unwrap();
3154
3155 assert_eq!(content, "one\ntwo\n");
3156
3157 // Range
3158 let content = thread
3159 .update(cx, |thread, cx| {
3160 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3161 })
3162 .await
3163 .unwrap();
3164
3165 assert_eq!(content, "two\nthree\n");
3166
3167 // Invalid
3168 let err = thread
3169 .update(cx, |thread, cx| {
3170 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3171 })
3172 .await
3173 .unwrap_err();
3174
3175 assert_eq!(
3176 err.to_string(),
3177 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3178 );
3179 }
3180
3181 #[gpui::test]
3182 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3183 init_test(cx);
3184
3185 let fs = FakeFs::new(cx.executor());
3186 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3187 let project = Project::test(fs.clone(), [], cx).await;
3188 project
3189 .update(cx, |project, cx| {
3190 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3191 })
3192 .await
3193 .unwrap();
3194
3195 let connection = Rc::new(FakeAgentConnection::new());
3196
3197 let thread = cx
3198 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3199 .await
3200 .unwrap();
3201
3202 // Whole file
3203 let content = thread
3204 .update(cx, |thread, cx| {
3205 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3206 })
3207 .await
3208 .unwrap();
3209
3210 assert_eq!(content, "");
3211
3212 // Only start line
3213 let content = thread
3214 .update(cx, |thread, cx| {
3215 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3216 })
3217 .await
3218 .unwrap();
3219
3220 assert_eq!(content, "");
3221
3222 // Only limit
3223 let content = thread
3224 .update(cx, |thread, cx| {
3225 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3226 })
3227 .await
3228 .unwrap();
3229
3230 assert_eq!(content, "");
3231
3232 // Range
3233 let content = thread
3234 .update(cx, |thread, cx| {
3235 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3236 })
3237 .await
3238 .unwrap();
3239
3240 assert_eq!(content, "");
3241
3242 // Invalid
3243 let err = thread
3244 .update(cx, |thread, cx| {
3245 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3246 })
3247 .await
3248 .unwrap_err();
3249
3250 assert_eq!(
3251 err.to_string(),
3252 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3253 );
3254 }
3255 #[gpui::test]
3256 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3257 init_test(cx);
3258
3259 let fs = FakeFs::new(cx.executor());
3260 fs.insert_tree(path!("/tmp"), json!({})).await;
3261 let project = Project::test(fs.clone(), [], cx).await;
3262 project
3263 .update(cx, |project, cx| {
3264 project.find_or_create_worktree(path!("/tmp"), true, cx)
3265 })
3266 .await
3267 .unwrap();
3268
3269 let connection = Rc::new(FakeAgentConnection::new());
3270
3271 let thread = cx
3272 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3273 .await
3274 .unwrap();
3275
3276 // Out of project file
3277 let err = thread
3278 .update(cx, |thread, cx| {
3279 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3280 })
3281 .await
3282 .unwrap_err();
3283
3284 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3285 }
3286
3287 #[gpui::test]
3288 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3289 init_test(cx);
3290
3291 let fs = FakeFs::new(cx.executor());
3292 let project = Project::test(fs, [], cx).await;
3293 let id = acp::ToolCallId::new("test");
3294
3295 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3296 let id = id.clone();
3297 move |_, thread, mut cx| {
3298 let id = id.clone();
3299 async move {
3300 thread
3301 .update(&mut cx, |thread, cx| {
3302 thread.handle_session_update(
3303 acp::SessionUpdate::ToolCall(
3304 acp::ToolCall::new(id.clone(), "Label")
3305 .kind(acp::ToolKind::Fetch)
3306 .status(acp::ToolCallStatus::InProgress),
3307 ),
3308 cx,
3309 )
3310 })
3311 .unwrap()
3312 .unwrap();
3313 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3314 }
3315 .boxed_local()
3316 }
3317 }));
3318
3319 let thread = cx
3320 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3321 .await
3322 .unwrap();
3323
3324 let request = thread.update(cx, |thread, cx| {
3325 thread.send_raw("Fetch https://example.com", cx)
3326 });
3327
3328 run_until_first_tool_call(&thread, cx).await;
3329
3330 thread.read_with(cx, |thread, _| {
3331 assert!(matches!(
3332 thread.entries[1],
3333 AgentThreadEntry::ToolCall(ToolCall {
3334 status: ToolCallStatus::InProgress,
3335 ..
3336 })
3337 ));
3338 });
3339
3340 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3341
3342 thread.read_with(cx, |thread, _| {
3343 assert!(matches!(
3344 &thread.entries[1],
3345 AgentThreadEntry::ToolCall(ToolCall {
3346 status: ToolCallStatus::Canceled,
3347 ..
3348 })
3349 ));
3350 });
3351
3352 thread
3353 .update(cx, |thread, cx| {
3354 thread.handle_session_update(
3355 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3356 id,
3357 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3358 )),
3359 cx,
3360 )
3361 })
3362 .unwrap();
3363
3364 request.await.unwrap();
3365
3366 thread.read_with(cx, |thread, _| {
3367 assert!(matches!(
3368 thread.entries[1],
3369 AgentThreadEntry::ToolCall(ToolCall {
3370 status: ToolCallStatus::Completed,
3371 ..
3372 })
3373 ));
3374 });
3375 }
3376
3377 #[gpui::test]
3378 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3379 init_test(cx);
3380 let fs = FakeFs::new(cx.background_executor.clone());
3381 fs.insert_tree(path!("/test"), json!({})).await;
3382 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3383
3384 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3385 move |_, thread, mut cx| {
3386 async move {
3387 thread
3388 .update(&mut cx, |thread, cx| {
3389 thread.handle_session_update(
3390 acp::SessionUpdate::ToolCall(
3391 acp::ToolCall::new("test", "Label")
3392 .kind(acp::ToolKind::Edit)
3393 .status(acp::ToolCallStatus::Completed)
3394 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3395 "/test/test.txt",
3396 "foo",
3397 ))]),
3398 ),
3399 cx,
3400 )
3401 })
3402 .unwrap()
3403 .unwrap();
3404 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3405 }
3406 .boxed_local()
3407 }
3408 }));
3409
3410 let thread = cx
3411 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3412 .await
3413 .unwrap();
3414
3415 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3416 .await
3417 .unwrap();
3418
3419 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3420 }
3421
3422 #[gpui::test(iterations = 10)]
3423 async fn test_checkpoints(cx: &mut TestAppContext) {
3424 init_test(cx);
3425 let fs = FakeFs::new(cx.background_executor.clone());
3426 fs.insert_tree(
3427 path!("/test"),
3428 json!({
3429 ".git": {}
3430 }),
3431 )
3432 .await;
3433 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3434
3435 let simulate_changes = Arc::new(AtomicBool::new(true));
3436 let next_filename = Arc::new(AtomicUsize::new(0));
3437 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3438 let simulate_changes = simulate_changes.clone();
3439 let next_filename = next_filename.clone();
3440 let fs = fs.clone();
3441 move |request, thread, mut cx| {
3442 let fs = fs.clone();
3443 let simulate_changes = simulate_changes.clone();
3444 let next_filename = next_filename.clone();
3445 async move {
3446 if simulate_changes.load(SeqCst) {
3447 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3448 fs.write(Path::new(&filename), b"").await?;
3449 }
3450
3451 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3452 panic!("expected text content block");
3453 };
3454 thread.update(&mut cx, |thread, cx| {
3455 thread
3456 .handle_session_update(
3457 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3458 content.text.to_uppercase().into(),
3459 )),
3460 cx,
3461 )
3462 .unwrap();
3463 })?;
3464 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3465 }
3466 .boxed_local()
3467 }
3468 }));
3469 let thread = cx
3470 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3471 .await
3472 .unwrap();
3473
3474 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3475 .await
3476 .unwrap();
3477 thread.read_with(cx, |thread, cx| {
3478 assert_eq!(
3479 thread.to_markdown(cx),
3480 indoc! {"
3481 ## User (checkpoint)
3482
3483 Lorem
3484
3485 ## Assistant
3486
3487 LOREM
3488
3489 "}
3490 );
3491 });
3492 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3493
3494 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3495 .await
3496 .unwrap();
3497 thread.read_with(cx, |thread, cx| {
3498 assert_eq!(
3499 thread.to_markdown(cx),
3500 indoc! {"
3501 ## User (checkpoint)
3502
3503 Lorem
3504
3505 ## Assistant
3506
3507 LOREM
3508
3509 ## User (checkpoint)
3510
3511 ipsum
3512
3513 ## Assistant
3514
3515 IPSUM
3516
3517 "}
3518 );
3519 });
3520 assert_eq!(
3521 fs.files(),
3522 vec![
3523 Path::new(path!("/test/file-0")),
3524 Path::new(path!("/test/file-1"))
3525 ]
3526 );
3527
3528 // Checkpoint isn't stored when there are no changes.
3529 simulate_changes.store(false, SeqCst);
3530 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3531 .await
3532 .unwrap();
3533 thread.read_with(cx, |thread, cx| {
3534 assert_eq!(
3535 thread.to_markdown(cx),
3536 indoc! {"
3537 ## User (checkpoint)
3538
3539 Lorem
3540
3541 ## Assistant
3542
3543 LOREM
3544
3545 ## User (checkpoint)
3546
3547 ipsum
3548
3549 ## Assistant
3550
3551 IPSUM
3552
3553 ## User
3554
3555 dolor
3556
3557 ## Assistant
3558
3559 DOLOR
3560
3561 "}
3562 );
3563 });
3564 assert_eq!(
3565 fs.files(),
3566 vec![
3567 Path::new(path!("/test/file-0")),
3568 Path::new(path!("/test/file-1"))
3569 ]
3570 );
3571
3572 // Rewinding the conversation truncates the history and restores the checkpoint.
3573 thread
3574 .update(cx, |thread, cx| {
3575 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3576 panic!("unexpected entries {:?}", thread.entries)
3577 };
3578 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3579 })
3580 .await
3581 .unwrap();
3582 thread.read_with(cx, |thread, cx| {
3583 assert_eq!(
3584 thread.to_markdown(cx),
3585 indoc! {"
3586 ## User (checkpoint)
3587
3588 Lorem
3589
3590 ## Assistant
3591
3592 LOREM
3593
3594 "}
3595 );
3596 });
3597 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3598 }
3599
3600 #[gpui::test]
3601 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3602 use std::sync::atomic::AtomicUsize;
3603 init_test(cx);
3604
3605 let fs = FakeFs::new(cx.executor());
3606 let project = Project::test(fs, None, cx).await;
3607
3608 // Create a connection that simulates refusal after tool result
3609 let prompt_count = Arc::new(AtomicUsize::new(0));
3610 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3611 let prompt_count = prompt_count.clone();
3612 move |_request, thread, mut cx| {
3613 let count = prompt_count.fetch_add(1, SeqCst);
3614 async move {
3615 if count == 0 {
3616 // First prompt: Generate a tool call with result
3617 thread.update(&mut cx, |thread, cx| {
3618 thread
3619 .handle_session_update(
3620 acp::SessionUpdate::ToolCall(
3621 acp::ToolCall::new("tool1", "Test Tool")
3622 .kind(acp::ToolKind::Fetch)
3623 .status(acp::ToolCallStatus::Completed)
3624 .raw_input(serde_json::json!({"query": "test"}))
3625 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3626 ),
3627 cx,
3628 )
3629 .unwrap();
3630 })?;
3631
3632 // Now return refusal because of the tool result
3633 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3634 } else {
3635 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3636 }
3637 }
3638 .boxed_local()
3639 }
3640 }));
3641
3642 let thread = cx
3643 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3644 .await
3645 .unwrap();
3646
3647 // Track if we see a Refusal event
3648 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3649 let saw_refusal_event_captured = saw_refusal_event.clone();
3650 thread.update(cx, |_thread, cx| {
3651 cx.subscribe(
3652 &thread,
3653 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3654 if matches!(event, AcpThreadEvent::Refusal) {
3655 *saw_refusal_event_captured.lock().unwrap() = true;
3656 }
3657 },
3658 )
3659 .detach();
3660 });
3661
3662 // Send a user message - this will trigger tool call and then refusal
3663 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3664 cx.background_executor.spawn(send_task).detach();
3665 cx.run_until_parked();
3666
3667 // Verify that:
3668 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3669 // 2. The user message was NOT truncated
3670 assert!(
3671 *saw_refusal_event.lock().unwrap(),
3672 "Refusal event should be emitted for tool result refusals"
3673 );
3674
3675 thread.read_with(cx, |thread, _| {
3676 let entries = thread.entries();
3677 assert!(entries.len() >= 2, "Should have user message and tool call");
3678
3679 // Verify user message is still there
3680 assert!(
3681 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3682 "User message should not be truncated"
3683 );
3684
3685 // Verify tool call is there with result
3686 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3687 assert!(
3688 tool_call.raw_output.is_some(),
3689 "Tool call should have output"
3690 );
3691 } else {
3692 panic!("Expected tool call at index 1");
3693 }
3694 });
3695 }
3696
3697 #[gpui::test]
3698 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3699 init_test(cx);
3700
3701 let fs = FakeFs::new(cx.executor());
3702 let project = Project::test(fs, None, cx).await;
3703
3704 let refuse_next = Arc::new(AtomicBool::new(false));
3705 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3706 let refuse_next = refuse_next.clone();
3707 move |_request, _thread, _cx| {
3708 if refuse_next.load(SeqCst) {
3709 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3710 .boxed_local()
3711 } else {
3712 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3713 .boxed_local()
3714 }
3715 }
3716 }));
3717
3718 let thread = cx
3719 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3720 .await
3721 .unwrap();
3722
3723 // Track if we see a Refusal event
3724 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3725 let saw_refusal_event_captured = saw_refusal_event.clone();
3726 thread.update(cx, |_thread, cx| {
3727 cx.subscribe(
3728 &thread,
3729 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3730 if matches!(event, AcpThreadEvent::Refusal) {
3731 *saw_refusal_event_captured.lock().unwrap() = true;
3732 }
3733 },
3734 )
3735 .detach();
3736 });
3737
3738 // Send a message that will be refused
3739 refuse_next.store(true, SeqCst);
3740 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3741 .await
3742 .unwrap();
3743
3744 // Verify that a Refusal event WAS emitted for user prompt refusal
3745 assert!(
3746 *saw_refusal_event.lock().unwrap(),
3747 "Refusal event should be emitted for user prompt refusals"
3748 );
3749
3750 // Verify the message was truncated (user prompt refusal)
3751 thread.read_with(cx, |thread, cx| {
3752 assert_eq!(thread.to_markdown(cx), "");
3753 });
3754 }
3755
3756 #[gpui::test]
3757 async fn test_refusal(cx: &mut TestAppContext) {
3758 init_test(cx);
3759 let fs = FakeFs::new(cx.background_executor.clone());
3760 fs.insert_tree(path!("/"), json!({})).await;
3761 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3762
3763 let refuse_next = Arc::new(AtomicBool::new(false));
3764 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3765 let refuse_next = refuse_next.clone();
3766 move |request, thread, mut cx| {
3767 let refuse_next = refuse_next.clone();
3768 async move {
3769 if refuse_next.load(SeqCst) {
3770 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3771 }
3772
3773 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3774 panic!("expected text content block");
3775 };
3776 thread.update(&mut cx, |thread, cx| {
3777 thread
3778 .handle_session_update(
3779 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3780 content.text.to_uppercase().into(),
3781 )),
3782 cx,
3783 )
3784 .unwrap();
3785 })?;
3786 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3787 }
3788 .boxed_local()
3789 }
3790 }));
3791 let thread = cx
3792 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3793 .await
3794 .unwrap();
3795
3796 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3797 .await
3798 .unwrap();
3799 thread.read_with(cx, |thread, cx| {
3800 assert_eq!(
3801 thread.to_markdown(cx),
3802 indoc! {"
3803 ## User
3804
3805 hello
3806
3807 ## Assistant
3808
3809 HELLO
3810
3811 "}
3812 );
3813 });
3814
3815 // Simulate refusing the second message. The message should be truncated
3816 // when a user prompt is refused.
3817 refuse_next.store(true, SeqCst);
3818 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3819 .await
3820 .unwrap();
3821 thread.read_with(cx, |thread, cx| {
3822 assert_eq!(
3823 thread.to_markdown(cx),
3824 indoc! {"
3825 ## User
3826
3827 hello
3828
3829 ## Assistant
3830
3831 HELLO
3832
3833 "}
3834 );
3835 });
3836 }
3837
3838 async fn run_until_first_tool_call(
3839 thread: &Entity<AcpThread>,
3840 cx: &mut TestAppContext,
3841 ) -> usize {
3842 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3843
3844 let subscription = cx.update(|cx| {
3845 cx.subscribe(thread, move |thread, _, cx| {
3846 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3847 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3848 return tx.try_send(ix).unwrap();
3849 }
3850 }
3851 })
3852 });
3853
3854 select! {
3855 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3856 panic!("Timeout waiting for tool call")
3857 }
3858 ix = rx.next().fuse() => {
3859 drop(subscription);
3860 ix.unwrap()
3861 }
3862 }
3863 }
3864
3865 #[derive(Clone, Default)]
3866 struct FakeAgentConnection {
3867 auth_methods: Vec<acp::AuthMethod>,
3868 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3869 on_user_message: Option<
3870 Rc<
3871 dyn Fn(
3872 acp::PromptRequest,
3873 WeakEntity<AcpThread>,
3874 AsyncApp,
3875 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3876 + 'static,
3877 >,
3878 >,
3879 }
3880
3881 impl FakeAgentConnection {
3882 fn new() -> Self {
3883 Self {
3884 auth_methods: Vec::new(),
3885 on_user_message: None,
3886 sessions: Arc::default(),
3887 }
3888 }
3889
3890 #[expect(unused)]
3891 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3892 self.auth_methods = auth_methods;
3893 self
3894 }
3895
3896 fn on_user_message(
3897 mut self,
3898 handler: impl Fn(
3899 acp::PromptRequest,
3900 WeakEntity<AcpThread>,
3901 AsyncApp,
3902 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3903 + 'static,
3904 ) -> Self {
3905 self.on_user_message.replace(Rc::new(handler));
3906 self
3907 }
3908 }
3909
3910 impl AgentConnection for FakeAgentConnection {
3911 fn telemetry_id(&self) -> SharedString {
3912 "fake".into()
3913 }
3914
3915 fn auth_methods(&self) -> &[acp::AuthMethod] {
3916 &self.auth_methods
3917 }
3918
3919 fn new_session(
3920 self: Rc<Self>,
3921 project: Entity<Project>,
3922 _cwd: &Path,
3923 cx: &mut App,
3924 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3925 let session_id = acp::SessionId::new(
3926 rand::rng()
3927 .sample_iter(&distr::Alphanumeric)
3928 .take(7)
3929 .map(char::from)
3930 .collect::<String>(),
3931 );
3932 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3933 let thread = cx.new(|cx| {
3934 AcpThread::new(
3935 None,
3936 "Test",
3937 self.clone(),
3938 project,
3939 action_log,
3940 session_id.clone(),
3941 watch::Receiver::constant(
3942 acp::PromptCapabilities::new()
3943 .image(true)
3944 .audio(true)
3945 .embedded_context(true),
3946 ),
3947 cx,
3948 )
3949 });
3950 self.sessions.lock().insert(session_id, thread.downgrade());
3951 Task::ready(Ok(thread))
3952 }
3953
3954 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3955 if self.auth_methods().iter().any(|m| m.id == method) {
3956 Task::ready(Ok(()))
3957 } else {
3958 Task::ready(Err(anyhow!("Invalid Auth Method")))
3959 }
3960 }
3961
3962 fn prompt(
3963 &self,
3964 _id: Option<UserMessageId>,
3965 params: acp::PromptRequest,
3966 cx: &mut App,
3967 ) -> Task<gpui::Result<acp::PromptResponse>> {
3968 let sessions = self.sessions.lock();
3969 let thread = sessions.get(¶ms.session_id).unwrap();
3970 if let Some(handler) = &self.on_user_message {
3971 let handler = handler.clone();
3972 let thread = thread.clone();
3973 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3974 } else {
3975 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3976 }
3977 }
3978
3979 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
3980
3981 fn truncate(
3982 &self,
3983 session_id: &acp::SessionId,
3984 _cx: &App,
3985 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3986 Some(Rc::new(FakeAgentSessionEditor {
3987 _session_id: session_id.clone(),
3988 }))
3989 }
3990
3991 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3992 self
3993 }
3994 }
3995
3996 struct FakeAgentSessionEditor {
3997 _session_id: acp::SessionId,
3998 }
3999
4000 impl AgentSessionTruncate for FakeAgentSessionEditor {
4001 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
4002 Task::ready(Ok(()))
4003 }
4004 }
4005
4006 #[gpui::test]
4007 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
4008 init_test(cx);
4009
4010 let fs = FakeFs::new(cx.executor());
4011 let project = Project::test(fs, [], cx).await;
4012 let connection = Rc::new(FakeAgentConnection::new());
4013 let thread = cx
4014 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4015 .await
4016 .unwrap();
4017
4018 // Try to update a tool call that doesn't exist
4019 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
4020 thread.update(cx, |thread, cx| {
4021 let result = thread.handle_session_update(
4022 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4023 nonexistent_id.clone(),
4024 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4025 )),
4026 cx,
4027 );
4028
4029 // The update should succeed (not return an error)
4030 assert!(result.is_ok());
4031
4032 // There should now be exactly one entry in the thread
4033 assert_eq!(thread.entries.len(), 1);
4034
4035 // The entry should be a failed tool call
4036 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4037 assert_eq!(tool_call.id, nonexistent_id);
4038 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4039 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4040
4041 // Check that the content contains the error message
4042 assert_eq!(tool_call.content.len(), 1);
4043 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4044 match content_block {
4045 ContentBlock::Markdown { markdown } => {
4046 let markdown_text = markdown.read(cx).source();
4047 assert!(markdown_text.contains("Tool call not found"));
4048 }
4049 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4050 ContentBlock::ResourceLink { .. } => {
4051 panic!("Expected markdown content, got resource link")
4052 }
4053 ContentBlock::Image { .. } => {
4054 panic!("Expected markdown content, got image")
4055 }
4056 }
4057 } else {
4058 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4059 }
4060 } else {
4061 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4062 }
4063 });
4064 }
4065
4066 /// Tests that restoring a checkpoint properly cleans up terminals that were
4067 /// created after that checkpoint, and cancels any in-progress generation.
4068 ///
4069 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4070 /// that were started after that checkpoint should be terminated, and any in-progress
4071 /// AI generation should be canceled.
4072 #[gpui::test]
4073 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4074 init_test(cx);
4075
4076 let fs = FakeFs::new(cx.executor());
4077 let project = Project::test(fs, [], cx).await;
4078 let connection = Rc::new(FakeAgentConnection::new());
4079 let thread = cx
4080 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4081 .await
4082 .unwrap();
4083
4084 // Send first user message to create a checkpoint
4085 cx.update(|cx| {
4086 thread.update(cx, |thread, cx| {
4087 thread.send(vec!["first message".into()], cx)
4088 })
4089 })
4090 .await
4091 .unwrap();
4092
4093 // Send second message (creates another checkpoint) - we'll restore to this one
4094 cx.update(|cx| {
4095 thread.update(cx, |thread, cx| {
4096 thread.send(vec!["second message".into()], cx)
4097 })
4098 })
4099 .await
4100 .unwrap();
4101
4102 // Create 2 terminals BEFORE the checkpoint that have completed running
4103 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4104 let mock_terminal_1 = cx.new(|cx| {
4105 let builder = ::terminal::TerminalBuilder::new_display_only(
4106 ::terminal::terminal_settings::CursorShape::default(),
4107 ::terminal::terminal_settings::AlternateScroll::On,
4108 None,
4109 0,
4110 cx.background_executor(),
4111 PathStyle::local(),
4112 )
4113 .unwrap();
4114 builder.subscribe(cx)
4115 });
4116
4117 thread.update(cx, |thread, cx| {
4118 thread.on_terminal_provider_event(
4119 TerminalProviderEvent::Created {
4120 terminal_id: terminal_id_1.clone(),
4121 label: "echo 'first'".to_string(),
4122 cwd: Some(PathBuf::from("/test")),
4123 output_byte_limit: None,
4124 terminal: mock_terminal_1.clone(),
4125 },
4126 cx,
4127 );
4128 });
4129
4130 thread.update(cx, |thread, cx| {
4131 thread.on_terminal_provider_event(
4132 TerminalProviderEvent::Output {
4133 terminal_id: terminal_id_1.clone(),
4134 data: b"first\n".to_vec(),
4135 },
4136 cx,
4137 );
4138 });
4139
4140 thread.update(cx, |thread, cx| {
4141 thread.on_terminal_provider_event(
4142 TerminalProviderEvent::Exit {
4143 terminal_id: terminal_id_1.clone(),
4144 status: acp::TerminalExitStatus::new().exit_code(0),
4145 },
4146 cx,
4147 );
4148 });
4149
4150 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4151 let mock_terminal_2 = cx.new(|cx| {
4152 let builder = ::terminal::TerminalBuilder::new_display_only(
4153 ::terminal::terminal_settings::CursorShape::default(),
4154 ::terminal::terminal_settings::AlternateScroll::On,
4155 None,
4156 0,
4157 cx.background_executor(),
4158 PathStyle::local(),
4159 )
4160 .unwrap();
4161 builder.subscribe(cx)
4162 });
4163
4164 thread.update(cx, |thread, cx| {
4165 thread.on_terminal_provider_event(
4166 TerminalProviderEvent::Created {
4167 terminal_id: terminal_id_2.clone(),
4168 label: "echo 'second'".to_string(),
4169 cwd: Some(PathBuf::from("/test")),
4170 output_byte_limit: None,
4171 terminal: mock_terminal_2.clone(),
4172 },
4173 cx,
4174 );
4175 });
4176
4177 thread.update(cx, |thread, cx| {
4178 thread.on_terminal_provider_event(
4179 TerminalProviderEvent::Output {
4180 terminal_id: terminal_id_2.clone(),
4181 data: b"second\n".to_vec(),
4182 },
4183 cx,
4184 );
4185 });
4186
4187 thread.update(cx, |thread, cx| {
4188 thread.on_terminal_provider_event(
4189 TerminalProviderEvent::Exit {
4190 terminal_id: terminal_id_2.clone(),
4191 status: acp::TerminalExitStatus::new().exit_code(0),
4192 },
4193 cx,
4194 );
4195 });
4196
4197 // Get the second message ID to restore to
4198 let second_message_id = thread.read_with(cx, |thread, _| {
4199 // At this point we have:
4200 // - Index 0: First user message (with checkpoint)
4201 // - Index 1: Second user message (with checkpoint)
4202 // No assistant responses because FakeAgentConnection just returns EndTurn
4203 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4204 panic!("expected user message at index 1");
4205 };
4206 message.id.clone().unwrap()
4207 });
4208
4209 // Create a terminal AFTER the checkpoint we'll restore to.
4210 // This simulates the AI agent starting a long-running terminal command.
4211 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4212 let mock_terminal = cx.new(|cx| {
4213 let builder = ::terminal::TerminalBuilder::new_display_only(
4214 ::terminal::terminal_settings::CursorShape::default(),
4215 ::terminal::terminal_settings::AlternateScroll::On,
4216 None,
4217 0,
4218 cx.background_executor(),
4219 PathStyle::local(),
4220 )
4221 .unwrap();
4222 builder.subscribe(cx)
4223 });
4224
4225 // Register the terminal as created
4226 thread.update(cx, |thread, cx| {
4227 thread.on_terminal_provider_event(
4228 TerminalProviderEvent::Created {
4229 terminal_id: terminal_id.clone(),
4230 label: "sleep 1000".to_string(),
4231 cwd: Some(PathBuf::from("/test")),
4232 output_byte_limit: None,
4233 terminal: mock_terminal.clone(),
4234 },
4235 cx,
4236 );
4237 });
4238
4239 // Simulate the terminal producing output (still running)
4240 thread.update(cx, |thread, cx| {
4241 thread.on_terminal_provider_event(
4242 TerminalProviderEvent::Output {
4243 terminal_id: terminal_id.clone(),
4244 data: b"terminal is running...\n".to_vec(),
4245 },
4246 cx,
4247 );
4248 });
4249
4250 // Create a tool call entry that references this terminal
4251 // This represents the agent requesting a terminal command
4252 thread.update(cx, |thread, cx| {
4253 thread
4254 .handle_session_update(
4255 acp::SessionUpdate::ToolCall(
4256 acp::ToolCall::new("terminal-tool-1", "Running command")
4257 .kind(acp::ToolKind::Execute)
4258 .status(acp::ToolCallStatus::InProgress)
4259 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4260 terminal_id.clone(),
4261 ))])
4262 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4263 ),
4264 cx,
4265 )
4266 .unwrap();
4267 });
4268
4269 // Verify terminal exists and is in the thread
4270 let terminal_exists_before =
4271 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4272 assert!(
4273 terminal_exists_before,
4274 "Terminal should exist before checkpoint restore"
4275 );
4276
4277 // Verify the terminal's underlying task is still running (not completed)
4278 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4279 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4280 terminal_entity.read_with(cx, |term, _cx| {
4281 term.output().is_none() // output is None means it's still running
4282 })
4283 });
4284 assert!(
4285 terminal_running_before,
4286 "Terminal should be running before checkpoint restore"
4287 );
4288
4289 // Verify we have the expected entries before restore
4290 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4291 assert!(
4292 entry_count_before > 1,
4293 "Should have multiple entries before restore"
4294 );
4295
4296 // Restore the checkpoint to the second message.
4297 // This should:
4298 // 1. Cancel any in-progress generation (via the cancel() call)
4299 // 2. Remove the terminal that was created after that point
4300 thread
4301 .update(cx, |thread, cx| {
4302 thread.restore_checkpoint(second_message_id, cx)
4303 })
4304 .await
4305 .unwrap();
4306
4307 // Verify that no send_task is in progress after restore
4308 // (cancel() clears the send_task)
4309 let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4310 assert!(
4311 !has_send_task_after,
4312 "Should not have a send_task after restore (cancel should have cleared it)"
4313 );
4314
4315 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4316 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4317 assert_eq!(
4318 entry_count, 1,
4319 "Should have 1 entry after restore (only the first user message)"
4320 );
4321
4322 // Verify the 2 completed terminals from before the checkpoint still exist
4323 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4324 thread.terminals.contains_key(&terminal_id_1)
4325 });
4326 assert!(
4327 terminal_1_exists,
4328 "Terminal 1 (from before checkpoint) should still exist"
4329 );
4330
4331 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4332 thread.terminals.contains_key(&terminal_id_2)
4333 });
4334 assert!(
4335 terminal_2_exists,
4336 "Terminal 2 (from before checkpoint) should still exist"
4337 );
4338
4339 // Verify they're still in completed state
4340 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4341 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4342 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4343 });
4344 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4345
4346 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4347 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4348 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4349 });
4350 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4351
4352 // Verify the running terminal (created after checkpoint) was removed
4353 let terminal_3_exists =
4354 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4355 assert!(
4356 !terminal_3_exists,
4357 "Terminal 3 (created after checkpoint) should have been removed"
4358 );
4359
4360 // Verify total count is 2 (the two from before the checkpoint)
4361 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4362 assert_eq!(
4363 terminal_count, 2,
4364 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4365 );
4366 }
4367
4368 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4369 /// even when a new user message is added while the async checkpoint comparison is in progress.
4370 ///
4371 /// This is a regression test for a bug where update_last_checkpoint would fail with
4372 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4373 /// update_last_checkpoint started and when its async closure ran.
4374 #[gpui::test]
4375 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4376 init_test(cx);
4377
4378 let fs = FakeFs::new(cx.executor());
4379 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4380 .await;
4381 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4382
4383 let handler_done = Arc::new(AtomicBool::new(false));
4384 let handler_done_clone = handler_done.clone();
4385 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4386 move |_, _thread, _cx| {
4387 handler_done_clone.store(true, SeqCst);
4388 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4389 },
4390 ));
4391
4392 let thread = cx
4393 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4394 .await
4395 .unwrap();
4396
4397 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4398 let send_task = cx.background_executor.spawn(send_future);
4399
4400 // Tick until handler completes, then a few more to let update_last_checkpoint start
4401 while !handler_done.load(SeqCst) {
4402 cx.executor().tick();
4403 }
4404 for _ in 0..5 {
4405 cx.executor().tick();
4406 }
4407
4408 thread.update(cx, |thread, cx| {
4409 thread.push_entry(
4410 AgentThreadEntry::UserMessage(UserMessage {
4411 id: Some(UserMessageId::new()),
4412 content: ContentBlock::Empty,
4413 chunks: vec!["Injected message (no checkpoint)".into()],
4414 checkpoint: None,
4415 indented: false,
4416 }),
4417 cx,
4418 );
4419 });
4420
4421 cx.run_until_parked();
4422 let result = send_task.await;
4423
4424 assert!(
4425 result.is_ok(),
4426 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4427 result.err()
4428 );
4429 }
4430
4431 /// Tests that when a follow-up message is sent during generation,
4432 /// the first turn completing does NOT clear `running_turn` because
4433 /// it now belongs to the second turn.
4434 #[gpui::test]
4435 async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
4436 init_test(cx);
4437
4438 let fs = FakeFs::new(cx.executor());
4439 let project = Project::test(fs, [], cx).await;
4440
4441 // First handler waits for this signal before completing
4442 let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
4443 let first_complete_rx = RefCell::new(Some(first_complete_rx));
4444
4445 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
4446 move |params, _thread, _cx| {
4447 let first_complete_rx = first_complete_rx.borrow_mut().take();
4448 let is_first = params
4449 .prompt
4450 .iter()
4451 .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
4452
4453 async move {
4454 if is_first {
4455 // First handler waits until signaled
4456 if let Some(rx) = first_complete_rx {
4457 rx.await.ok();
4458 }
4459 }
4460 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
4461 }
4462 .boxed_local()
4463 }
4464 }));
4465
4466 let thread = cx
4467 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4468 .await
4469 .unwrap();
4470
4471 // Send first message (turn_id=1) - handler will block
4472 let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
4473 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
4474
4475 // Send second message (turn_id=2) while first is still blocked
4476 // This calls cancel() which takes turn 1's running_turn and sets turn 2's
4477 let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
4478 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
4479
4480 let running_turn_after_second_send =
4481 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4482 assert_eq!(
4483 running_turn_after_second_send,
4484 Some(2),
4485 "running_turn should be set to turn 2 after sending second message"
4486 );
4487
4488 // Now signal first handler to complete
4489 first_complete_tx.send(()).ok();
4490
4491 // First request completes - should NOT clear running_turn
4492 // because running_turn now belongs to turn 2
4493 first_request.await.unwrap();
4494
4495 let running_turn_after_first =
4496 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4497 assert_eq!(
4498 running_turn_after_first,
4499 Some(2),
4500 "first turn completing should not clear running_turn (belongs to turn 2)"
4501 );
4502
4503 // Second request completes - SHOULD clear running_turn
4504 second_request.await.unwrap();
4505
4506 let running_turn_after_second =
4507 thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4508 assert!(
4509 !running_turn_after_second,
4510 "second turn completing should clear running_turn"
4511 );
4512 }
4513
4514 #[gpui::test]
4515 async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
4516 cx: &mut TestAppContext,
4517 ) {
4518 init_test(cx);
4519
4520 let fs = FakeFs::new(cx.executor());
4521 let project = Project::test(fs, [], cx).await;
4522
4523 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4524 move |_params, thread, mut cx| {
4525 async move {
4526 thread
4527 .update(&mut cx, |thread, cx| {
4528 thread.handle_session_update(
4529 acp::SessionUpdate::ToolCall(
4530 acp::ToolCall::new(
4531 acp::ToolCallId::new("test-tool"),
4532 "Test Tool",
4533 )
4534 .kind(acp::ToolKind::Fetch)
4535 .status(acp::ToolCallStatus::InProgress),
4536 ),
4537 cx,
4538 )
4539 })
4540 .unwrap()
4541 .unwrap();
4542
4543 Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
4544 }
4545 .boxed_local()
4546 },
4547 ));
4548
4549 let thread = cx
4550 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4551 .await
4552 .unwrap();
4553
4554 let response = thread
4555 .update(cx, |thread, cx| thread.send_raw("test message", cx))
4556 .await;
4557
4558 let response = response
4559 .expect("send should succeed")
4560 .expect("should have response");
4561 assert_eq!(
4562 response.stop_reason,
4563 acp::StopReason::Cancelled,
4564 "response should have Cancelled stop_reason"
4565 );
4566
4567 thread.read_with(cx, |thread, _| {
4568 let tool_entry = thread
4569 .entries
4570 .iter()
4571 .find_map(|e| {
4572 if let AgentThreadEntry::ToolCall(call) = e {
4573 Some(call)
4574 } else {
4575 None
4576 }
4577 })
4578 .expect("should have tool call entry");
4579
4580 assert!(
4581 matches!(tool_entry.status, ToolCallStatus::Canceled),
4582 "tool should be marked as Canceled when response is Cancelled, got {:?}",
4583 tool_entry.status
4584 );
4585 });
4586 }
4587}