1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 if self.kind == acp::ToolKind::Execute {
318 for terminal in self.terminals() {
319 terminal.update(cx, |terminal, cx| {
320 terminal.update_command_label(&title, cx);
321 });
322 }
323 }
324 self.label.update(cx, |label, cx| {
325 if self.kind == acp::ToolKind::Execute {
326 label.replace(title, cx);
327 } else if let Some((first_line, _)) = title.split_once("\n") {
328 label.replace(first_line.to_owned() + "…", cx);
329 } else {
330 label.replace(title, cx);
331 }
332 });
333 }
334
335 if let Some(content) = content {
336 let mut new_content_len = content.len();
337 let mut content = content.into_iter();
338
339 // Reuse existing content if we can
340 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
341 let valid_content =
342 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
343 if !valid_content {
344 new_content_len -= 1;
345 }
346 }
347 for new in content {
348 if let Some(new) = ToolCallContent::from_acp(
349 new,
350 language_registry.clone(),
351 path_style,
352 terminals,
353 cx,
354 )? {
355 self.content.push(new);
356 } else {
357 new_content_len -= 1;
358 }
359 }
360 self.content.truncate(new_content_len);
361 }
362
363 if let Some(locations) = locations {
364 self.locations = locations;
365 }
366
367 if let Some(raw_input) = raw_input {
368 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
369 self.raw_input = Some(raw_input);
370 }
371
372 if let Some(raw_output) = raw_output {
373 if self.content.is_empty()
374 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
375 {
376 self.content
377 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
378 markdown,
379 }));
380 }
381 self.raw_output = Some(raw_output);
382 }
383 Ok(())
384 }
385
386 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
387 self.content.iter().filter_map(|content| match content {
388 ToolCallContent::Diff(diff) => Some(diff),
389 ToolCallContent::ContentBlock(_) => None,
390 ToolCallContent::Terminal(_) => None,
391 })
392 }
393
394 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
395 self.content.iter().filter_map(|content| match content {
396 ToolCallContent::Terminal(terminal) => Some(terminal),
397 ToolCallContent::ContentBlock(_) => None,
398 ToolCallContent::Diff(_) => None,
399 })
400 }
401
402 pub fn is_subagent(&self) -> bool {
403 self.tool_name.as_ref().is_some_and(|s| s == "spawn_agent")
404 || self.subagent_session_id.is_some()
405 }
406
407 pub fn to_markdown(&self, cx: &App) -> String {
408 let mut markdown = format!(
409 "**Tool Call: {}**\nStatus: {}\n\n",
410 self.label.read(cx).source(),
411 self.status
412 );
413 for content in &self.content {
414 markdown.push_str(content.to_markdown(cx).as_str());
415 markdown.push_str("\n\n");
416 }
417 markdown
418 }
419
420 async fn resolve_location(
421 location: acp::ToolCallLocation,
422 project: WeakEntity<Project>,
423 cx: &mut AsyncApp,
424 ) -> Option<ResolvedLocation> {
425 let buffer = project
426 .update(cx, |project, cx| {
427 project
428 .project_path_for_absolute_path(&location.path, cx)
429 .map(|path| project.open_buffer(path, cx))
430 })
431 .ok()??;
432 let buffer = buffer.await.log_err()?;
433 let position = buffer.update(cx, |buffer, _| {
434 let snapshot = buffer.snapshot();
435 if let Some(row) = location.line {
436 let column = snapshot.indent_size_for_line(row).len;
437 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
438 snapshot.anchor_before(point)
439 } else {
440 Anchor::min_for_buffer(snapshot.remote_id())
441 }
442 });
443
444 Some(ResolvedLocation { buffer, position })
445 }
446
447 fn resolve_locations(
448 &self,
449 project: Entity<Project>,
450 cx: &mut App,
451 ) -> Task<Vec<Option<ResolvedLocation>>> {
452 let locations = self.locations.clone();
453 project.update(cx, |_, cx| {
454 cx.spawn(async move |project, cx| {
455 let mut new_locations = Vec::new();
456 for location in locations {
457 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
458 }
459 new_locations
460 })
461 })
462 }
463}
464
465// Separate so we can hold a strong reference to the buffer
466// for saving on the thread
467#[derive(Clone, Debug, PartialEq, Eq)]
468struct ResolvedLocation {
469 buffer: Entity<Buffer>,
470 position: Anchor,
471}
472
473impl From<&ResolvedLocation> for AgentLocation {
474 fn from(value: &ResolvedLocation) -> Self {
475 Self {
476 buffer: value.buffer.downgrade(),
477 position: value.position,
478 }
479 }
480}
481
482#[derive(Debug)]
483pub enum ToolCallStatus {
484 /// The tool call hasn't started running yet, but we start showing it to
485 /// the user.
486 Pending,
487 /// The tool call is waiting for confirmation from the user.
488 WaitingForConfirmation {
489 options: PermissionOptions,
490 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
491 },
492 /// The tool call is currently running.
493 InProgress,
494 /// The tool call completed successfully.
495 Completed,
496 /// The tool call failed.
497 Failed,
498 /// The user rejected the tool call.
499 Rejected,
500 /// The user canceled generation so the tool call was canceled.
501 Canceled,
502}
503
504impl From<acp::ToolCallStatus> for ToolCallStatus {
505 fn from(status: acp::ToolCallStatus) -> Self {
506 match status {
507 acp::ToolCallStatus::Pending => Self::Pending,
508 acp::ToolCallStatus::InProgress => Self::InProgress,
509 acp::ToolCallStatus::Completed => Self::Completed,
510 acp::ToolCallStatus::Failed => Self::Failed,
511 _ => Self::Pending,
512 }
513 }
514}
515
516impl Display for ToolCallStatus {
517 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "{}",
521 match self {
522 ToolCallStatus::Pending => "Pending",
523 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
524 ToolCallStatus::InProgress => "In Progress",
525 ToolCallStatus::Completed => "Completed",
526 ToolCallStatus::Failed => "Failed",
527 ToolCallStatus::Rejected => "Rejected",
528 ToolCallStatus::Canceled => "Canceled",
529 }
530 )
531 }
532}
533
534#[derive(Debug, PartialEq, Clone)]
535pub enum ContentBlock {
536 Empty,
537 Markdown { markdown: Entity<Markdown> },
538 ResourceLink { resource_link: acp::ResourceLink },
539 Image { image: Arc<gpui::Image> },
540}
541
542impl ContentBlock {
543 pub fn new(
544 block: acp::ContentBlock,
545 language_registry: &Arc<LanguageRegistry>,
546 path_style: PathStyle,
547 cx: &mut App,
548 ) -> Self {
549 let mut this = Self::Empty;
550 this.append(block, language_registry, path_style, cx);
551 this
552 }
553
554 pub fn new_combined(
555 blocks: impl IntoIterator<Item = acp::ContentBlock>,
556 language_registry: Arc<LanguageRegistry>,
557 path_style: PathStyle,
558 cx: &mut App,
559 ) -> Self {
560 let mut this = Self::Empty;
561 for block in blocks {
562 this.append(block, &language_registry, path_style, cx);
563 }
564 this
565 }
566
567 pub fn append(
568 &mut self,
569 block: acp::ContentBlock,
570 language_registry: &Arc<LanguageRegistry>,
571 path_style: PathStyle,
572 cx: &mut App,
573 ) {
574 match (&mut *self, &block) {
575 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
576 *self = ContentBlock::ResourceLink {
577 resource_link: resource_link.clone(),
578 };
579 }
580 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
581 if let Some(image) = Self::decode_image(image_content) {
582 *self = ContentBlock::Image { image };
583 } else {
584 let new_content = Self::image_md(image_content);
585 *self = Self::create_markdown_block(new_content, language_registry, cx);
586 }
587 }
588 (ContentBlock::Empty, _) => {
589 let new_content = Self::block_string_contents(&block, path_style);
590 *self = Self::create_markdown_block(new_content, language_registry, cx);
591 }
592 (ContentBlock::Markdown { markdown }, _) => {
593 let new_content = Self::block_string_contents(&block, path_style);
594 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
595 }
596 (ContentBlock::ResourceLink { resource_link }, _) => {
597 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
598 let new_content = Self::block_string_contents(&block, path_style);
599 let combined = format!("{}\n{}", existing_content, new_content);
600 *self = Self::create_markdown_block(combined, language_registry, cx);
601 }
602 (ContentBlock::Image { .. }, _) => {
603 let new_content = Self::block_string_contents(&block, path_style);
604 let combined = format!("`Image`\n{}", new_content);
605 *self = Self::create_markdown_block(combined, language_registry, cx);
606 }
607 }
608 }
609
610 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
611 use base64::Engine as _;
612
613 let bytes = base64::engine::general_purpose::STANDARD
614 .decode(image_content.data.as_bytes())
615 .ok()?;
616 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
617 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
618 }
619
620 fn create_markdown_block(
621 content: String,
622 language_registry: &Arc<LanguageRegistry>,
623 cx: &mut App,
624 ) -> ContentBlock {
625 ContentBlock::Markdown {
626 markdown: cx
627 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
628 }
629 }
630
631 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
632 match block {
633 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
634 acp::ContentBlock::ResourceLink(resource_link) => {
635 Self::resource_link_md(&resource_link.uri, path_style)
636 }
637 acp::ContentBlock::Resource(acp::EmbeddedResource {
638 resource:
639 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
640 uri,
641 ..
642 }),
643 ..
644 }) => Self::resource_link_md(uri, path_style),
645 acp::ContentBlock::Image(image) => Self::image_md(image),
646 _ => String::new(),
647 }
648 }
649
650 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
651 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
652 uri.as_link().to_string()
653 } else {
654 uri.to_string()
655 }
656 }
657
658 fn image_md(_image: &acp::ImageContent) -> String {
659 "`Image`".into()
660 }
661
662 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
663 match self {
664 ContentBlock::Empty => "",
665 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
666 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
667 ContentBlock::Image { .. } => "`Image`",
668 }
669 }
670
671 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
672 match self {
673 ContentBlock::Empty => None,
674 ContentBlock::Markdown { markdown } => Some(markdown),
675 ContentBlock::ResourceLink { .. } => None,
676 ContentBlock::Image { .. } => None,
677 }
678 }
679
680 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
681 match self {
682 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
683 _ => None,
684 }
685 }
686
687 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
688 match self {
689 ContentBlock::Image { image } => Some(image),
690 _ => None,
691 }
692 }
693}
694
695#[derive(Debug)]
696pub enum ToolCallContent {
697 ContentBlock(ContentBlock),
698 Diff(Entity<Diff>),
699 Terminal(Entity<Terminal>),
700}
701
702impl ToolCallContent {
703 pub fn from_acp(
704 content: acp::ToolCallContent,
705 language_registry: Arc<LanguageRegistry>,
706 path_style: PathStyle,
707 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
708 cx: &mut App,
709 ) -> Result<Option<Self>> {
710 match content {
711 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
712 Ok(Some(Self::ContentBlock(ContentBlock::new(
713 content,
714 &language_registry,
715 path_style,
716 cx,
717 ))))
718 }
719 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
720 Diff::finalized(
721 diff.path.to_string_lossy().into_owned(),
722 diff.old_text,
723 diff.new_text,
724 language_registry,
725 cx,
726 )
727 })))),
728 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
729 .get(&terminal_id)
730 .cloned()
731 .map(|terminal| Some(Self::Terminal(terminal)))
732 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
733 _ => Ok(None),
734 }
735 }
736
737 pub fn update_from_acp(
738 &mut self,
739 new: acp::ToolCallContent,
740 language_registry: Arc<LanguageRegistry>,
741 path_style: PathStyle,
742 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
743 cx: &mut App,
744 ) -> Result<bool> {
745 let needs_update = match (&self, &new) {
746 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
747 old_diff.read(cx).needs_update(
748 new_diff.old_text.as_deref().unwrap_or(""),
749 &new_diff.new_text,
750 cx,
751 )
752 }
753 _ => true,
754 };
755
756 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
757 if needs_update {
758 *self = update;
759 }
760 Ok(true)
761 } else {
762 Ok(false)
763 }
764 }
765
766 pub fn to_markdown(&self, cx: &App) -> String {
767 match self {
768 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
769 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
770 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
771 }
772 }
773
774 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
775 match self {
776 Self::ContentBlock(content) => content.image(),
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, PartialEq)]
783pub enum ToolCallUpdate {
784 UpdateFields(acp::ToolCallUpdate),
785 UpdateDiff(ToolCallUpdateDiff),
786 UpdateTerminal(ToolCallUpdateTerminal),
787}
788
789impl ToolCallUpdate {
790 fn id(&self) -> &acp::ToolCallId {
791 match self {
792 Self::UpdateFields(update) => &update.tool_call_id,
793 Self::UpdateDiff(diff) => &diff.id,
794 Self::UpdateTerminal(terminal) => &terminal.id,
795 }
796 }
797}
798
799impl From<acp::ToolCallUpdate> for ToolCallUpdate {
800 fn from(update: acp::ToolCallUpdate) -> Self {
801 Self::UpdateFields(update)
802 }
803}
804
805impl From<ToolCallUpdateDiff> for ToolCallUpdate {
806 fn from(diff: ToolCallUpdateDiff) -> Self {
807 Self::UpdateDiff(diff)
808 }
809}
810
811#[derive(Debug, PartialEq)]
812pub struct ToolCallUpdateDiff {
813 pub id: acp::ToolCallId,
814 pub diff: Entity<Diff>,
815}
816
817impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
818 fn from(terminal: ToolCallUpdateTerminal) -> Self {
819 Self::UpdateTerminal(terminal)
820 }
821}
822
823#[derive(Debug, PartialEq)]
824pub struct ToolCallUpdateTerminal {
825 pub id: acp::ToolCallId,
826 pub terminal: Entity<Terminal>,
827}
828
829#[derive(Debug, Default)]
830pub struct Plan {
831 pub entries: Vec<PlanEntry>,
832}
833
834#[derive(Debug)]
835pub struct PlanStats<'a> {
836 pub in_progress_entry: Option<&'a PlanEntry>,
837 pub pending: u32,
838 pub completed: u32,
839}
840
841impl Plan {
842 pub fn is_empty(&self) -> bool {
843 self.entries.is_empty()
844 }
845
846 pub fn stats(&self) -> PlanStats<'_> {
847 let mut stats = PlanStats {
848 in_progress_entry: None,
849 pending: 0,
850 completed: 0,
851 };
852
853 for entry in &self.entries {
854 match &entry.status {
855 acp::PlanEntryStatus::Pending => {
856 stats.pending += 1;
857 }
858 acp::PlanEntryStatus::InProgress => {
859 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
860 }
861 acp::PlanEntryStatus::Completed => {
862 stats.completed += 1;
863 }
864 _ => {}
865 }
866 }
867
868 stats
869 }
870}
871
872#[derive(Debug)]
873pub struct PlanEntry {
874 pub content: Entity<Markdown>,
875 pub priority: acp::PlanEntryPriority,
876 pub status: acp::PlanEntryStatus,
877}
878
879impl PlanEntry {
880 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
881 Self {
882 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
883 priority: entry.priority,
884 status: entry.status,
885 }
886 }
887}
888
889#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
890pub struct TokenUsage {
891 pub max_tokens: u64,
892 pub used_tokens: u64,
893 pub input_tokens: u64,
894 pub output_tokens: u64,
895 pub max_output_tokens: Option<u64>,
896}
897
898pub const TOKEN_USAGE_WARNING_THRESHOLD: f32 = 0.8;
899
900impl TokenUsage {
901 pub fn ratio(&self) -> TokenUsageRatio {
902 #[cfg(debug_assertions)]
903 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
904 .unwrap_or(TOKEN_USAGE_WARNING_THRESHOLD.to_string())
905 .parse()
906 .unwrap();
907 #[cfg(not(debug_assertions))]
908 let warning_threshold: f32 = TOKEN_USAGE_WARNING_THRESHOLD;
909
910 // When the maximum is unknown because there is no selected model,
911 // avoid showing the token limit warning.
912 if self.max_tokens == 0 {
913 TokenUsageRatio::Normal
914 } else if self.used_tokens >= self.max_tokens {
915 TokenUsageRatio::Exceeded
916 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
917 TokenUsageRatio::Warning
918 } else {
919 TokenUsageRatio::Normal
920 }
921 }
922}
923
924#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
925pub enum TokenUsageRatio {
926 Normal,
927 Warning,
928 Exceeded,
929}
930
931#[derive(Debug, Clone)]
932pub struct RetryStatus {
933 pub last_error: SharedString,
934 pub attempt: usize,
935 pub max_attempts: usize,
936 pub started_at: Instant,
937 pub duration: Duration,
938}
939
940struct RunningTurn {
941 id: u32,
942 send_task: Task<()>,
943}
944
945pub struct AcpThread {
946 parent_session_id: Option<acp::SessionId>,
947 title: SharedString,
948 entries: Vec<AgentThreadEntry>,
949 plan: Plan,
950 project: Entity<Project>,
951 action_log: Entity<ActionLog>,
952 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
953 turn_id: u32,
954 running_turn: Option<RunningTurn>,
955 connection: Rc<dyn AgentConnection>,
956 session_id: acp::SessionId,
957 token_usage: Option<TokenUsage>,
958 prompt_capabilities: acp::PromptCapabilities,
959 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
960 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
961 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
962 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
963 had_error: bool,
964}
965
966impl From<&AcpThread> for ActionLogTelemetry {
967 fn from(value: &AcpThread) -> Self {
968 Self {
969 agent_telemetry_id: value.connection().telemetry_id(),
970 session_id: value.session_id.0.clone(),
971 }
972 }
973}
974
975#[derive(Debug)]
976pub enum AcpThreadEvent {
977 NewEntry,
978 TitleUpdated,
979 TokenUsageUpdated,
980 EntryUpdated(usize),
981 EntriesRemoved(Range<usize>),
982 ToolAuthorizationRequested(acp::ToolCallId),
983 ToolAuthorizationReceived(acp::ToolCallId),
984 Retry(RetryStatus),
985 SubagentSpawned(acp::SessionId),
986 Stopped,
987 Error,
988 LoadError(LoadError),
989 PromptCapabilitiesUpdated,
990 Refusal,
991 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
992 ModeUpdated(acp::SessionModeId),
993 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
994}
995
996impl EventEmitter<AcpThreadEvent> for AcpThread {}
997
998#[derive(Debug, Clone)]
999pub enum TerminalProviderEvent {
1000 Created {
1001 terminal_id: acp::TerminalId,
1002 label: String,
1003 cwd: Option<PathBuf>,
1004 output_byte_limit: Option<u64>,
1005 terminal: Entity<::terminal::Terminal>,
1006 },
1007 Output {
1008 terminal_id: acp::TerminalId,
1009 data: Vec<u8>,
1010 },
1011 TitleChanged {
1012 terminal_id: acp::TerminalId,
1013 title: String,
1014 },
1015 Exit {
1016 terminal_id: acp::TerminalId,
1017 status: acp::TerminalExitStatus,
1018 },
1019}
1020
1021#[derive(Debug, Clone)]
1022pub enum TerminalProviderCommand {
1023 WriteInput {
1024 terminal_id: acp::TerminalId,
1025 bytes: Vec<u8>,
1026 },
1027 Resize {
1028 terminal_id: acp::TerminalId,
1029 cols: u16,
1030 rows: u16,
1031 },
1032 Close {
1033 terminal_id: acp::TerminalId,
1034 },
1035}
1036
1037impl AcpThread {
1038 pub fn on_terminal_provider_event(
1039 &mut self,
1040 event: TerminalProviderEvent,
1041 cx: &mut Context<Self>,
1042 ) {
1043 match event {
1044 TerminalProviderEvent::Created {
1045 terminal_id,
1046 label,
1047 cwd,
1048 output_byte_limit,
1049 terminal,
1050 } => {
1051 let entity = self.register_terminal_created(
1052 terminal_id.clone(),
1053 label,
1054 cwd,
1055 output_byte_limit,
1056 terminal,
1057 cx,
1058 );
1059
1060 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1061 for data in chunks.drain(..) {
1062 entity.update(cx, |term, cx| {
1063 term.inner().update(cx, |inner, cx| {
1064 inner.write_output(&data, cx);
1065 })
1066 });
1067 }
1068 }
1069
1070 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1071 entity.update(cx, |_term, cx| {
1072 cx.notify();
1073 });
1074 }
1075
1076 cx.notify();
1077 }
1078 TerminalProviderEvent::Output { terminal_id, data } => {
1079 if let Some(entity) = self.terminals.get(&terminal_id) {
1080 entity.update(cx, |term, cx| {
1081 term.inner().update(cx, |inner, cx| {
1082 inner.write_output(&data, cx);
1083 })
1084 });
1085 } else {
1086 self.pending_terminal_output
1087 .entry(terminal_id)
1088 .or_default()
1089 .push(data);
1090 }
1091 }
1092 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1093 if let Some(entity) = self.terminals.get(&terminal_id) {
1094 entity.update(cx, |term, cx| {
1095 term.inner().update(cx, |inner, cx| {
1096 inner.breadcrumb_text = title;
1097 cx.emit(::terminal::Event::BreadcrumbsChanged);
1098 })
1099 });
1100 }
1101 }
1102 TerminalProviderEvent::Exit {
1103 terminal_id,
1104 status,
1105 } => {
1106 if let Some(entity) = self.terminals.get(&terminal_id) {
1107 entity.update(cx, |_term, cx| {
1108 cx.notify();
1109 });
1110 } else {
1111 self.pending_terminal_exit.insert(terminal_id, status);
1112 }
1113 }
1114 }
1115 }
1116}
1117
1118#[derive(PartialEq, Eq, Debug)]
1119pub enum ThreadStatus {
1120 Idle,
1121 Generating,
1122}
1123
1124#[derive(Debug, Clone)]
1125pub enum LoadError {
1126 Unsupported {
1127 command: SharedString,
1128 current_version: SharedString,
1129 minimum_version: SharedString,
1130 },
1131 FailedToInstall(SharedString),
1132 Exited {
1133 status: ExitStatus,
1134 },
1135 Other(SharedString),
1136}
1137
1138impl Display for LoadError {
1139 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1140 match self {
1141 LoadError::Unsupported {
1142 command: path,
1143 current_version,
1144 minimum_version,
1145 } => {
1146 write!(
1147 f,
1148 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1149 )
1150 }
1151 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1152 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1153 LoadError::Other(msg) => write!(f, "{msg}"),
1154 }
1155 }
1156}
1157
1158impl Error for LoadError {}
1159
1160impl AcpThread {
1161 pub fn new(
1162 parent_session_id: Option<acp::SessionId>,
1163 title: impl Into<SharedString>,
1164 connection: Rc<dyn AgentConnection>,
1165 project: Entity<Project>,
1166 action_log: Entity<ActionLog>,
1167 session_id: acp::SessionId,
1168 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1169 cx: &mut Context<Self>,
1170 ) -> Self {
1171 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1172 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1173 loop {
1174 let caps = prompt_capabilities_rx.recv().await?;
1175 this.update(cx, |this, cx| {
1176 this.prompt_capabilities = caps;
1177 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1178 })?;
1179 }
1180 });
1181
1182 Self {
1183 parent_session_id,
1184 action_log,
1185 shared_buffers: Default::default(),
1186 entries: Default::default(),
1187 plan: Default::default(),
1188 title: title.into(),
1189 project,
1190 running_turn: None,
1191 turn_id: 0,
1192 connection,
1193 session_id,
1194 token_usage: None,
1195 prompt_capabilities,
1196 _observe_prompt_capabilities: task,
1197 terminals: HashMap::default(),
1198 pending_terminal_output: HashMap::default(),
1199 pending_terminal_exit: HashMap::default(),
1200 had_error: false,
1201 }
1202 }
1203
1204 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1205 self.parent_session_id.as_ref()
1206 }
1207
1208 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1209 self.prompt_capabilities.clone()
1210 }
1211
1212 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1213 &self.connection
1214 }
1215
1216 pub fn action_log(&self) -> &Entity<ActionLog> {
1217 &self.action_log
1218 }
1219
1220 pub fn project(&self) -> &Entity<Project> {
1221 &self.project
1222 }
1223
1224 pub fn title(&self) -> SharedString {
1225 self.title.clone()
1226 }
1227
1228 pub fn entries(&self) -> &[AgentThreadEntry] {
1229 &self.entries
1230 }
1231
1232 pub fn session_id(&self) -> &acp::SessionId {
1233 &self.session_id
1234 }
1235
1236 pub fn status(&self) -> ThreadStatus {
1237 if self.running_turn.is_some() {
1238 ThreadStatus::Generating
1239 } else {
1240 ThreadStatus::Idle
1241 }
1242 }
1243
1244 pub fn had_error(&self) -> bool {
1245 self.had_error
1246 }
1247
1248 pub fn is_waiting_for_confirmation(&self) -> bool {
1249 for entry in self.entries.iter().rev() {
1250 match entry {
1251 AgentThreadEntry::UserMessage(_) => return false,
1252 AgentThreadEntry::ToolCall(ToolCall {
1253 status: ToolCallStatus::WaitingForConfirmation { .. },
1254 ..
1255 }) => return true,
1256 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1257 }
1258 }
1259 false
1260 }
1261
1262 pub fn token_usage(&self) -> Option<&TokenUsage> {
1263 self.token_usage.as_ref()
1264 }
1265
1266 pub fn has_pending_edit_tool_calls(&self) -> bool {
1267 for entry in self.entries.iter().rev() {
1268 match entry {
1269 AgentThreadEntry::UserMessage(_) => return false,
1270 AgentThreadEntry::ToolCall(
1271 call @ ToolCall {
1272 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1273 ..
1274 },
1275 ) if call.diffs().next().is_some() => {
1276 return true;
1277 }
1278 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1279 }
1280 }
1281
1282 false
1283 }
1284
1285 pub fn has_in_progress_tool_calls(&self) -> bool {
1286 for entry in self.entries.iter().rev() {
1287 match entry {
1288 AgentThreadEntry::UserMessage(_) => return false,
1289 AgentThreadEntry::ToolCall(ToolCall {
1290 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1291 ..
1292 }) => {
1293 return true;
1294 }
1295 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1296 }
1297 }
1298
1299 false
1300 }
1301
1302 pub fn used_tools_since_last_user_message(&self) -> bool {
1303 for entry in self.entries.iter().rev() {
1304 match entry {
1305 AgentThreadEntry::UserMessage(..) => return false,
1306 AgentThreadEntry::AssistantMessage(..) => continue,
1307 AgentThreadEntry::ToolCall(..) => return true,
1308 }
1309 }
1310
1311 false
1312 }
1313
1314 pub fn handle_session_update(
1315 &mut self,
1316 update: acp::SessionUpdate,
1317 cx: &mut Context<Self>,
1318 ) -> Result<(), acp::Error> {
1319 match update {
1320 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1321 self.push_user_content_block(None, content, cx);
1322 }
1323 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1324 self.push_assistant_content_block(content, false, cx);
1325 }
1326 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1327 self.push_assistant_content_block(content, true, cx);
1328 }
1329 acp::SessionUpdate::ToolCall(tool_call) => {
1330 self.upsert_tool_call(tool_call, cx)?;
1331 }
1332 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1333 self.update_tool_call(tool_call_update, cx)?;
1334 }
1335 acp::SessionUpdate::Plan(plan) => {
1336 self.update_plan(plan, cx);
1337 }
1338 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1339 available_commands,
1340 ..
1341 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1342 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1343 current_mode_id,
1344 ..
1345 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1346 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1347 config_options,
1348 ..
1349 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1350 _ => {}
1351 }
1352 Ok(())
1353 }
1354
1355 pub fn push_user_content_block(
1356 &mut self,
1357 message_id: Option<UserMessageId>,
1358 chunk: acp::ContentBlock,
1359 cx: &mut Context<Self>,
1360 ) {
1361 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1362 }
1363
1364 pub fn push_user_content_block_with_indent(
1365 &mut self,
1366 message_id: Option<UserMessageId>,
1367 chunk: acp::ContentBlock,
1368 indented: bool,
1369 cx: &mut Context<Self>,
1370 ) {
1371 let language_registry = self.project.read(cx).languages().clone();
1372 let path_style = self.project.read(cx).path_style(cx);
1373 let entries_len = self.entries.len();
1374
1375 if let Some(last_entry) = self.entries.last_mut()
1376 && let AgentThreadEntry::UserMessage(UserMessage {
1377 id,
1378 content,
1379 chunks,
1380 indented: existing_indented,
1381 ..
1382 }) = last_entry
1383 && *existing_indented == indented
1384 {
1385 *id = message_id.or(id.take());
1386 content.append(chunk.clone(), &language_registry, path_style, cx);
1387 chunks.push(chunk);
1388 let idx = entries_len - 1;
1389 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1390 } else {
1391 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1392 self.push_entry(
1393 AgentThreadEntry::UserMessage(UserMessage {
1394 id: message_id,
1395 content,
1396 chunks: vec![chunk],
1397 checkpoint: None,
1398 indented,
1399 }),
1400 cx,
1401 );
1402 }
1403 }
1404
1405 pub fn push_assistant_content_block(
1406 &mut self,
1407 chunk: acp::ContentBlock,
1408 is_thought: bool,
1409 cx: &mut Context<Self>,
1410 ) {
1411 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1412 }
1413
1414 pub fn push_assistant_content_block_with_indent(
1415 &mut self,
1416 chunk: acp::ContentBlock,
1417 is_thought: bool,
1418 indented: bool,
1419 cx: &mut Context<Self>,
1420 ) {
1421 let language_registry = self.project.read(cx).languages().clone();
1422 let path_style = self.project.read(cx).path_style(cx);
1423 let entries_len = self.entries.len();
1424 if let Some(last_entry) = self.entries.last_mut()
1425 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1426 chunks,
1427 indented: existing_indented,
1428 }) = last_entry
1429 && *existing_indented == indented
1430 {
1431 let idx = entries_len - 1;
1432 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1433 match (chunks.last_mut(), is_thought) {
1434 (Some(AssistantMessageChunk::Message { block }), false)
1435 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1436 block.append(chunk, &language_registry, path_style, cx)
1437 }
1438 _ => {
1439 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1440 if is_thought {
1441 chunks.push(AssistantMessageChunk::Thought { block })
1442 } else {
1443 chunks.push(AssistantMessageChunk::Message { block })
1444 }
1445 }
1446 }
1447 } else {
1448 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1449 let chunk = if is_thought {
1450 AssistantMessageChunk::Thought { block }
1451 } else {
1452 AssistantMessageChunk::Message { block }
1453 };
1454
1455 self.push_entry(
1456 AgentThreadEntry::AssistantMessage(AssistantMessage {
1457 chunks: vec![chunk],
1458 indented,
1459 }),
1460 cx,
1461 );
1462 }
1463 }
1464
1465 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1466 self.entries.push(entry);
1467 cx.emit(AcpThreadEvent::NewEntry);
1468 }
1469
1470 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1471 self.connection.set_title(&self.session_id, cx).is_some()
1472 }
1473
1474 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1475 if title != self.title {
1476 self.title = title.clone();
1477 cx.emit(AcpThreadEvent::TitleUpdated);
1478 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1479 return set_title.run(title, cx);
1480 }
1481 }
1482 Task::ready(Ok(()))
1483 }
1484
1485 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1486 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1487 }
1488
1489 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1490 self.token_usage = usage;
1491 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1492 }
1493
1494 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1495 cx.emit(AcpThreadEvent::Retry(status));
1496 }
1497
1498 pub fn update_tool_call(
1499 &mut self,
1500 update: impl Into<ToolCallUpdate>,
1501 cx: &mut Context<Self>,
1502 ) -> Result<()> {
1503 let update = update.into();
1504 let languages = self.project.read(cx).languages().clone();
1505 let path_style = self.project.read(cx).path_style(cx);
1506
1507 let ix = match self.index_for_tool_call(update.id()) {
1508 Some(ix) => ix,
1509 None => {
1510 // Tool call not found - create a failed tool call entry
1511 let failed_tool_call = ToolCall {
1512 id: update.id().clone(),
1513 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1514 kind: acp::ToolKind::Fetch,
1515 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1516 "Tool call not found".into(),
1517 &languages,
1518 path_style,
1519 cx,
1520 ))],
1521 status: ToolCallStatus::Failed,
1522 locations: Vec::new(),
1523 resolved_locations: Vec::new(),
1524 raw_input: None,
1525 raw_input_markdown: None,
1526 raw_output: None,
1527 tool_name: None,
1528 subagent_session_id: None,
1529 };
1530 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1531 return Ok(());
1532 }
1533 };
1534 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1535 unreachable!()
1536 };
1537
1538 match update {
1539 ToolCallUpdate::UpdateFields(update) => {
1540 let location_updated = update.fields.locations.is_some();
1541 call.update_fields(
1542 update.fields,
1543 update.meta,
1544 languages,
1545 path_style,
1546 &self.terminals,
1547 cx,
1548 )?;
1549 if location_updated {
1550 self.resolve_locations(update.tool_call_id, cx);
1551 }
1552 }
1553 ToolCallUpdate::UpdateDiff(update) => {
1554 call.content.clear();
1555 call.content.push(ToolCallContent::Diff(update.diff));
1556 }
1557 ToolCallUpdate::UpdateTerminal(update) => {
1558 call.content.clear();
1559 call.content
1560 .push(ToolCallContent::Terminal(update.terminal));
1561 }
1562 }
1563
1564 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1565
1566 Ok(())
1567 }
1568
1569 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1570 pub fn upsert_tool_call(
1571 &mut self,
1572 tool_call: acp::ToolCall,
1573 cx: &mut Context<Self>,
1574 ) -> Result<(), acp::Error> {
1575 let status = tool_call.status.into();
1576 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1577 }
1578
1579 /// Fails if id does not match an existing entry.
1580 pub fn upsert_tool_call_inner(
1581 &mut self,
1582 update: acp::ToolCallUpdate,
1583 status: ToolCallStatus,
1584 cx: &mut Context<Self>,
1585 ) -> Result<(), acp::Error> {
1586 let language_registry = self.project.read(cx).languages().clone();
1587 let path_style = self.project.read(cx).path_style(cx);
1588 let id = update.tool_call_id.clone();
1589
1590 let agent_telemetry_id = self.connection().telemetry_id();
1591 let session = self.session_id();
1592 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1593 let status = if matches!(status, ToolCallStatus::Completed) {
1594 "completed"
1595 } else {
1596 "failed"
1597 };
1598 telemetry::event!(
1599 "Agent Tool Call Completed",
1600 agent_telemetry_id,
1601 session,
1602 status
1603 );
1604 }
1605
1606 if let Some(ix) = self.index_for_tool_call(&id) {
1607 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1608 unreachable!()
1609 };
1610
1611 call.update_fields(
1612 update.fields,
1613 update.meta,
1614 language_registry,
1615 path_style,
1616 &self.terminals,
1617 cx,
1618 )?;
1619 call.status = status;
1620
1621 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1622 } else {
1623 let call = ToolCall::from_acp(
1624 update.try_into()?,
1625 status,
1626 language_registry,
1627 self.project.read(cx).path_style(cx),
1628 &self.terminals,
1629 cx,
1630 )?;
1631 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1632 };
1633
1634 self.resolve_locations(id, cx);
1635 Ok(())
1636 }
1637
1638 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1639 self.entries
1640 .iter()
1641 .enumerate()
1642 .rev()
1643 .find_map(|(index, entry)| {
1644 if let AgentThreadEntry::ToolCall(tool_call) = entry
1645 && &tool_call.id == id
1646 {
1647 Some(index)
1648 } else {
1649 None
1650 }
1651 })
1652 }
1653
1654 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1655 // The tool call we are looking for is typically the last one, or very close to the end.
1656 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1657 self.entries
1658 .iter_mut()
1659 .enumerate()
1660 .rev()
1661 .find_map(|(index, tool_call)| {
1662 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1663 && &tool_call.id == id
1664 {
1665 Some((index, tool_call))
1666 } else {
1667 None
1668 }
1669 })
1670 }
1671
1672 pub fn tool_call(&self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1673 self.entries
1674 .iter()
1675 .enumerate()
1676 .rev()
1677 .find_map(|(index, tool_call)| {
1678 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1679 && &tool_call.id == id
1680 {
1681 Some((index, tool_call))
1682 } else {
1683 None
1684 }
1685 })
1686 }
1687
1688 pub fn tool_call_for_subagent(&self, session_id: &acp::SessionId) -> Option<&ToolCall> {
1689 self.entries.iter().find_map(|entry| match entry {
1690 AgentThreadEntry::ToolCall(tool_call)
1691 if tool_call.subagent_session_id.as_ref() == Some(session_id) =>
1692 {
1693 Some(tool_call)
1694 }
1695 _ => None,
1696 })
1697 }
1698
1699 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1700 let project = self.project.clone();
1701 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1702 return;
1703 };
1704 let task = tool_call.resolve_locations(project, cx);
1705 cx.spawn(async move |this, cx| {
1706 let resolved_locations = task.await;
1707
1708 this.update(cx, |this, cx| {
1709 let project = this.project.clone();
1710
1711 for location in resolved_locations.iter().flatten() {
1712 this.shared_buffers
1713 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1714 }
1715 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1716 return;
1717 };
1718
1719 if let Some(Some(location)) = resolved_locations.last() {
1720 project.update(cx, |project, cx| {
1721 let should_ignore = if let Some(agent_location) = project
1722 .agent_location()
1723 .filter(|agent_location| agent_location.buffer == location.buffer)
1724 {
1725 let snapshot = location.buffer.read(cx).snapshot();
1726 let old_position = agent_location.position.to_point(&snapshot);
1727 let new_position = location.position.to_point(&snapshot);
1728
1729 // ignore this so that when we get updates from the edit tool
1730 // the position doesn't reset to the startof line
1731 old_position.row == new_position.row
1732 && old_position.column > new_position.column
1733 } else {
1734 false
1735 };
1736 if !should_ignore {
1737 project.set_agent_location(Some(location.into()), cx);
1738 }
1739 });
1740 }
1741
1742 let resolved_locations = resolved_locations
1743 .iter()
1744 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1745 .collect::<Vec<_>>();
1746
1747 if tool_call.resolved_locations != resolved_locations {
1748 tool_call.resolved_locations = resolved_locations;
1749 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1750 }
1751 })
1752 })
1753 .detach();
1754 }
1755
1756 pub fn request_tool_call_authorization(
1757 &mut self,
1758 tool_call: acp::ToolCallUpdate,
1759 options: PermissionOptions,
1760 cx: &mut Context<Self>,
1761 ) -> Result<Task<acp::RequestPermissionOutcome>> {
1762 let (tx, rx) = oneshot::channel();
1763
1764 let status = ToolCallStatus::WaitingForConfirmation {
1765 options,
1766 respond_tx: tx,
1767 };
1768
1769 let tool_call_id = tool_call.tool_call_id.clone();
1770 self.upsert_tool_call_inner(tool_call, status, cx)?;
1771 cx.emit(AcpThreadEvent::ToolAuthorizationRequested(
1772 tool_call_id.clone(),
1773 ));
1774
1775 Ok(cx.spawn(async move |this, cx| {
1776 let outcome = match rx.await {
1777 Ok(option) => acp::RequestPermissionOutcome::Selected(
1778 acp::SelectedPermissionOutcome::new(option),
1779 ),
1780 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1781 };
1782 this.update(cx, |_this, cx| {
1783 cx.emit(AcpThreadEvent::ToolAuthorizationReceived(tool_call_id))
1784 })
1785 .ok();
1786 outcome
1787 }))
1788 }
1789
1790 pub fn authorize_tool_call(
1791 &mut self,
1792 id: acp::ToolCallId,
1793 option_id: acp::PermissionOptionId,
1794 option_kind: acp::PermissionOptionKind,
1795 cx: &mut Context<Self>,
1796 ) {
1797 let Some((ix, call)) = self.tool_call_mut(&id) else {
1798 return;
1799 };
1800
1801 let new_status = match option_kind {
1802 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1803 ToolCallStatus::Rejected
1804 }
1805 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1806 ToolCallStatus::InProgress
1807 }
1808 _ => ToolCallStatus::InProgress,
1809 };
1810
1811 let curr_status = mem::replace(&mut call.status, new_status);
1812
1813 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1814 respond_tx.send(option_id).log_err();
1815 } else if cfg!(debug_assertions) {
1816 panic!("tried to authorize an already authorized tool call");
1817 }
1818
1819 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1820 }
1821
1822 pub fn plan(&self) -> &Plan {
1823 &self.plan
1824 }
1825
1826 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1827 let new_entries_len = request.entries.len();
1828 let mut new_entries = request.entries.into_iter();
1829
1830 // Reuse existing markdown to prevent flickering
1831 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1832 let PlanEntry {
1833 content,
1834 priority,
1835 status,
1836 } = old;
1837 content.update(cx, |old, cx| {
1838 old.replace(new.content, cx);
1839 });
1840 *priority = new.priority;
1841 *status = new.status;
1842 }
1843 for new in new_entries {
1844 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1845 }
1846 self.plan.entries.truncate(new_entries_len);
1847
1848 cx.notify();
1849 }
1850
1851 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1852 self.plan
1853 .entries
1854 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1855 cx.notify();
1856 }
1857
1858 #[cfg(any(test, feature = "test-support"))]
1859 pub fn send_raw(
1860 &mut self,
1861 message: &str,
1862 cx: &mut Context<Self>,
1863 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1864 self.send(vec![message.into()], cx)
1865 }
1866
1867 pub fn send(
1868 &mut self,
1869 message: Vec<acp::ContentBlock>,
1870 cx: &mut Context<Self>,
1871 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1872 let block = ContentBlock::new_combined(
1873 message.clone(),
1874 self.project.read(cx).languages().clone(),
1875 self.project.read(cx).path_style(cx),
1876 cx,
1877 );
1878 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1879 let git_store = self.project.read(cx).git_store().clone();
1880
1881 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1882 Some(UserMessageId::new())
1883 } else {
1884 None
1885 };
1886
1887 self.run_turn(cx, async move |this, cx| {
1888 this.update(cx, |this, cx| {
1889 this.push_entry(
1890 AgentThreadEntry::UserMessage(UserMessage {
1891 id: message_id.clone(),
1892 content: block,
1893 chunks: message,
1894 checkpoint: None,
1895 indented: false,
1896 }),
1897 cx,
1898 );
1899 })
1900 .ok();
1901
1902 let old_checkpoint = git_store
1903 .update(cx, |git, cx| git.checkpoint(cx))
1904 .await
1905 .context("failed to get old checkpoint")
1906 .log_err();
1907 this.update(cx, |this, cx| {
1908 if let Some((_ix, message)) = this.last_user_message() {
1909 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1910 git_checkpoint,
1911 show: false,
1912 });
1913 }
1914 this.connection.prompt(message_id, request, cx)
1915 })?
1916 .await
1917 })
1918 }
1919
1920 pub fn can_retry(&self, cx: &App) -> bool {
1921 self.connection.retry(&self.session_id, cx).is_some()
1922 }
1923
1924 pub fn retry(
1925 &mut self,
1926 cx: &mut Context<Self>,
1927 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1928 self.run_turn(cx, async move |this, cx| {
1929 this.update(cx, |this, cx| {
1930 this.connection
1931 .retry(&this.session_id, cx)
1932 .map(|retry| retry.run(cx))
1933 })?
1934 .context("retrying a session is not supported")?
1935 .await
1936 })
1937 }
1938
1939 fn run_turn(
1940 &mut self,
1941 cx: &mut Context<Self>,
1942 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1943 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1944 self.clear_completed_plan_entries(cx);
1945 self.had_error = false;
1946
1947 let (tx, rx) = oneshot::channel();
1948 let cancel_task = self.cancel(cx);
1949
1950 self.turn_id += 1;
1951 let turn_id = self.turn_id;
1952 self.running_turn = Some(RunningTurn {
1953 id: turn_id,
1954 send_task: cx.spawn(async move |this, cx| {
1955 cancel_task.await;
1956 tx.send(f(this, cx).await).ok();
1957 }),
1958 });
1959
1960 cx.spawn(async move |this, cx| {
1961 let response = rx.await;
1962
1963 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1964 .await?;
1965
1966 this.update(cx, |this, cx| {
1967 this.project
1968 .update(cx, |project, cx| project.set_agent_location(None, cx));
1969 let Ok(response) = response else {
1970 // tx dropped, just return
1971 return Ok(None);
1972 };
1973
1974 let is_same_turn = this
1975 .running_turn
1976 .as_ref()
1977 .is_some_and(|turn| turn_id == turn.id);
1978
1979 // If the user submitted a follow up message, running_turn might
1980 // already point to a different turn. Therefore we only want to
1981 // take the task if it's the same turn.
1982 if is_same_turn {
1983 this.running_turn.take();
1984 }
1985
1986 match response {
1987 Ok(r) => {
1988 if r.stop_reason == acp::StopReason::MaxTokens {
1989 this.had_error = true;
1990 cx.emit(AcpThreadEvent::Error);
1991 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1992 return Err(anyhow!("Max tokens reached"));
1993 }
1994
1995 let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
1996 if canceled {
1997 this.mark_pending_tools_as_canceled();
1998 }
1999
2000 // Handle refusal - distinguish between user prompt and tool call refusals
2001 if let acp::StopReason::Refusal = r.stop_reason {
2002 this.had_error = true;
2003 if let Some((user_msg_ix, _)) = this.last_user_message() {
2004 // Check if there's a completed tool call with results after the last user message
2005 // This indicates the refusal is in response to tool output, not the user's prompt
2006 let has_completed_tool_call_after_user_msg =
2007 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
2008 if let AgentThreadEntry::ToolCall(tool_call) = entry {
2009 // Check if the tool call has completed and has output
2010 matches!(tool_call.status, ToolCallStatus::Completed)
2011 && tool_call.raw_output.is_some()
2012 } else {
2013 false
2014 }
2015 });
2016
2017 if has_completed_tool_call_after_user_msg {
2018 // Refusal is due to tool output - don't truncate, just notify
2019 // The model refused based on what the tool returned
2020 cx.emit(AcpThreadEvent::Refusal);
2021 } else {
2022 // User prompt was refused - truncate back to before the user message
2023 let range = user_msg_ix..this.entries.len();
2024 if range.start < range.end {
2025 this.entries.truncate(user_msg_ix);
2026 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2027 }
2028 cx.emit(AcpThreadEvent::Refusal);
2029 }
2030 } else {
2031 // No user message found, treat as general refusal
2032 cx.emit(AcpThreadEvent::Refusal);
2033 }
2034 }
2035
2036 cx.emit(AcpThreadEvent::Stopped);
2037 Ok(Some(r))
2038 }
2039 Err(e) => {
2040 this.had_error = true;
2041 cx.emit(AcpThreadEvent::Error);
2042 log::error!("Error in run turn: {:?}", e);
2043 Err(e)
2044 }
2045 }
2046 })?
2047 })
2048 .boxed()
2049 }
2050
2051 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2052 let Some(turn) = self.running_turn.take() else {
2053 return Task::ready(());
2054 };
2055 self.connection.cancel(&self.session_id, cx);
2056
2057 self.mark_pending_tools_as_canceled();
2058
2059 // Wait for the send task to complete
2060 cx.background_spawn(turn.send_task)
2061 }
2062
2063 fn mark_pending_tools_as_canceled(&mut self) {
2064 for entry in self.entries.iter_mut() {
2065 if let AgentThreadEntry::ToolCall(call) = entry {
2066 let cancel = matches!(
2067 call.status,
2068 ToolCallStatus::Pending
2069 | ToolCallStatus::WaitingForConfirmation { .. }
2070 | ToolCallStatus::InProgress
2071 );
2072
2073 if cancel {
2074 call.status = ToolCallStatus::Canceled;
2075 }
2076 }
2077 }
2078 }
2079
2080 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2081 pub fn restore_checkpoint(
2082 &mut self,
2083 id: UserMessageId,
2084 cx: &mut Context<Self>,
2085 ) -> Task<Result<()>> {
2086 let Some((_, message)) = self.user_message_mut(&id) else {
2087 return Task::ready(Err(anyhow!("message not found")));
2088 };
2089
2090 let checkpoint = message
2091 .checkpoint
2092 .as_ref()
2093 .map(|c| c.git_checkpoint.clone());
2094
2095 // Cancel any in-progress generation before restoring
2096 let cancel_task = self.cancel(cx);
2097 let rewind = self.rewind(id.clone(), cx);
2098 let git_store = self.project.read(cx).git_store().clone();
2099
2100 cx.spawn(async move |_, cx| {
2101 cancel_task.await;
2102 rewind.await?;
2103 if let Some(checkpoint) = checkpoint {
2104 git_store
2105 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2106 .await?;
2107 }
2108
2109 Ok(())
2110 })
2111 }
2112
2113 /// Rewinds this thread to before the entry at `index`, removing it and all
2114 /// subsequent entries while rejecting any action_log changes made from that point.
2115 /// Unlike `restore_checkpoint`, this method does not restore from git.
2116 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2117 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2118 return Task::ready(Err(anyhow!("not supported")));
2119 };
2120
2121 let telemetry = ActionLogTelemetry::from(&*self);
2122 cx.spawn(async move |this, cx| {
2123 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2124 this.update(cx, |this, cx| {
2125 if let Some((ix, _)) = this.user_message_mut(&id) {
2126 // Collect all terminals from entries that will be removed
2127 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2128 .iter()
2129 .flat_map(|entry| entry.terminals())
2130 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2131 .collect();
2132
2133 let range = ix..this.entries.len();
2134 this.entries.truncate(ix);
2135 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2136
2137 // Kill and remove the terminals
2138 for terminal_id in terminals_to_remove {
2139 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2140 terminal.update(cx, |terminal, cx| {
2141 terminal.kill(cx);
2142 });
2143 }
2144 }
2145 }
2146 this.action_log().update(cx, |action_log, cx| {
2147 action_log.reject_all_edits(Some(telemetry), cx)
2148 })
2149 })?
2150 .await;
2151 Ok(())
2152 })
2153 }
2154
2155 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2156 let git_store = self.project.read(cx).git_store().clone();
2157
2158 let Some((_, message)) = self.last_user_message() else {
2159 return Task::ready(Ok(()));
2160 };
2161 let Some(user_message_id) = message.id.clone() else {
2162 return Task::ready(Ok(()));
2163 };
2164 let Some(checkpoint) = message.checkpoint.as_ref() else {
2165 return Task::ready(Ok(()));
2166 };
2167 let old_checkpoint = checkpoint.git_checkpoint.clone();
2168
2169 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2170 cx.spawn(async move |this, cx| {
2171 let Some(new_checkpoint) = new_checkpoint
2172 .await
2173 .context("failed to get new checkpoint")
2174 .log_err()
2175 else {
2176 return Ok(());
2177 };
2178
2179 let equal = git_store
2180 .update(cx, |git, cx| {
2181 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2182 })
2183 .await
2184 .unwrap_or(true);
2185
2186 this.update(cx, |this, cx| {
2187 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2188 if let Some(checkpoint) = message.checkpoint.as_mut() {
2189 checkpoint.show = !equal;
2190 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2191 }
2192 }
2193 })?;
2194
2195 Ok(())
2196 })
2197 }
2198
2199 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2200 self.entries
2201 .iter_mut()
2202 .enumerate()
2203 .rev()
2204 .find_map(|(ix, entry)| {
2205 if let AgentThreadEntry::UserMessage(message) = entry {
2206 Some((ix, message))
2207 } else {
2208 None
2209 }
2210 })
2211 }
2212
2213 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2214 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2215 if let AgentThreadEntry::UserMessage(message) = entry {
2216 if message.id.as_ref() == Some(id) {
2217 Some((ix, message))
2218 } else {
2219 None
2220 }
2221 } else {
2222 None
2223 }
2224 })
2225 }
2226
2227 pub fn read_text_file(
2228 &self,
2229 path: PathBuf,
2230 line: Option<u32>,
2231 limit: Option<u32>,
2232 reuse_shared_snapshot: bool,
2233 cx: &mut Context<Self>,
2234 ) -> Task<Result<String, acp::Error>> {
2235 // Args are 1-based, move to 0-based
2236 let line = line.unwrap_or_default().saturating_sub(1);
2237 let limit = limit.unwrap_or(u32::MAX);
2238 let project = self.project.clone();
2239 let action_log = self.action_log.clone();
2240 cx.spawn(async move |this, cx| {
2241 let load = project.update(cx, |project, cx| {
2242 let path = project
2243 .project_path_for_absolute_path(&path, cx)
2244 .ok_or_else(|| {
2245 acp::Error::resource_not_found(Some(path.display().to_string()))
2246 })?;
2247 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2248 })?;
2249
2250 let buffer = load.await?;
2251
2252 let snapshot = if reuse_shared_snapshot {
2253 this.read_with(cx, |this, _| {
2254 this.shared_buffers.get(&buffer.clone()).cloned()
2255 })
2256 .log_err()
2257 .flatten()
2258 } else {
2259 None
2260 };
2261
2262 let snapshot = if let Some(snapshot) = snapshot {
2263 snapshot
2264 } else {
2265 action_log.update(cx, |action_log, cx| {
2266 action_log.buffer_read(buffer.clone(), cx);
2267 });
2268
2269 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2270 this.update(cx, |this, _| {
2271 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2272 })?;
2273 snapshot
2274 };
2275
2276 let max_point = snapshot.max_point();
2277 let start_position = Point::new(line, 0);
2278
2279 if start_position > max_point {
2280 return Err(acp::Error::invalid_params().data(format!(
2281 "Attempting to read beyond the end of the file, line {}:{}",
2282 max_point.row + 1,
2283 max_point.column
2284 )));
2285 }
2286
2287 let start = snapshot.anchor_before(start_position);
2288 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2289
2290 project.update(cx, |project, cx| {
2291 project.set_agent_location(
2292 Some(AgentLocation {
2293 buffer: buffer.downgrade(),
2294 position: start,
2295 }),
2296 cx,
2297 );
2298 });
2299
2300 Ok(snapshot.text_for_range(start..end).collect::<String>())
2301 })
2302 }
2303
2304 pub fn write_text_file(
2305 &self,
2306 path: PathBuf,
2307 content: String,
2308 cx: &mut Context<Self>,
2309 ) -> Task<Result<()>> {
2310 let project = self.project.clone();
2311 let action_log = self.action_log.clone();
2312 cx.spawn(async move |this, cx| {
2313 let load = project.update(cx, |project, cx| {
2314 let path = project
2315 .project_path_for_absolute_path(&path, cx)
2316 .context("invalid path")?;
2317 anyhow::Ok(project.open_buffer(path, cx))
2318 });
2319 let buffer = load?.await?;
2320 let snapshot = this.update(cx, |this, cx| {
2321 this.shared_buffers
2322 .get(&buffer)
2323 .cloned()
2324 .unwrap_or_else(|| buffer.read(cx).snapshot())
2325 })?;
2326 let edits = cx
2327 .background_executor()
2328 .spawn(async move {
2329 let old_text = snapshot.text();
2330 text_diff(old_text.as_str(), &content)
2331 .into_iter()
2332 .map(|(range, replacement)| {
2333 (snapshot.anchor_range_around(range), replacement)
2334 })
2335 .collect::<Vec<_>>()
2336 })
2337 .await;
2338
2339 project.update(cx, |project, cx| {
2340 project.set_agent_location(
2341 Some(AgentLocation {
2342 buffer: buffer.downgrade(),
2343 position: edits
2344 .last()
2345 .map(|(range, _)| range.end)
2346 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2347 }),
2348 cx,
2349 );
2350 });
2351
2352 let format_on_save = cx.update(|cx| {
2353 action_log.update(cx, |action_log, cx| {
2354 action_log.buffer_read(buffer.clone(), cx);
2355 });
2356
2357 let format_on_save = buffer.update(cx, |buffer, cx| {
2358 buffer.edit(edits, None, cx);
2359
2360 let settings = language::language_settings::language_settings(
2361 buffer.language().map(|l| l.name()),
2362 buffer.file(),
2363 cx,
2364 );
2365
2366 settings.format_on_save != FormatOnSave::Off
2367 });
2368 action_log.update(cx, |action_log, cx| {
2369 action_log.buffer_edited(buffer.clone(), cx);
2370 });
2371 format_on_save
2372 });
2373
2374 if format_on_save {
2375 let format_task = project.update(cx, |project, cx| {
2376 project.format(
2377 HashSet::from_iter([buffer.clone()]),
2378 LspFormatTarget::Buffers,
2379 false,
2380 FormatTrigger::Save,
2381 cx,
2382 )
2383 });
2384 format_task.await.log_err();
2385
2386 action_log.update(cx, |action_log, cx| {
2387 action_log.buffer_edited(buffer.clone(), cx);
2388 });
2389 }
2390
2391 project
2392 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2393 .await
2394 })
2395 }
2396
2397 pub fn create_terminal(
2398 &self,
2399 command: String,
2400 args: Vec<String>,
2401 extra_env: Vec<acp::EnvVariable>,
2402 cwd: Option<PathBuf>,
2403 output_byte_limit: Option<u64>,
2404 cx: &mut Context<Self>,
2405 ) -> Task<Result<Entity<Terminal>>> {
2406 let env = match &cwd {
2407 Some(dir) => self.project.update(cx, |project, cx| {
2408 project.environment().update(cx, |env, cx| {
2409 env.directory_environment(dir.as_path().into(), cx)
2410 })
2411 }),
2412 None => Task::ready(None).shared(),
2413 };
2414 let env = cx.spawn(async move |_, _| {
2415 let mut env = env.await.unwrap_or_default();
2416 // Disables paging for `git` and hopefully other commands
2417 env.insert("PAGER".into(), "".into());
2418 for var in extra_env {
2419 env.insert(var.name, var.value);
2420 }
2421 env
2422 });
2423
2424 let project = self.project.clone();
2425 let language_registry = project.read(cx).languages().clone();
2426 let is_windows = project.read(cx).path_style(cx).is_windows();
2427
2428 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2429 let terminal_task = cx.spawn({
2430 let terminal_id = terminal_id.clone();
2431 async move |_this, cx| {
2432 let env = env.await;
2433 let shell = project
2434 .update(cx, |project, cx| {
2435 project
2436 .remote_client()
2437 .and_then(|r| r.read(cx).default_system_shell())
2438 })
2439 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2440 let (task_command, task_args) =
2441 ShellBuilder::new(&Shell::Program(shell), is_windows)
2442 .redirect_stdin_to_dev_null()
2443 .build(Some(command.clone()), &args);
2444 let terminal = project
2445 .update(cx, |project, cx| {
2446 project.create_terminal_task(
2447 task::SpawnInTerminal {
2448 command: Some(task_command),
2449 args: task_args,
2450 cwd: cwd.clone(),
2451 env,
2452 ..Default::default()
2453 },
2454 cx,
2455 )
2456 })
2457 .await?;
2458
2459 anyhow::Ok(cx.new(|cx| {
2460 Terminal::new(
2461 terminal_id,
2462 &format!("{} {}", command, args.join(" ")),
2463 cwd,
2464 output_byte_limit.map(|l| l as usize),
2465 terminal,
2466 language_registry,
2467 cx,
2468 )
2469 }))
2470 }
2471 });
2472
2473 cx.spawn(async move |this, cx| {
2474 let terminal = terminal_task.await?;
2475 this.update(cx, |this, _cx| {
2476 this.terminals.insert(terminal_id, terminal.clone());
2477 terminal
2478 })
2479 })
2480 }
2481
2482 pub fn kill_terminal(
2483 &mut self,
2484 terminal_id: acp::TerminalId,
2485 cx: &mut Context<Self>,
2486 ) -> Result<()> {
2487 self.terminals
2488 .get(&terminal_id)
2489 .context("Terminal not found")?
2490 .update(cx, |terminal, cx| {
2491 terminal.kill(cx);
2492 });
2493
2494 Ok(())
2495 }
2496
2497 pub fn release_terminal(
2498 &mut self,
2499 terminal_id: acp::TerminalId,
2500 cx: &mut Context<Self>,
2501 ) -> Result<()> {
2502 self.terminals
2503 .remove(&terminal_id)
2504 .context("Terminal not found")?
2505 .update(cx, |terminal, cx| {
2506 terminal.kill(cx);
2507 });
2508
2509 Ok(())
2510 }
2511
2512 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2513 self.terminals
2514 .get(&terminal_id)
2515 .context("Terminal not found")
2516 .cloned()
2517 }
2518
2519 pub fn to_markdown(&self, cx: &App) -> String {
2520 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2521 }
2522
2523 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2524 cx.emit(AcpThreadEvent::LoadError(error));
2525 }
2526
2527 pub fn register_terminal_created(
2528 &mut self,
2529 terminal_id: acp::TerminalId,
2530 command_label: String,
2531 working_dir: Option<PathBuf>,
2532 output_byte_limit: Option<u64>,
2533 terminal: Entity<::terminal::Terminal>,
2534 cx: &mut Context<Self>,
2535 ) -> Entity<Terminal> {
2536 let language_registry = self.project.read(cx).languages().clone();
2537
2538 let entity = cx.new(|cx| {
2539 Terminal::new(
2540 terminal_id.clone(),
2541 &command_label,
2542 working_dir.clone(),
2543 output_byte_limit.map(|l| l as usize),
2544 terminal,
2545 language_registry,
2546 cx,
2547 )
2548 });
2549 self.terminals.insert(terminal_id.clone(), entity.clone());
2550 entity
2551 }
2552}
2553
2554fn markdown_for_raw_output(
2555 raw_output: &serde_json::Value,
2556 language_registry: &Arc<LanguageRegistry>,
2557 cx: &mut App,
2558) -> Option<Entity<Markdown>> {
2559 match raw_output {
2560 serde_json::Value::Null => None,
2561 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2562 Markdown::new(
2563 value.to_string().into(),
2564 Some(language_registry.clone()),
2565 None,
2566 cx,
2567 )
2568 })),
2569 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2570 Markdown::new(
2571 value.to_string().into(),
2572 Some(language_registry.clone()),
2573 None,
2574 cx,
2575 )
2576 })),
2577 serde_json::Value::String(value) => Some(cx.new(|cx| {
2578 Markdown::new(
2579 value.clone().into(),
2580 Some(language_registry.clone()),
2581 None,
2582 cx,
2583 )
2584 })),
2585 value => Some(cx.new(|cx| {
2586 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2587
2588 Markdown::new(
2589 format!("```json\n{}\n```", pretty_json).into(),
2590 Some(language_registry.clone()),
2591 None,
2592 cx,
2593 )
2594 })),
2595 }
2596}
2597
2598#[cfg(test)]
2599mod tests {
2600 use super::*;
2601 use anyhow::anyhow;
2602 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2603 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2604 use indoc::indoc;
2605 use project::{FakeFs, Fs};
2606 use rand::{distr, prelude::*};
2607 use serde_json::json;
2608 use settings::SettingsStore;
2609 use smol::stream::StreamExt as _;
2610 use std::{
2611 any::Any,
2612 cell::RefCell,
2613 path::Path,
2614 rc::Rc,
2615 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2616 time::Duration,
2617 };
2618 use util::path;
2619
2620 fn init_test(cx: &mut TestAppContext) {
2621 env_logger::try_init().ok();
2622 cx.update(|cx| {
2623 let settings_store = SettingsStore::test(cx);
2624 cx.set_global(settings_store);
2625 });
2626 }
2627
2628 #[gpui::test]
2629 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2630 init_test(cx);
2631
2632 let fs = FakeFs::new(cx.executor());
2633 let project = Project::test(fs, [], cx).await;
2634 let connection = Rc::new(FakeAgentConnection::new());
2635 let thread = cx
2636 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2637 .await
2638 .unwrap();
2639
2640 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2641
2642 // Send Output BEFORE Created - should be buffered by acp_thread
2643 thread.update(cx, |thread, cx| {
2644 thread.on_terminal_provider_event(
2645 TerminalProviderEvent::Output {
2646 terminal_id: terminal_id.clone(),
2647 data: b"hello buffered".to_vec(),
2648 },
2649 cx,
2650 );
2651 });
2652
2653 // Create a display-only terminal and then send Created
2654 let lower = cx.new(|cx| {
2655 let builder = ::terminal::TerminalBuilder::new_display_only(
2656 ::terminal::terminal_settings::CursorShape::default(),
2657 ::terminal::terminal_settings::AlternateScroll::On,
2658 None,
2659 0,
2660 cx.background_executor(),
2661 PathStyle::local(),
2662 )
2663 .unwrap();
2664 builder.subscribe(cx)
2665 });
2666
2667 thread.update(cx, |thread, cx| {
2668 thread.on_terminal_provider_event(
2669 TerminalProviderEvent::Created {
2670 terminal_id: terminal_id.clone(),
2671 label: "Buffered Test".to_string(),
2672 cwd: None,
2673 output_byte_limit: None,
2674 terminal: lower.clone(),
2675 },
2676 cx,
2677 );
2678 });
2679
2680 // After Created, buffered Output should have been flushed into the renderer
2681 let content = thread.read_with(cx, |thread, cx| {
2682 let term = thread.terminal(terminal_id.clone()).unwrap();
2683 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2684 });
2685
2686 assert!(
2687 content.contains("hello buffered"),
2688 "expected buffered output to render, got: {content}"
2689 );
2690 }
2691
2692 #[gpui::test]
2693 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2694 init_test(cx);
2695
2696 let fs = FakeFs::new(cx.executor());
2697 let project = Project::test(fs, [], cx).await;
2698 let connection = Rc::new(FakeAgentConnection::new());
2699 let thread = cx
2700 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2701 .await
2702 .unwrap();
2703
2704 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2705
2706 // Send Output BEFORE Created
2707 thread.update(cx, |thread, cx| {
2708 thread.on_terminal_provider_event(
2709 TerminalProviderEvent::Output {
2710 terminal_id: terminal_id.clone(),
2711 data: b"pre-exit data".to_vec(),
2712 },
2713 cx,
2714 );
2715 });
2716
2717 // Send Exit BEFORE Created
2718 thread.update(cx, |thread, cx| {
2719 thread.on_terminal_provider_event(
2720 TerminalProviderEvent::Exit {
2721 terminal_id: terminal_id.clone(),
2722 status: acp::TerminalExitStatus::new().exit_code(0),
2723 },
2724 cx,
2725 );
2726 });
2727
2728 // Now create a display-only lower-level terminal and send Created
2729 let lower = cx.new(|cx| {
2730 let builder = ::terminal::TerminalBuilder::new_display_only(
2731 ::terminal::terminal_settings::CursorShape::default(),
2732 ::terminal::terminal_settings::AlternateScroll::On,
2733 None,
2734 0,
2735 cx.background_executor(),
2736 PathStyle::local(),
2737 )
2738 .unwrap();
2739 builder.subscribe(cx)
2740 });
2741
2742 thread.update(cx, |thread, cx| {
2743 thread.on_terminal_provider_event(
2744 TerminalProviderEvent::Created {
2745 terminal_id: terminal_id.clone(),
2746 label: "Buffered Exit Test".to_string(),
2747 cwd: None,
2748 output_byte_limit: None,
2749 terminal: lower.clone(),
2750 },
2751 cx,
2752 );
2753 });
2754
2755 // Output should be present after Created (flushed from buffer)
2756 let content = thread.read_with(cx, |thread, cx| {
2757 let term = thread.terminal(terminal_id.clone()).unwrap();
2758 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2759 });
2760
2761 assert!(
2762 content.contains("pre-exit data"),
2763 "expected pre-exit data to render, got: {content}"
2764 );
2765 }
2766
2767 /// Test that killing a terminal via Terminal::kill properly:
2768 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2769 /// 2. The underlying terminal still has the output that was written before the kill
2770 ///
2771 /// This test verifies that the fix to kill_active_task (which now also kills
2772 /// the shell process in addition to the foreground process) properly allows
2773 /// wait_for_exit to complete instead of hanging indefinitely.
2774 #[cfg(unix)]
2775 #[gpui::test]
2776 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2777 use std::collections::HashMap;
2778 use task::Shell;
2779 use util::shell_builder::ShellBuilder;
2780
2781 init_test(cx);
2782 cx.executor().allow_parking();
2783
2784 let fs = FakeFs::new(cx.executor());
2785 let project = Project::test(fs, [], cx).await;
2786 let connection = Rc::new(FakeAgentConnection::new());
2787 let thread = cx
2788 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2789 .await
2790 .unwrap();
2791
2792 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2793
2794 // Create a real PTY terminal that runs a command which prints output then sleeps
2795 // We use printf instead of echo and chain with && sleep to ensure proper execution
2796 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2797 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2798 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2799 &[],
2800 );
2801
2802 let builder = cx
2803 .update(|cx| {
2804 ::terminal::TerminalBuilder::new(
2805 None,
2806 None,
2807 task::Shell::WithArguments {
2808 program,
2809 args,
2810 title_override: None,
2811 },
2812 HashMap::default(),
2813 ::terminal::terminal_settings::CursorShape::default(),
2814 ::terminal::terminal_settings::AlternateScroll::On,
2815 None,
2816 vec![],
2817 0,
2818 false,
2819 0,
2820 Some(completion_tx),
2821 cx,
2822 vec![],
2823 PathStyle::local(),
2824 )
2825 })
2826 .await
2827 .unwrap();
2828
2829 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2830
2831 // Create the acp_thread Terminal wrapper
2832 thread.update(cx, |thread, cx| {
2833 thread.on_terminal_provider_event(
2834 TerminalProviderEvent::Created {
2835 terminal_id: terminal_id.clone(),
2836 label: "printf output_before_kill && sleep 60".to_string(),
2837 cwd: None,
2838 output_byte_limit: None,
2839 terminal: lower_terminal.clone(),
2840 },
2841 cx,
2842 );
2843 });
2844
2845 // Wait for the printf command to execute and produce output
2846 // Use real time since parking is enabled
2847 cx.executor().timer(Duration::from_millis(500)).await;
2848
2849 // Get the acp_thread Terminal and kill it
2850 let wait_for_exit = thread.update(cx, |thread, cx| {
2851 let term = thread.terminals.get(&terminal_id).unwrap();
2852 let wait_for_exit = term.read(cx).wait_for_exit();
2853 term.update(cx, |term, cx| {
2854 term.kill(cx);
2855 });
2856 wait_for_exit
2857 });
2858
2859 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2860 // Before the fix to kill_active_task, this would hang forever because
2861 // only the foreground process was killed, not the shell, so the PTY
2862 // child never exited and wait_for_completed_task never completed.
2863 let exit_result = futures::select! {
2864 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2865 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2866 };
2867
2868 assert!(
2869 exit_result.is_some(),
2870 "wait_for_exit should complete after kill, but it timed out. \
2871 This indicates kill_active_task is not properly killing the shell process."
2872 );
2873
2874 // Give the system a chance to process any pending updates
2875 cx.run_until_parked();
2876
2877 // Verify that the underlying terminal still has the output that was
2878 // written before the kill. This verifies that killing doesn't lose output.
2879 let inner_content = thread.read_with(cx, |thread, cx| {
2880 let term = thread.terminals.get(&terminal_id).unwrap();
2881 term.read(cx).inner().read(cx).get_content()
2882 });
2883
2884 assert!(
2885 inner_content.contains("output_before_kill"),
2886 "Underlying terminal should contain output from before kill, got: {}",
2887 inner_content
2888 );
2889 }
2890
2891 #[gpui::test]
2892 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2893 init_test(cx);
2894
2895 let fs = FakeFs::new(cx.executor());
2896 let project = Project::test(fs, [], cx).await;
2897 let connection = Rc::new(FakeAgentConnection::new());
2898 let thread = cx
2899 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2900 .await
2901 .unwrap();
2902
2903 // Test creating a new user message
2904 thread.update(cx, |thread, cx| {
2905 thread.push_user_content_block(None, "Hello, ".into(), cx);
2906 });
2907
2908 thread.update(cx, |thread, cx| {
2909 assert_eq!(thread.entries.len(), 1);
2910 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2911 assert_eq!(user_msg.id, None);
2912 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2913 } else {
2914 panic!("Expected UserMessage");
2915 }
2916 });
2917
2918 // Test appending to existing user message
2919 let message_1_id = UserMessageId::new();
2920 thread.update(cx, |thread, cx| {
2921 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2922 });
2923
2924 thread.update(cx, |thread, cx| {
2925 assert_eq!(thread.entries.len(), 1);
2926 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2927 assert_eq!(user_msg.id, Some(message_1_id));
2928 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2929 } else {
2930 panic!("Expected UserMessage");
2931 }
2932 });
2933
2934 // Test creating new user message after assistant message
2935 thread.update(cx, |thread, cx| {
2936 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2937 });
2938
2939 let message_2_id = UserMessageId::new();
2940 thread.update(cx, |thread, cx| {
2941 thread.push_user_content_block(
2942 Some(message_2_id.clone()),
2943 "New user message".into(),
2944 cx,
2945 );
2946 });
2947
2948 thread.update(cx, |thread, cx| {
2949 assert_eq!(thread.entries.len(), 3);
2950 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2951 assert_eq!(user_msg.id, Some(message_2_id));
2952 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2953 } else {
2954 panic!("Expected UserMessage at index 2");
2955 }
2956 });
2957 }
2958
2959 #[gpui::test]
2960 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2961 init_test(cx);
2962
2963 let fs = FakeFs::new(cx.executor());
2964 let project = Project::test(fs, [], cx).await;
2965 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2966 |_, thread, mut cx| {
2967 async move {
2968 thread.update(&mut cx, |thread, cx| {
2969 thread
2970 .handle_session_update(
2971 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2972 "Thinking ".into(),
2973 )),
2974 cx,
2975 )
2976 .unwrap();
2977 thread
2978 .handle_session_update(
2979 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2980 "hard!".into(),
2981 )),
2982 cx,
2983 )
2984 .unwrap();
2985 })?;
2986 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2987 }
2988 .boxed_local()
2989 },
2990 ));
2991
2992 let thread = cx
2993 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2994 .await
2995 .unwrap();
2996
2997 thread
2998 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2999 .await
3000 .unwrap();
3001
3002 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
3003 assert_eq!(
3004 output,
3005 indoc! {r#"
3006 ## User
3007
3008 Hello from Zed!
3009
3010 ## Assistant
3011
3012 <thinking>
3013 Thinking hard!
3014 </thinking>
3015
3016 "#}
3017 );
3018 }
3019
3020 #[gpui::test]
3021 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3022 init_test(cx);
3023
3024 let fs = FakeFs::new(cx.executor());
3025 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3026 .await;
3027 let project = Project::test(fs.clone(), [], cx).await;
3028 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3029 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3030 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3031 move |_, thread, mut cx| {
3032 let read_file_tx = read_file_tx.clone();
3033 async move {
3034 let content = thread
3035 .update(&mut cx, |thread, cx| {
3036 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3037 })
3038 .unwrap()
3039 .await
3040 .unwrap();
3041 assert_eq!(content, "one\ntwo\nthree\n");
3042 read_file_tx.take().unwrap().send(()).unwrap();
3043 thread
3044 .update(&mut cx, |thread, cx| {
3045 thread.write_text_file(
3046 path!("/tmp/foo").into(),
3047 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3048 cx,
3049 )
3050 })
3051 .unwrap()
3052 .await
3053 .unwrap();
3054 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3055 }
3056 .boxed_local()
3057 },
3058 ));
3059
3060 let (worktree, pathbuf) = project
3061 .update(cx, |project, cx| {
3062 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3063 })
3064 .await
3065 .unwrap();
3066 let buffer = project
3067 .update(cx, |project, cx| {
3068 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3069 })
3070 .await
3071 .unwrap();
3072
3073 let thread = cx
3074 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3075 .await
3076 .unwrap();
3077
3078 let request = thread.update(cx, |thread, cx| {
3079 thread.send_raw("Extend the count in /tmp/foo", cx)
3080 });
3081 read_file_rx.await.ok();
3082 buffer.update(cx, |buffer, cx| {
3083 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3084 });
3085 cx.run_until_parked();
3086 assert_eq!(
3087 buffer.read_with(cx, |buffer, _| buffer.text()),
3088 "zero\none\ntwo\nthree\nfour\nfive\n"
3089 );
3090 assert_eq!(
3091 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3092 "zero\none\ntwo\nthree\nfour\nfive\n"
3093 );
3094 request.await.unwrap();
3095 }
3096
3097 #[gpui::test]
3098 async fn test_reading_from_line(cx: &mut TestAppContext) {
3099 init_test(cx);
3100
3101 let fs = FakeFs::new(cx.executor());
3102 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3103 .await;
3104 let project = Project::test(fs.clone(), [], cx).await;
3105 project
3106 .update(cx, |project, cx| {
3107 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3108 })
3109 .await
3110 .unwrap();
3111
3112 let connection = Rc::new(FakeAgentConnection::new());
3113
3114 let thread = cx
3115 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3116 .await
3117 .unwrap();
3118
3119 // Whole file
3120 let content = thread
3121 .update(cx, |thread, cx| {
3122 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3123 })
3124 .await
3125 .unwrap();
3126
3127 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3128
3129 // Only start line
3130 let content = thread
3131 .update(cx, |thread, cx| {
3132 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3133 })
3134 .await
3135 .unwrap();
3136
3137 assert_eq!(content, "three\nfour\n");
3138
3139 // Only limit
3140 let content = thread
3141 .update(cx, |thread, cx| {
3142 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3143 })
3144 .await
3145 .unwrap();
3146
3147 assert_eq!(content, "one\ntwo\n");
3148
3149 // Range
3150 let content = thread
3151 .update(cx, |thread, cx| {
3152 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3153 })
3154 .await
3155 .unwrap();
3156
3157 assert_eq!(content, "two\nthree\n");
3158
3159 // Invalid
3160 let err = thread
3161 .update(cx, |thread, cx| {
3162 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3163 })
3164 .await
3165 .unwrap_err();
3166
3167 assert_eq!(
3168 err.to_string(),
3169 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3170 );
3171 }
3172
3173 #[gpui::test]
3174 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3175 init_test(cx);
3176
3177 let fs = FakeFs::new(cx.executor());
3178 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3179 let project = Project::test(fs.clone(), [], cx).await;
3180 project
3181 .update(cx, |project, cx| {
3182 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3183 })
3184 .await
3185 .unwrap();
3186
3187 let connection = Rc::new(FakeAgentConnection::new());
3188
3189 let thread = cx
3190 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3191 .await
3192 .unwrap();
3193
3194 // Whole file
3195 let content = thread
3196 .update(cx, |thread, cx| {
3197 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3198 })
3199 .await
3200 .unwrap();
3201
3202 assert_eq!(content, "");
3203
3204 // Only start line
3205 let content = thread
3206 .update(cx, |thread, cx| {
3207 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3208 })
3209 .await
3210 .unwrap();
3211
3212 assert_eq!(content, "");
3213
3214 // Only limit
3215 let content = thread
3216 .update(cx, |thread, cx| {
3217 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3218 })
3219 .await
3220 .unwrap();
3221
3222 assert_eq!(content, "");
3223
3224 // Range
3225 let content = thread
3226 .update(cx, |thread, cx| {
3227 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3228 })
3229 .await
3230 .unwrap();
3231
3232 assert_eq!(content, "");
3233
3234 // Invalid
3235 let err = thread
3236 .update(cx, |thread, cx| {
3237 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3238 })
3239 .await
3240 .unwrap_err();
3241
3242 assert_eq!(
3243 err.to_string(),
3244 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3245 );
3246 }
3247 #[gpui::test]
3248 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3249 init_test(cx);
3250
3251 let fs = FakeFs::new(cx.executor());
3252 fs.insert_tree(path!("/tmp"), json!({})).await;
3253 let project = Project::test(fs.clone(), [], cx).await;
3254 project
3255 .update(cx, |project, cx| {
3256 project.find_or_create_worktree(path!("/tmp"), true, cx)
3257 })
3258 .await
3259 .unwrap();
3260
3261 let connection = Rc::new(FakeAgentConnection::new());
3262
3263 let thread = cx
3264 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3265 .await
3266 .unwrap();
3267
3268 // Out of project file
3269 let err = thread
3270 .update(cx, |thread, cx| {
3271 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3272 })
3273 .await
3274 .unwrap_err();
3275
3276 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3277 }
3278
3279 #[gpui::test]
3280 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3281 init_test(cx);
3282
3283 let fs = FakeFs::new(cx.executor());
3284 let project = Project::test(fs, [], cx).await;
3285 let id = acp::ToolCallId::new("test");
3286
3287 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3288 let id = id.clone();
3289 move |_, thread, mut cx| {
3290 let id = id.clone();
3291 async move {
3292 thread
3293 .update(&mut cx, |thread, cx| {
3294 thread.handle_session_update(
3295 acp::SessionUpdate::ToolCall(
3296 acp::ToolCall::new(id.clone(), "Label")
3297 .kind(acp::ToolKind::Fetch)
3298 .status(acp::ToolCallStatus::InProgress),
3299 ),
3300 cx,
3301 )
3302 })
3303 .unwrap()
3304 .unwrap();
3305 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3306 }
3307 .boxed_local()
3308 }
3309 }));
3310
3311 let thread = cx
3312 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3313 .await
3314 .unwrap();
3315
3316 let request = thread.update(cx, |thread, cx| {
3317 thread.send_raw("Fetch https://example.com", cx)
3318 });
3319
3320 run_until_first_tool_call(&thread, cx).await;
3321
3322 thread.read_with(cx, |thread, _| {
3323 assert!(matches!(
3324 thread.entries[1],
3325 AgentThreadEntry::ToolCall(ToolCall {
3326 status: ToolCallStatus::InProgress,
3327 ..
3328 })
3329 ));
3330 });
3331
3332 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3333
3334 thread.read_with(cx, |thread, _| {
3335 assert!(matches!(
3336 &thread.entries[1],
3337 AgentThreadEntry::ToolCall(ToolCall {
3338 status: ToolCallStatus::Canceled,
3339 ..
3340 })
3341 ));
3342 });
3343
3344 thread
3345 .update(cx, |thread, cx| {
3346 thread.handle_session_update(
3347 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3348 id,
3349 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3350 )),
3351 cx,
3352 )
3353 })
3354 .unwrap();
3355
3356 request.await.unwrap();
3357
3358 thread.read_with(cx, |thread, _| {
3359 assert!(matches!(
3360 thread.entries[1],
3361 AgentThreadEntry::ToolCall(ToolCall {
3362 status: ToolCallStatus::Completed,
3363 ..
3364 })
3365 ));
3366 });
3367 }
3368
3369 #[gpui::test]
3370 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3371 init_test(cx);
3372 let fs = FakeFs::new(cx.background_executor.clone());
3373 fs.insert_tree(path!("/test"), json!({})).await;
3374 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3375
3376 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3377 move |_, thread, mut cx| {
3378 async move {
3379 thread
3380 .update(&mut cx, |thread, cx| {
3381 thread.handle_session_update(
3382 acp::SessionUpdate::ToolCall(
3383 acp::ToolCall::new("test", "Label")
3384 .kind(acp::ToolKind::Edit)
3385 .status(acp::ToolCallStatus::Completed)
3386 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3387 "/test/test.txt",
3388 "foo",
3389 ))]),
3390 ),
3391 cx,
3392 )
3393 })
3394 .unwrap()
3395 .unwrap();
3396 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3397 }
3398 .boxed_local()
3399 }
3400 }));
3401
3402 let thread = cx
3403 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3404 .await
3405 .unwrap();
3406
3407 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3408 .await
3409 .unwrap();
3410
3411 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3412 }
3413
3414 #[gpui::test(iterations = 10)]
3415 async fn test_checkpoints(cx: &mut TestAppContext) {
3416 init_test(cx);
3417 let fs = FakeFs::new(cx.background_executor.clone());
3418 fs.insert_tree(
3419 path!("/test"),
3420 json!({
3421 ".git": {}
3422 }),
3423 )
3424 .await;
3425 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3426
3427 let simulate_changes = Arc::new(AtomicBool::new(true));
3428 let next_filename = Arc::new(AtomicUsize::new(0));
3429 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3430 let simulate_changes = simulate_changes.clone();
3431 let next_filename = next_filename.clone();
3432 let fs = fs.clone();
3433 move |request, thread, mut cx| {
3434 let fs = fs.clone();
3435 let simulate_changes = simulate_changes.clone();
3436 let next_filename = next_filename.clone();
3437 async move {
3438 if simulate_changes.load(SeqCst) {
3439 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3440 fs.write(Path::new(&filename), b"").await?;
3441 }
3442
3443 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3444 panic!("expected text content block");
3445 };
3446 thread.update(&mut cx, |thread, cx| {
3447 thread
3448 .handle_session_update(
3449 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3450 content.text.to_uppercase().into(),
3451 )),
3452 cx,
3453 )
3454 .unwrap();
3455 })?;
3456 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3457 }
3458 .boxed_local()
3459 }
3460 }));
3461 let thread = cx
3462 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3463 .await
3464 .unwrap();
3465
3466 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3467 .await
3468 .unwrap();
3469 thread.read_with(cx, |thread, cx| {
3470 assert_eq!(
3471 thread.to_markdown(cx),
3472 indoc! {"
3473 ## User (checkpoint)
3474
3475 Lorem
3476
3477 ## Assistant
3478
3479 LOREM
3480
3481 "}
3482 );
3483 });
3484 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3485
3486 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3487 .await
3488 .unwrap();
3489 thread.read_with(cx, |thread, cx| {
3490 assert_eq!(
3491 thread.to_markdown(cx),
3492 indoc! {"
3493 ## User (checkpoint)
3494
3495 Lorem
3496
3497 ## Assistant
3498
3499 LOREM
3500
3501 ## User (checkpoint)
3502
3503 ipsum
3504
3505 ## Assistant
3506
3507 IPSUM
3508
3509 "}
3510 );
3511 });
3512 assert_eq!(
3513 fs.files(),
3514 vec![
3515 Path::new(path!("/test/file-0")),
3516 Path::new(path!("/test/file-1"))
3517 ]
3518 );
3519
3520 // Checkpoint isn't stored when there are no changes.
3521 simulate_changes.store(false, SeqCst);
3522 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3523 .await
3524 .unwrap();
3525 thread.read_with(cx, |thread, cx| {
3526 assert_eq!(
3527 thread.to_markdown(cx),
3528 indoc! {"
3529 ## User (checkpoint)
3530
3531 Lorem
3532
3533 ## Assistant
3534
3535 LOREM
3536
3537 ## User (checkpoint)
3538
3539 ipsum
3540
3541 ## Assistant
3542
3543 IPSUM
3544
3545 ## User
3546
3547 dolor
3548
3549 ## Assistant
3550
3551 DOLOR
3552
3553 "}
3554 );
3555 });
3556 assert_eq!(
3557 fs.files(),
3558 vec![
3559 Path::new(path!("/test/file-0")),
3560 Path::new(path!("/test/file-1"))
3561 ]
3562 );
3563
3564 // Rewinding the conversation truncates the history and restores the checkpoint.
3565 thread
3566 .update(cx, |thread, cx| {
3567 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3568 panic!("unexpected entries {:?}", thread.entries)
3569 };
3570 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3571 })
3572 .await
3573 .unwrap();
3574 thread.read_with(cx, |thread, cx| {
3575 assert_eq!(
3576 thread.to_markdown(cx),
3577 indoc! {"
3578 ## User (checkpoint)
3579
3580 Lorem
3581
3582 ## Assistant
3583
3584 LOREM
3585
3586 "}
3587 );
3588 });
3589 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3590 }
3591
3592 #[gpui::test]
3593 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3594 use std::sync::atomic::AtomicUsize;
3595 init_test(cx);
3596
3597 let fs = FakeFs::new(cx.executor());
3598 let project = Project::test(fs, None, cx).await;
3599
3600 // Create a connection that simulates refusal after tool result
3601 let prompt_count = Arc::new(AtomicUsize::new(0));
3602 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3603 let prompt_count = prompt_count.clone();
3604 move |_request, thread, mut cx| {
3605 let count = prompt_count.fetch_add(1, SeqCst);
3606 async move {
3607 if count == 0 {
3608 // First prompt: Generate a tool call with result
3609 thread.update(&mut cx, |thread, cx| {
3610 thread
3611 .handle_session_update(
3612 acp::SessionUpdate::ToolCall(
3613 acp::ToolCall::new("tool1", "Test Tool")
3614 .kind(acp::ToolKind::Fetch)
3615 .status(acp::ToolCallStatus::Completed)
3616 .raw_input(serde_json::json!({"query": "test"}))
3617 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3618 ),
3619 cx,
3620 )
3621 .unwrap();
3622 })?;
3623
3624 // Now return refusal because of the tool result
3625 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3626 } else {
3627 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3628 }
3629 }
3630 .boxed_local()
3631 }
3632 }));
3633
3634 let thread = cx
3635 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3636 .await
3637 .unwrap();
3638
3639 // Track if we see a Refusal event
3640 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3641 let saw_refusal_event_captured = saw_refusal_event.clone();
3642 thread.update(cx, |_thread, cx| {
3643 cx.subscribe(
3644 &thread,
3645 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3646 if matches!(event, AcpThreadEvent::Refusal) {
3647 *saw_refusal_event_captured.lock().unwrap() = true;
3648 }
3649 },
3650 )
3651 .detach();
3652 });
3653
3654 // Send a user message - this will trigger tool call and then refusal
3655 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3656 cx.background_executor.spawn(send_task).detach();
3657 cx.run_until_parked();
3658
3659 // Verify that:
3660 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3661 // 2. The user message was NOT truncated
3662 assert!(
3663 *saw_refusal_event.lock().unwrap(),
3664 "Refusal event should be emitted for tool result refusals"
3665 );
3666
3667 thread.read_with(cx, |thread, _| {
3668 let entries = thread.entries();
3669 assert!(entries.len() >= 2, "Should have user message and tool call");
3670
3671 // Verify user message is still there
3672 assert!(
3673 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3674 "User message should not be truncated"
3675 );
3676
3677 // Verify tool call is there with result
3678 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3679 assert!(
3680 tool_call.raw_output.is_some(),
3681 "Tool call should have output"
3682 );
3683 } else {
3684 panic!("Expected tool call at index 1");
3685 }
3686 });
3687 }
3688
3689 #[gpui::test]
3690 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3691 init_test(cx);
3692
3693 let fs = FakeFs::new(cx.executor());
3694 let project = Project::test(fs, None, cx).await;
3695
3696 let refuse_next = Arc::new(AtomicBool::new(false));
3697 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3698 let refuse_next = refuse_next.clone();
3699 move |_request, _thread, _cx| {
3700 if refuse_next.load(SeqCst) {
3701 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3702 .boxed_local()
3703 } else {
3704 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3705 .boxed_local()
3706 }
3707 }
3708 }));
3709
3710 let thread = cx
3711 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3712 .await
3713 .unwrap();
3714
3715 // Track if we see a Refusal event
3716 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3717 let saw_refusal_event_captured = saw_refusal_event.clone();
3718 thread.update(cx, |_thread, cx| {
3719 cx.subscribe(
3720 &thread,
3721 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3722 if matches!(event, AcpThreadEvent::Refusal) {
3723 *saw_refusal_event_captured.lock().unwrap() = true;
3724 }
3725 },
3726 )
3727 .detach();
3728 });
3729
3730 // Send a message that will be refused
3731 refuse_next.store(true, SeqCst);
3732 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3733 .await
3734 .unwrap();
3735
3736 // Verify that a Refusal event WAS emitted for user prompt refusal
3737 assert!(
3738 *saw_refusal_event.lock().unwrap(),
3739 "Refusal event should be emitted for user prompt refusals"
3740 );
3741
3742 // Verify the message was truncated (user prompt refusal)
3743 thread.read_with(cx, |thread, cx| {
3744 assert_eq!(thread.to_markdown(cx), "");
3745 });
3746 }
3747
3748 #[gpui::test]
3749 async fn test_refusal(cx: &mut TestAppContext) {
3750 init_test(cx);
3751 let fs = FakeFs::new(cx.background_executor.clone());
3752 fs.insert_tree(path!("/"), json!({})).await;
3753 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3754
3755 let refuse_next = Arc::new(AtomicBool::new(false));
3756 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3757 let refuse_next = refuse_next.clone();
3758 move |request, thread, mut cx| {
3759 let refuse_next = refuse_next.clone();
3760 async move {
3761 if refuse_next.load(SeqCst) {
3762 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3763 }
3764
3765 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3766 panic!("expected text content block");
3767 };
3768 thread.update(&mut cx, |thread, cx| {
3769 thread
3770 .handle_session_update(
3771 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3772 content.text.to_uppercase().into(),
3773 )),
3774 cx,
3775 )
3776 .unwrap();
3777 })?;
3778 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3779 }
3780 .boxed_local()
3781 }
3782 }));
3783 let thread = cx
3784 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3785 .await
3786 .unwrap();
3787
3788 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3789 .await
3790 .unwrap();
3791 thread.read_with(cx, |thread, cx| {
3792 assert_eq!(
3793 thread.to_markdown(cx),
3794 indoc! {"
3795 ## User
3796
3797 hello
3798
3799 ## Assistant
3800
3801 HELLO
3802
3803 "}
3804 );
3805 });
3806
3807 // Simulate refusing the second message. The message should be truncated
3808 // when a user prompt is refused.
3809 refuse_next.store(true, SeqCst);
3810 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3811 .await
3812 .unwrap();
3813 thread.read_with(cx, |thread, cx| {
3814 assert_eq!(
3815 thread.to_markdown(cx),
3816 indoc! {"
3817 ## User
3818
3819 hello
3820
3821 ## Assistant
3822
3823 HELLO
3824
3825 "}
3826 );
3827 });
3828 }
3829
3830 async fn run_until_first_tool_call(
3831 thread: &Entity<AcpThread>,
3832 cx: &mut TestAppContext,
3833 ) -> usize {
3834 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3835
3836 let subscription = cx.update(|cx| {
3837 cx.subscribe(thread, move |thread, _, cx| {
3838 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3839 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3840 return tx.try_send(ix).unwrap();
3841 }
3842 }
3843 })
3844 });
3845
3846 select! {
3847 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3848 panic!("Timeout waiting for tool call")
3849 }
3850 ix = rx.next().fuse() => {
3851 drop(subscription);
3852 ix.unwrap()
3853 }
3854 }
3855 }
3856
3857 #[derive(Clone, Default)]
3858 struct FakeAgentConnection {
3859 auth_methods: Vec<acp::AuthMethod>,
3860 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3861 on_user_message: Option<
3862 Rc<
3863 dyn Fn(
3864 acp::PromptRequest,
3865 WeakEntity<AcpThread>,
3866 AsyncApp,
3867 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3868 + 'static,
3869 >,
3870 >,
3871 }
3872
3873 impl FakeAgentConnection {
3874 fn new() -> Self {
3875 Self {
3876 auth_methods: Vec::new(),
3877 on_user_message: None,
3878 sessions: Arc::default(),
3879 }
3880 }
3881
3882 #[expect(unused)]
3883 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3884 self.auth_methods = auth_methods;
3885 self
3886 }
3887
3888 fn on_user_message(
3889 mut self,
3890 handler: impl Fn(
3891 acp::PromptRequest,
3892 WeakEntity<AcpThread>,
3893 AsyncApp,
3894 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3895 + 'static,
3896 ) -> Self {
3897 self.on_user_message.replace(Rc::new(handler));
3898 self
3899 }
3900 }
3901
3902 impl AgentConnection for FakeAgentConnection {
3903 fn telemetry_id(&self) -> SharedString {
3904 "fake".into()
3905 }
3906
3907 fn auth_methods(&self) -> &[acp::AuthMethod] {
3908 &self.auth_methods
3909 }
3910
3911 fn new_session(
3912 self: Rc<Self>,
3913 project: Entity<Project>,
3914 _cwd: &Path,
3915 cx: &mut App,
3916 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3917 let session_id = acp::SessionId::new(
3918 rand::rng()
3919 .sample_iter(&distr::Alphanumeric)
3920 .take(7)
3921 .map(char::from)
3922 .collect::<String>(),
3923 );
3924 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3925 let thread = cx.new(|cx| {
3926 AcpThread::new(
3927 None,
3928 "Test",
3929 self.clone(),
3930 project,
3931 action_log,
3932 session_id.clone(),
3933 watch::Receiver::constant(
3934 acp::PromptCapabilities::new()
3935 .image(true)
3936 .audio(true)
3937 .embedded_context(true),
3938 ),
3939 cx,
3940 )
3941 });
3942 self.sessions.lock().insert(session_id, thread.downgrade());
3943 Task::ready(Ok(thread))
3944 }
3945
3946 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3947 if self.auth_methods().iter().any(|m| m.id == method) {
3948 Task::ready(Ok(()))
3949 } else {
3950 Task::ready(Err(anyhow!("Invalid Auth Method")))
3951 }
3952 }
3953
3954 fn prompt(
3955 &self,
3956 _id: Option<UserMessageId>,
3957 params: acp::PromptRequest,
3958 cx: &mut App,
3959 ) -> Task<gpui::Result<acp::PromptResponse>> {
3960 let sessions = self.sessions.lock();
3961 let thread = sessions.get(¶ms.session_id).unwrap();
3962 if let Some(handler) = &self.on_user_message {
3963 let handler = handler.clone();
3964 let thread = thread.clone();
3965 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3966 } else {
3967 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3968 }
3969 }
3970
3971 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
3972
3973 fn truncate(
3974 &self,
3975 session_id: &acp::SessionId,
3976 _cx: &App,
3977 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3978 Some(Rc::new(FakeAgentSessionEditor {
3979 _session_id: session_id.clone(),
3980 }))
3981 }
3982
3983 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3984 self
3985 }
3986 }
3987
3988 struct FakeAgentSessionEditor {
3989 _session_id: acp::SessionId,
3990 }
3991
3992 impl AgentSessionTruncate for FakeAgentSessionEditor {
3993 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3994 Task::ready(Ok(()))
3995 }
3996 }
3997
3998 #[gpui::test]
3999 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
4000 init_test(cx);
4001
4002 let fs = FakeFs::new(cx.executor());
4003 let project = Project::test(fs, [], cx).await;
4004 let connection = Rc::new(FakeAgentConnection::new());
4005 let thread = cx
4006 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4007 .await
4008 .unwrap();
4009
4010 // Try to update a tool call that doesn't exist
4011 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
4012 thread.update(cx, |thread, cx| {
4013 let result = thread.handle_session_update(
4014 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4015 nonexistent_id.clone(),
4016 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4017 )),
4018 cx,
4019 );
4020
4021 // The update should succeed (not return an error)
4022 assert!(result.is_ok());
4023
4024 // There should now be exactly one entry in the thread
4025 assert_eq!(thread.entries.len(), 1);
4026
4027 // The entry should be a failed tool call
4028 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4029 assert_eq!(tool_call.id, nonexistent_id);
4030 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4031 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4032
4033 // Check that the content contains the error message
4034 assert_eq!(tool_call.content.len(), 1);
4035 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4036 match content_block {
4037 ContentBlock::Markdown { markdown } => {
4038 let markdown_text = markdown.read(cx).source();
4039 assert!(markdown_text.contains("Tool call not found"));
4040 }
4041 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4042 ContentBlock::ResourceLink { .. } => {
4043 panic!("Expected markdown content, got resource link")
4044 }
4045 ContentBlock::Image { .. } => {
4046 panic!("Expected markdown content, got image")
4047 }
4048 }
4049 } else {
4050 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4051 }
4052 } else {
4053 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4054 }
4055 });
4056 }
4057
4058 /// Tests that restoring a checkpoint properly cleans up terminals that were
4059 /// created after that checkpoint, and cancels any in-progress generation.
4060 ///
4061 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4062 /// that were started after that checkpoint should be terminated, and any in-progress
4063 /// AI generation should be canceled.
4064 #[gpui::test]
4065 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4066 init_test(cx);
4067
4068 let fs = FakeFs::new(cx.executor());
4069 let project = Project::test(fs, [], cx).await;
4070 let connection = Rc::new(FakeAgentConnection::new());
4071 let thread = cx
4072 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4073 .await
4074 .unwrap();
4075
4076 // Send first user message to create a checkpoint
4077 cx.update(|cx| {
4078 thread.update(cx, |thread, cx| {
4079 thread.send(vec!["first message".into()], cx)
4080 })
4081 })
4082 .await
4083 .unwrap();
4084
4085 // Send second message (creates another checkpoint) - we'll restore to this one
4086 cx.update(|cx| {
4087 thread.update(cx, |thread, cx| {
4088 thread.send(vec!["second message".into()], cx)
4089 })
4090 })
4091 .await
4092 .unwrap();
4093
4094 // Create 2 terminals BEFORE the checkpoint that have completed running
4095 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4096 let mock_terminal_1 = cx.new(|cx| {
4097 let builder = ::terminal::TerminalBuilder::new_display_only(
4098 ::terminal::terminal_settings::CursorShape::default(),
4099 ::terminal::terminal_settings::AlternateScroll::On,
4100 None,
4101 0,
4102 cx.background_executor(),
4103 PathStyle::local(),
4104 )
4105 .unwrap();
4106 builder.subscribe(cx)
4107 });
4108
4109 thread.update(cx, |thread, cx| {
4110 thread.on_terminal_provider_event(
4111 TerminalProviderEvent::Created {
4112 terminal_id: terminal_id_1.clone(),
4113 label: "echo 'first'".to_string(),
4114 cwd: Some(PathBuf::from("/test")),
4115 output_byte_limit: None,
4116 terminal: mock_terminal_1.clone(),
4117 },
4118 cx,
4119 );
4120 });
4121
4122 thread.update(cx, |thread, cx| {
4123 thread.on_terminal_provider_event(
4124 TerminalProviderEvent::Output {
4125 terminal_id: terminal_id_1.clone(),
4126 data: b"first\n".to_vec(),
4127 },
4128 cx,
4129 );
4130 });
4131
4132 thread.update(cx, |thread, cx| {
4133 thread.on_terminal_provider_event(
4134 TerminalProviderEvent::Exit {
4135 terminal_id: terminal_id_1.clone(),
4136 status: acp::TerminalExitStatus::new().exit_code(0),
4137 },
4138 cx,
4139 );
4140 });
4141
4142 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4143 let mock_terminal_2 = cx.new(|cx| {
4144 let builder = ::terminal::TerminalBuilder::new_display_only(
4145 ::terminal::terminal_settings::CursorShape::default(),
4146 ::terminal::terminal_settings::AlternateScroll::On,
4147 None,
4148 0,
4149 cx.background_executor(),
4150 PathStyle::local(),
4151 )
4152 .unwrap();
4153 builder.subscribe(cx)
4154 });
4155
4156 thread.update(cx, |thread, cx| {
4157 thread.on_terminal_provider_event(
4158 TerminalProviderEvent::Created {
4159 terminal_id: terminal_id_2.clone(),
4160 label: "echo 'second'".to_string(),
4161 cwd: Some(PathBuf::from("/test")),
4162 output_byte_limit: None,
4163 terminal: mock_terminal_2.clone(),
4164 },
4165 cx,
4166 );
4167 });
4168
4169 thread.update(cx, |thread, cx| {
4170 thread.on_terminal_provider_event(
4171 TerminalProviderEvent::Output {
4172 terminal_id: terminal_id_2.clone(),
4173 data: b"second\n".to_vec(),
4174 },
4175 cx,
4176 );
4177 });
4178
4179 thread.update(cx, |thread, cx| {
4180 thread.on_terminal_provider_event(
4181 TerminalProviderEvent::Exit {
4182 terminal_id: terminal_id_2.clone(),
4183 status: acp::TerminalExitStatus::new().exit_code(0),
4184 },
4185 cx,
4186 );
4187 });
4188
4189 // Get the second message ID to restore to
4190 let second_message_id = thread.read_with(cx, |thread, _| {
4191 // At this point we have:
4192 // - Index 0: First user message (with checkpoint)
4193 // - Index 1: Second user message (with checkpoint)
4194 // No assistant responses because FakeAgentConnection just returns EndTurn
4195 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4196 panic!("expected user message at index 1");
4197 };
4198 message.id.clone().unwrap()
4199 });
4200
4201 // Create a terminal AFTER the checkpoint we'll restore to.
4202 // This simulates the AI agent starting a long-running terminal command.
4203 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4204 let mock_terminal = cx.new(|cx| {
4205 let builder = ::terminal::TerminalBuilder::new_display_only(
4206 ::terminal::terminal_settings::CursorShape::default(),
4207 ::terminal::terminal_settings::AlternateScroll::On,
4208 None,
4209 0,
4210 cx.background_executor(),
4211 PathStyle::local(),
4212 )
4213 .unwrap();
4214 builder.subscribe(cx)
4215 });
4216
4217 // Register the terminal as created
4218 thread.update(cx, |thread, cx| {
4219 thread.on_terminal_provider_event(
4220 TerminalProviderEvent::Created {
4221 terminal_id: terminal_id.clone(),
4222 label: "sleep 1000".to_string(),
4223 cwd: Some(PathBuf::from("/test")),
4224 output_byte_limit: None,
4225 terminal: mock_terminal.clone(),
4226 },
4227 cx,
4228 );
4229 });
4230
4231 // Simulate the terminal producing output (still running)
4232 thread.update(cx, |thread, cx| {
4233 thread.on_terminal_provider_event(
4234 TerminalProviderEvent::Output {
4235 terminal_id: terminal_id.clone(),
4236 data: b"terminal is running...\n".to_vec(),
4237 },
4238 cx,
4239 );
4240 });
4241
4242 // Create a tool call entry that references this terminal
4243 // This represents the agent requesting a terminal command
4244 thread.update(cx, |thread, cx| {
4245 thread
4246 .handle_session_update(
4247 acp::SessionUpdate::ToolCall(
4248 acp::ToolCall::new("terminal-tool-1", "Running command")
4249 .kind(acp::ToolKind::Execute)
4250 .status(acp::ToolCallStatus::InProgress)
4251 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4252 terminal_id.clone(),
4253 ))])
4254 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4255 ),
4256 cx,
4257 )
4258 .unwrap();
4259 });
4260
4261 // Verify terminal exists and is in the thread
4262 let terminal_exists_before =
4263 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4264 assert!(
4265 terminal_exists_before,
4266 "Terminal should exist before checkpoint restore"
4267 );
4268
4269 // Verify the terminal's underlying task is still running (not completed)
4270 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4271 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4272 terminal_entity.read_with(cx, |term, _cx| {
4273 term.output().is_none() // output is None means it's still running
4274 })
4275 });
4276 assert!(
4277 terminal_running_before,
4278 "Terminal should be running before checkpoint restore"
4279 );
4280
4281 // Verify we have the expected entries before restore
4282 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4283 assert!(
4284 entry_count_before > 1,
4285 "Should have multiple entries before restore"
4286 );
4287
4288 // Restore the checkpoint to the second message.
4289 // This should:
4290 // 1. Cancel any in-progress generation (via the cancel() call)
4291 // 2. Remove the terminal that was created after that point
4292 thread
4293 .update(cx, |thread, cx| {
4294 thread.restore_checkpoint(second_message_id, cx)
4295 })
4296 .await
4297 .unwrap();
4298
4299 // Verify that no send_task is in progress after restore
4300 // (cancel() clears the send_task)
4301 let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4302 assert!(
4303 !has_send_task_after,
4304 "Should not have a send_task after restore (cancel should have cleared it)"
4305 );
4306
4307 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4308 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4309 assert_eq!(
4310 entry_count, 1,
4311 "Should have 1 entry after restore (only the first user message)"
4312 );
4313
4314 // Verify the 2 completed terminals from before the checkpoint still exist
4315 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4316 thread.terminals.contains_key(&terminal_id_1)
4317 });
4318 assert!(
4319 terminal_1_exists,
4320 "Terminal 1 (from before checkpoint) should still exist"
4321 );
4322
4323 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4324 thread.terminals.contains_key(&terminal_id_2)
4325 });
4326 assert!(
4327 terminal_2_exists,
4328 "Terminal 2 (from before checkpoint) should still exist"
4329 );
4330
4331 // Verify they're still in completed state
4332 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4333 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4334 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4335 });
4336 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4337
4338 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4339 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4340 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4341 });
4342 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4343
4344 // Verify the running terminal (created after checkpoint) was removed
4345 let terminal_3_exists =
4346 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4347 assert!(
4348 !terminal_3_exists,
4349 "Terminal 3 (created after checkpoint) should have been removed"
4350 );
4351
4352 // Verify total count is 2 (the two from before the checkpoint)
4353 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4354 assert_eq!(
4355 terminal_count, 2,
4356 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4357 );
4358 }
4359
4360 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4361 /// even when a new user message is added while the async checkpoint comparison is in progress.
4362 ///
4363 /// This is a regression test for a bug where update_last_checkpoint would fail with
4364 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4365 /// update_last_checkpoint started and when its async closure ran.
4366 #[gpui::test]
4367 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4368 init_test(cx);
4369
4370 let fs = FakeFs::new(cx.executor());
4371 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4372 .await;
4373 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4374
4375 let handler_done = Arc::new(AtomicBool::new(false));
4376 let handler_done_clone = handler_done.clone();
4377 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4378 move |_, _thread, _cx| {
4379 handler_done_clone.store(true, SeqCst);
4380 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4381 },
4382 ));
4383
4384 let thread = cx
4385 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4386 .await
4387 .unwrap();
4388
4389 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4390 let send_task = cx.background_executor.spawn(send_future);
4391
4392 // Tick until handler completes, then a few more to let update_last_checkpoint start
4393 while !handler_done.load(SeqCst) {
4394 cx.executor().tick();
4395 }
4396 for _ in 0..5 {
4397 cx.executor().tick();
4398 }
4399
4400 thread.update(cx, |thread, cx| {
4401 thread.push_entry(
4402 AgentThreadEntry::UserMessage(UserMessage {
4403 id: Some(UserMessageId::new()),
4404 content: ContentBlock::Empty,
4405 chunks: vec!["Injected message (no checkpoint)".into()],
4406 checkpoint: None,
4407 indented: false,
4408 }),
4409 cx,
4410 );
4411 });
4412
4413 cx.run_until_parked();
4414 let result = send_task.await;
4415
4416 assert!(
4417 result.is_ok(),
4418 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4419 result.err()
4420 );
4421 }
4422
4423 /// Tests that when a follow-up message is sent during generation,
4424 /// the first turn completing does NOT clear `running_turn` because
4425 /// it now belongs to the second turn.
4426 #[gpui::test]
4427 async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
4428 init_test(cx);
4429
4430 let fs = FakeFs::new(cx.executor());
4431 let project = Project::test(fs, [], cx).await;
4432
4433 // First handler waits for this signal before completing
4434 let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
4435 let first_complete_rx = RefCell::new(Some(first_complete_rx));
4436
4437 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
4438 move |params, _thread, _cx| {
4439 let first_complete_rx = first_complete_rx.borrow_mut().take();
4440 let is_first = params
4441 .prompt
4442 .iter()
4443 .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
4444
4445 async move {
4446 if is_first {
4447 // First handler waits until signaled
4448 if let Some(rx) = first_complete_rx {
4449 rx.await.ok();
4450 }
4451 }
4452 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
4453 }
4454 .boxed_local()
4455 }
4456 }));
4457
4458 let thread = cx
4459 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4460 .await
4461 .unwrap();
4462
4463 // Send first message (turn_id=1) - handler will block
4464 let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
4465 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
4466
4467 // Send second message (turn_id=2) while first is still blocked
4468 // This calls cancel() which takes turn 1's running_turn and sets turn 2's
4469 let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
4470 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
4471
4472 let running_turn_after_second_send =
4473 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4474 assert_eq!(
4475 running_turn_after_second_send,
4476 Some(2),
4477 "running_turn should be set to turn 2 after sending second message"
4478 );
4479
4480 // Now signal first handler to complete
4481 first_complete_tx.send(()).ok();
4482
4483 // First request completes - should NOT clear running_turn
4484 // because running_turn now belongs to turn 2
4485 first_request.await.unwrap();
4486
4487 let running_turn_after_first =
4488 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4489 assert_eq!(
4490 running_turn_after_first,
4491 Some(2),
4492 "first turn completing should not clear running_turn (belongs to turn 2)"
4493 );
4494
4495 // Second request completes - SHOULD clear running_turn
4496 second_request.await.unwrap();
4497
4498 let running_turn_after_second =
4499 thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4500 assert!(
4501 !running_turn_after_second,
4502 "second turn completing should clear running_turn"
4503 );
4504 }
4505
4506 #[gpui::test]
4507 async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
4508 cx: &mut TestAppContext,
4509 ) {
4510 init_test(cx);
4511
4512 let fs = FakeFs::new(cx.executor());
4513 let project = Project::test(fs, [], cx).await;
4514
4515 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4516 move |_params, thread, mut cx| {
4517 async move {
4518 thread
4519 .update(&mut cx, |thread, cx| {
4520 thread.handle_session_update(
4521 acp::SessionUpdate::ToolCall(
4522 acp::ToolCall::new(
4523 acp::ToolCallId::new("test-tool"),
4524 "Test Tool",
4525 )
4526 .kind(acp::ToolKind::Fetch)
4527 .status(acp::ToolCallStatus::InProgress),
4528 ),
4529 cx,
4530 )
4531 })
4532 .unwrap()
4533 .unwrap();
4534
4535 Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
4536 }
4537 .boxed_local()
4538 },
4539 ));
4540
4541 let thread = cx
4542 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4543 .await
4544 .unwrap();
4545
4546 let response = thread
4547 .update(cx, |thread, cx| thread.send_raw("test message", cx))
4548 .await;
4549
4550 let response = response
4551 .expect("send should succeed")
4552 .expect("should have response");
4553 assert_eq!(
4554 response.stop_reason,
4555 acp::StopReason::Cancelled,
4556 "response should have Cancelled stop_reason"
4557 );
4558
4559 thread.read_with(cx, |thread, _| {
4560 let tool_entry = thread
4561 .entries
4562 .iter()
4563 .find_map(|e| {
4564 if let AgentThreadEntry::ToolCall(call) = e {
4565 Some(call)
4566 } else {
4567 None
4568 }
4569 })
4570 .expect("should have tool call entry");
4571
4572 assert!(
4573 matches!(tool_entry.status, ToolCallStatus::Canceled),
4574 "tool should be marked as Canceled when response is Cancelled, got {:?}",
4575 tool_entry.status
4576 );
4577 });
4578 }
4579}