1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105 pub is_subagent_output: bool,
106}
107
108impl AssistantMessage {
109 pub fn to_markdown(&self, cx: &App) -> String {
110 format!(
111 "## Assistant\n\n{}\n\n",
112 self.chunks
113 .iter()
114 .map(|chunk| chunk.to_markdown(cx))
115 .join("\n\n")
116 )
117 }
118}
119
120#[derive(Debug, PartialEq)]
121pub enum AssistantMessageChunk {
122 Message { block: ContentBlock },
123 Thought { block: ContentBlock },
124}
125
126impl AssistantMessageChunk {
127 pub fn from_str(
128 chunk: &str,
129 language_registry: &Arc<LanguageRegistry>,
130 path_style: PathStyle,
131 cx: &mut App,
132 ) -> Self {
133 Self::Message {
134 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
135 }
136 }
137
138 fn to_markdown(&self, cx: &App) -> String {
139 match self {
140 Self::Message { block } => block.to_markdown(cx).to_string(),
141 Self::Thought { block } => {
142 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
143 }
144 }
145 }
146}
147
148#[derive(Debug)]
149pub enum AgentThreadEntry {
150 UserMessage(UserMessage),
151 AssistantMessage(AssistantMessage),
152 ToolCall(ToolCall),
153}
154
155impl AgentThreadEntry {
156 pub fn is_indented(&self) -> bool {
157 match self {
158 Self::UserMessage(message) => message.indented,
159 Self::AssistantMessage(message) => message.indented,
160 Self::ToolCall(_) => false,
161 }
162 }
163
164 pub fn to_markdown(&self, cx: &App) -> String {
165 match self {
166 Self::UserMessage(message) => message.to_markdown(cx),
167 Self::AssistantMessage(message) => message.to_markdown(cx),
168 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
169 }
170 }
171
172 pub fn user_message(&self) -> Option<&UserMessage> {
173 if let AgentThreadEntry::UserMessage(message) = self {
174 Some(message)
175 } else {
176 None
177 }
178 }
179
180 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
181 if let AgentThreadEntry::ToolCall(call) = self {
182 itertools::Either::Left(call.diffs())
183 } else {
184 itertools::Either::Right(std::iter::empty())
185 }
186 }
187
188 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
189 if let AgentThreadEntry::ToolCall(call) = self {
190 itertools::Either::Left(call.terminals())
191 } else {
192 itertools::Either::Right(std::iter::empty())
193 }
194 }
195
196 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
197 if let AgentThreadEntry::ToolCall(ToolCall {
198 locations,
199 resolved_locations,
200 ..
201 }) = self
202 {
203 Some((
204 locations.get(ix)?.clone(),
205 resolved_locations.get(ix)?.clone()?,
206 ))
207 } else {
208 None
209 }
210 }
211}
212
213#[derive(Debug)]
214pub struct ToolCall {
215 pub id: acp::ToolCallId,
216 pub label: Entity<Markdown>,
217 pub kind: acp::ToolKind,
218 pub content: Vec<ToolCallContent>,
219 pub status: ToolCallStatus,
220 pub locations: Vec<acp::ToolCallLocation>,
221 pub resolved_locations: Vec<Option<AgentLocation>>,
222 pub raw_input: Option<serde_json::Value>,
223 pub raw_input_markdown: Option<Entity<Markdown>>,
224 pub raw_output: Option<serde_json::Value>,
225 pub tool_name: Option<SharedString>,
226 pub subagent_session_id: Option<acp::SessionId>,
227}
228
229impl ToolCall {
230 fn from_acp(
231 tool_call: acp::ToolCall,
232 status: ToolCallStatus,
233 language_registry: Arc<LanguageRegistry>,
234 path_style: PathStyle,
235 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
236 cx: &mut App,
237 ) -> Result<Self> {
238 let title = if tool_call.kind == acp::ToolKind::Execute {
239 tool_call.title
240 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
241 first_line.to_owned() + "…"
242 } else {
243 tool_call.title
244 };
245 let mut content = Vec::with_capacity(tool_call.content.len());
246 for item in tool_call.content {
247 if let Some(item) = ToolCallContent::from_acp(
248 item,
249 language_registry.clone(),
250 path_style,
251 terminals,
252 cx,
253 )? {
254 content.push(item);
255 }
256 }
257
258 let raw_input_markdown = tool_call
259 .raw_input
260 .as_ref()
261 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
262
263 let tool_name = tool_name_from_meta(&tool_call.meta);
264
265 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
266
267 let result = Self {
268 id: tool_call.tool_call_id,
269 label: cx
270 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
271 kind: tool_call.kind,
272 content,
273 locations: tool_call.locations,
274 resolved_locations: Vec::default(),
275 status,
276 raw_input: tool_call.raw_input,
277 raw_input_markdown,
278 raw_output: tool_call.raw_output,
279 tool_name,
280 subagent_session_id: subagent_session,
281 };
282 Ok(result)
283 }
284
285 fn update_fields(
286 &mut self,
287 fields: acp::ToolCallUpdateFields,
288 meta: Option<acp::Meta>,
289 language_registry: Arc<LanguageRegistry>,
290 path_style: PathStyle,
291 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
292 cx: &mut App,
293 ) -> Result<()> {
294 let acp::ToolCallUpdateFields {
295 kind,
296 status,
297 title,
298 content,
299 locations,
300 raw_input,
301 raw_output,
302 ..
303 } = fields;
304
305 if let Some(kind) = kind {
306 self.kind = kind;
307 }
308
309 if let Some(status) = status {
310 self.status = status.into();
311 }
312
313 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
314 self.subagent_session_id = Some(subagent_session_id);
315 }
316
317 if let Some(title) = title {
318 if self.kind == acp::ToolKind::Execute {
319 for terminal in self.terminals() {
320 terminal.update(cx, |terminal, cx| {
321 terminal.update_command_label(&title, cx);
322 });
323 }
324 }
325 self.label.update(cx, |label, cx| {
326 if self.kind == acp::ToolKind::Execute {
327 label.replace(title, cx);
328 } else if let Some((first_line, _)) = title.split_once("\n") {
329 label.replace(first_line.to_owned() + "…", cx);
330 } else {
331 label.replace(title, cx);
332 }
333 });
334 }
335
336 if let Some(content) = content {
337 let mut new_content_len = content.len();
338 let mut content = content.into_iter();
339
340 // Reuse existing content if we can
341 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
342 let valid_content =
343 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
344 if !valid_content {
345 new_content_len -= 1;
346 }
347 }
348 for new in content {
349 if let Some(new) = ToolCallContent::from_acp(
350 new,
351 language_registry.clone(),
352 path_style,
353 terminals,
354 cx,
355 )? {
356 self.content.push(new);
357 } else {
358 new_content_len -= 1;
359 }
360 }
361 self.content.truncate(new_content_len);
362 }
363
364 if let Some(locations) = locations {
365 self.locations = locations;
366 }
367
368 if let Some(raw_input) = raw_input {
369 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
370 self.raw_input = Some(raw_input);
371 }
372
373 if let Some(raw_output) = raw_output {
374 if self.content.is_empty()
375 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
376 {
377 self.content
378 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
379 markdown,
380 }));
381 }
382 self.raw_output = Some(raw_output);
383 }
384 Ok(())
385 }
386
387 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
388 self.content.iter().filter_map(|content| match content {
389 ToolCallContent::Diff(diff) => Some(diff),
390 ToolCallContent::ContentBlock(_) => None,
391 ToolCallContent::Terminal(_) => None,
392 })
393 }
394
395 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
396 self.content.iter().filter_map(|content| match content {
397 ToolCallContent::Terminal(terminal) => Some(terminal),
398 ToolCallContent::ContentBlock(_) => None,
399 ToolCallContent::Diff(_) => None,
400 })
401 }
402
403 pub fn is_subagent(&self) -> bool {
404 self.tool_name.as_ref().is_some_and(|s| s == "spawn_agent")
405 || self.subagent_session_id.is_some()
406 }
407
408 pub fn to_markdown(&self, cx: &App) -> String {
409 let mut markdown = format!(
410 "**Tool Call: {}**\nStatus: {}\n\n",
411 self.label.read(cx).source(),
412 self.status
413 );
414 for content in &self.content {
415 markdown.push_str(content.to_markdown(cx).as_str());
416 markdown.push_str("\n\n");
417 }
418 markdown
419 }
420
421 async fn resolve_location(
422 location: acp::ToolCallLocation,
423 project: WeakEntity<Project>,
424 cx: &mut AsyncApp,
425 ) -> Option<ResolvedLocation> {
426 let buffer = project
427 .update(cx, |project, cx| {
428 project
429 .project_path_for_absolute_path(&location.path, cx)
430 .map(|path| project.open_buffer(path, cx))
431 })
432 .ok()??;
433 let buffer = buffer.await.log_err()?;
434 let position = buffer.update(cx, |buffer, _| {
435 let snapshot = buffer.snapshot();
436 if let Some(row) = location.line {
437 let column = snapshot.indent_size_for_line(row).len;
438 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
439 snapshot.anchor_before(point)
440 } else {
441 Anchor::min_for_buffer(snapshot.remote_id())
442 }
443 });
444
445 Some(ResolvedLocation { buffer, position })
446 }
447
448 fn resolve_locations(
449 &self,
450 project: Entity<Project>,
451 cx: &mut App,
452 ) -> Task<Vec<Option<ResolvedLocation>>> {
453 let locations = self.locations.clone();
454 project.update(cx, |_, cx| {
455 cx.spawn(async move |project, cx| {
456 let mut new_locations = Vec::new();
457 for location in locations {
458 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
459 }
460 new_locations
461 })
462 })
463 }
464}
465
466// Separate so we can hold a strong reference to the buffer
467// for saving on the thread
468#[derive(Clone, Debug, PartialEq, Eq)]
469struct ResolvedLocation {
470 buffer: Entity<Buffer>,
471 position: Anchor,
472}
473
474impl From<&ResolvedLocation> for AgentLocation {
475 fn from(value: &ResolvedLocation) -> Self {
476 Self {
477 buffer: value.buffer.downgrade(),
478 position: value.position,
479 }
480 }
481}
482
483#[derive(Debug)]
484pub enum ToolCallStatus {
485 /// The tool call hasn't started running yet, but we start showing it to
486 /// the user.
487 Pending,
488 /// The tool call is waiting for confirmation from the user.
489 WaitingForConfirmation {
490 options: PermissionOptions,
491 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
492 },
493 /// The tool call is currently running.
494 InProgress,
495 /// The tool call completed successfully.
496 Completed,
497 /// The tool call failed.
498 Failed,
499 /// The user rejected the tool call.
500 Rejected,
501 /// The user canceled generation so the tool call was canceled.
502 Canceled,
503}
504
505impl From<acp::ToolCallStatus> for ToolCallStatus {
506 fn from(status: acp::ToolCallStatus) -> Self {
507 match status {
508 acp::ToolCallStatus::Pending => Self::Pending,
509 acp::ToolCallStatus::InProgress => Self::InProgress,
510 acp::ToolCallStatus::Completed => Self::Completed,
511 acp::ToolCallStatus::Failed => Self::Failed,
512 _ => Self::Pending,
513 }
514 }
515}
516
517impl Display for ToolCallStatus {
518 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
519 write!(
520 f,
521 "{}",
522 match self {
523 ToolCallStatus::Pending => "Pending",
524 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
525 ToolCallStatus::InProgress => "In Progress",
526 ToolCallStatus::Completed => "Completed",
527 ToolCallStatus::Failed => "Failed",
528 ToolCallStatus::Rejected => "Rejected",
529 ToolCallStatus::Canceled => "Canceled",
530 }
531 )
532 }
533}
534
535#[derive(Debug, PartialEq, Clone)]
536pub enum ContentBlock {
537 Empty,
538 Markdown { markdown: Entity<Markdown> },
539 ResourceLink { resource_link: acp::ResourceLink },
540 Image { image: Arc<gpui::Image> },
541}
542
543impl ContentBlock {
544 pub fn new(
545 block: acp::ContentBlock,
546 language_registry: &Arc<LanguageRegistry>,
547 path_style: PathStyle,
548 cx: &mut App,
549 ) -> Self {
550 let mut this = Self::Empty;
551 this.append(block, language_registry, path_style, cx);
552 this
553 }
554
555 pub fn new_combined(
556 blocks: impl IntoIterator<Item = acp::ContentBlock>,
557 language_registry: Arc<LanguageRegistry>,
558 path_style: PathStyle,
559 cx: &mut App,
560 ) -> Self {
561 let mut this = Self::Empty;
562 for block in blocks {
563 this.append(block, &language_registry, path_style, cx);
564 }
565 this
566 }
567
568 pub fn append(
569 &mut self,
570 block: acp::ContentBlock,
571 language_registry: &Arc<LanguageRegistry>,
572 path_style: PathStyle,
573 cx: &mut App,
574 ) {
575 match (&mut *self, &block) {
576 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
577 *self = ContentBlock::ResourceLink {
578 resource_link: resource_link.clone(),
579 };
580 }
581 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
582 if let Some(image) = Self::decode_image(image_content) {
583 *self = ContentBlock::Image { image };
584 } else {
585 let new_content = Self::image_md(image_content);
586 *self = Self::create_markdown_block(new_content, language_registry, cx);
587 }
588 }
589 (ContentBlock::Empty, _) => {
590 let new_content = Self::block_string_contents(&block, path_style);
591 *self = Self::create_markdown_block(new_content, language_registry, cx);
592 }
593 (ContentBlock::Markdown { markdown }, _) => {
594 let new_content = Self::block_string_contents(&block, path_style);
595 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
596 }
597 (ContentBlock::ResourceLink { resource_link }, _) => {
598 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
599 let new_content = Self::block_string_contents(&block, path_style);
600 let combined = format!("{}\n{}", existing_content, new_content);
601 *self = Self::create_markdown_block(combined, language_registry, cx);
602 }
603 (ContentBlock::Image { .. }, _) => {
604 let new_content = Self::block_string_contents(&block, path_style);
605 let combined = format!("`Image`\n{}", new_content);
606 *self = Self::create_markdown_block(combined, language_registry, cx);
607 }
608 }
609 }
610
611 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
612 use base64::Engine as _;
613
614 let bytes = base64::engine::general_purpose::STANDARD
615 .decode(image_content.data.as_bytes())
616 .ok()?;
617 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
618 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
619 }
620
621 fn create_markdown_block(
622 content: String,
623 language_registry: &Arc<LanguageRegistry>,
624 cx: &mut App,
625 ) -> ContentBlock {
626 ContentBlock::Markdown {
627 markdown: cx
628 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
629 }
630 }
631
632 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
633 match block {
634 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
635 acp::ContentBlock::ResourceLink(resource_link) => {
636 Self::resource_link_md(&resource_link.uri, path_style)
637 }
638 acp::ContentBlock::Resource(acp::EmbeddedResource {
639 resource:
640 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
641 uri,
642 ..
643 }),
644 ..
645 }) => Self::resource_link_md(uri, path_style),
646 acp::ContentBlock::Image(image) => Self::image_md(image),
647 _ => String::new(),
648 }
649 }
650
651 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
652 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
653 uri.as_link().to_string()
654 } else {
655 uri.to_string()
656 }
657 }
658
659 fn image_md(_image: &acp::ImageContent) -> String {
660 "`Image`".into()
661 }
662
663 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
664 match self {
665 ContentBlock::Empty => "",
666 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
667 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
668 ContentBlock::Image { .. } => "`Image`",
669 }
670 }
671
672 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
673 match self {
674 ContentBlock::Empty => None,
675 ContentBlock::Markdown { markdown } => Some(markdown),
676 ContentBlock::ResourceLink { .. } => None,
677 ContentBlock::Image { .. } => None,
678 }
679 }
680
681 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
682 match self {
683 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
684 _ => None,
685 }
686 }
687
688 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
689 match self {
690 ContentBlock::Image { image } => Some(image),
691 _ => None,
692 }
693 }
694}
695
696#[derive(Debug)]
697pub enum ToolCallContent {
698 ContentBlock(ContentBlock),
699 Diff(Entity<Diff>),
700 Terminal(Entity<Terminal>),
701}
702
703impl ToolCallContent {
704 pub fn from_acp(
705 content: acp::ToolCallContent,
706 language_registry: Arc<LanguageRegistry>,
707 path_style: PathStyle,
708 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
709 cx: &mut App,
710 ) -> Result<Option<Self>> {
711 match content {
712 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
713 Ok(Some(Self::ContentBlock(ContentBlock::new(
714 content,
715 &language_registry,
716 path_style,
717 cx,
718 ))))
719 }
720 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
721 Diff::finalized(
722 diff.path.to_string_lossy().into_owned(),
723 diff.old_text,
724 diff.new_text,
725 language_registry,
726 cx,
727 )
728 })))),
729 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
730 .get(&terminal_id)
731 .cloned()
732 .map(|terminal| Some(Self::Terminal(terminal)))
733 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
734 _ => Ok(None),
735 }
736 }
737
738 pub fn update_from_acp(
739 &mut self,
740 new: acp::ToolCallContent,
741 language_registry: Arc<LanguageRegistry>,
742 path_style: PathStyle,
743 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
744 cx: &mut App,
745 ) -> Result<bool> {
746 let needs_update = match (&self, &new) {
747 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
748 old_diff.read(cx).needs_update(
749 new_diff.old_text.as_deref().unwrap_or(""),
750 &new_diff.new_text,
751 cx,
752 )
753 }
754 _ => true,
755 };
756
757 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
758 if needs_update {
759 *self = update;
760 }
761 Ok(true)
762 } else {
763 Ok(false)
764 }
765 }
766
767 pub fn to_markdown(&self, cx: &App) -> String {
768 match self {
769 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
770 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
771 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
772 }
773 }
774
775 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
776 match self {
777 Self::ContentBlock(content) => content.image(),
778 _ => None,
779 }
780 }
781}
782
783#[derive(Debug, PartialEq)]
784pub enum ToolCallUpdate {
785 UpdateFields(acp::ToolCallUpdate),
786 UpdateDiff(ToolCallUpdateDiff),
787 UpdateTerminal(ToolCallUpdateTerminal),
788}
789
790impl ToolCallUpdate {
791 fn id(&self) -> &acp::ToolCallId {
792 match self {
793 Self::UpdateFields(update) => &update.tool_call_id,
794 Self::UpdateDiff(diff) => &diff.id,
795 Self::UpdateTerminal(terminal) => &terminal.id,
796 }
797 }
798}
799
800impl From<acp::ToolCallUpdate> for ToolCallUpdate {
801 fn from(update: acp::ToolCallUpdate) -> Self {
802 Self::UpdateFields(update)
803 }
804}
805
806impl From<ToolCallUpdateDiff> for ToolCallUpdate {
807 fn from(diff: ToolCallUpdateDiff) -> Self {
808 Self::UpdateDiff(diff)
809 }
810}
811
812#[derive(Debug, PartialEq)]
813pub struct ToolCallUpdateDiff {
814 pub id: acp::ToolCallId,
815 pub diff: Entity<Diff>,
816}
817
818impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
819 fn from(terminal: ToolCallUpdateTerminal) -> Self {
820 Self::UpdateTerminal(terminal)
821 }
822}
823
824#[derive(Debug, PartialEq)]
825pub struct ToolCallUpdateTerminal {
826 pub id: acp::ToolCallId,
827 pub terminal: Entity<Terminal>,
828}
829
830#[derive(Debug, Default)]
831pub struct Plan {
832 pub entries: Vec<PlanEntry>,
833}
834
835#[derive(Debug)]
836pub struct PlanStats<'a> {
837 pub in_progress_entry: Option<&'a PlanEntry>,
838 pub pending: u32,
839 pub completed: u32,
840}
841
842impl Plan {
843 pub fn is_empty(&self) -> bool {
844 self.entries.is_empty()
845 }
846
847 pub fn stats(&self) -> PlanStats<'_> {
848 let mut stats = PlanStats {
849 in_progress_entry: None,
850 pending: 0,
851 completed: 0,
852 };
853
854 for entry in &self.entries {
855 match &entry.status {
856 acp::PlanEntryStatus::Pending => {
857 stats.pending += 1;
858 }
859 acp::PlanEntryStatus::InProgress => {
860 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
861 }
862 acp::PlanEntryStatus::Completed => {
863 stats.completed += 1;
864 }
865 _ => {}
866 }
867 }
868
869 stats
870 }
871}
872
873#[derive(Debug)]
874pub struct PlanEntry {
875 pub content: Entity<Markdown>,
876 pub priority: acp::PlanEntryPriority,
877 pub status: acp::PlanEntryStatus,
878}
879
880impl PlanEntry {
881 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
882 Self {
883 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
884 priority: entry.priority,
885 status: entry.status,
886 }
887 }
888}
889
890#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
891pub struct TokenUsage {
892 pub max_tokens: u64,
893 pub used_tokens: u64,
894 pub input_tokens: u64,
895 pub output_tokens: u64,
896 pub max_output_tokens: Option<u64>,
897}
898
899pub const TOKEN_USAGE_WARNING_THRESHOLD: f32 = 0.8;
900
901impl TokenUsage {
902 pub fn ratio(&self) -> TokenUsageRatio {
903 #[cfg(debug_assertions)]
904 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
905 .unwrap_or(TOKEN_USAGE_WARNING_THRESHOLD.to_string())
906 .parse()
907 .unwrap();
908 #[cfg(not(debug_assertions))]
909 let warning_threshold: f32 = TOKEN_USAGE_WARNING_THRESHOLD;
910
911 // When the maximum is unknown because there is no selected model,
912 // avoid showing the token limit warning.
913 if self.max_tokens == 0 {
914 TokenUsageRatio::Normal
915 } else if self.used_tokens >= self.max_tokens {
916 TokenUsageRatio::Exceeded
917 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
918 TokenUsageRatio::Warning
919 } else {
920 TokenUsageRatio::Normal
921 }
922 }
923}
924
925#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
926pub enum TokenUsageRatio {
927 Normal,
928 Warning,
929 Exceeded,
930}
931
932#[derive(Debug, Clone)]
933pub struct RetryStatus {
934 pub last_error: SharedString,
935 pub attempt: usize,
936 pub max_attempts: usize,
937 pub started_at: Instant,
938 pub duration: Duration,
939}
940
941struct RunningTurn {
942 id: u32,
943 send_task: Task<()>,
944}
945
946pub struct AcpThread {
947 parent_session_id: Option<acp::SessionId>,
948 title: SharedString,
949 entries: Vec<AgentThreadEntry>,
950 plan: Plan,
951 project: Entity<Project>,
952 action_log: Entity<ActionLog>,
953 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
954 turn_id: u32,
955 running_turn: Option<RunningTurn>,
956 connection: Rc<dyn AgentConnection>,
957 session_id: acp::SessionId,
958 token_usage: Option<TokenUsage>,
959 prompt_capabilities: acp::PromptCapabilities,
960 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
961 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
962 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
963 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
964 had_error: bool,
965}
966
967impl From<&AcpThread> for ActionLogTelemetry {
968 fn from(value: &AcpThread) -> Self {
969 Self {
970 agent_telemetry_id: value.connection().telemetry_id(),
971 session_id: value.session_id.0.clone(),
972 }
973 }
974}
975
976#[derive(Debug)]
977pub enum AcpThreadEvent {
978 NewEntry,
979 TitleUpdated,
980 TokenUsageUpdated,
981 EntryUpdated(usize),
982 EntriesRemoved(Range<usize>),
983 ToolAuthorizationRequested(acp::ToolCallId),
984 ToolAuthorizationReceived(acp::ToolCallId),
985 Retry(RetryStatus),
986 SubagentSpawned(acp::SessionId),
987 Stopped(acp::StopReason),
988 Error,
989 LoadError(LoadError),
990 PromptCapabilitiesUpdated,
991 Refusal,
992 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
993 ModeUpdated(acp::SessionModeId),
994 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
995}
996
997impl EventEmitter<AcpThreadEvent> for AcpThread {}
998
999#[derive(Debug, Clone)]
1000pub enum TerminalProviderEvent {
1001 Created {
1002 terminal_id: acp::TerminalId,
1003 label: String,
1004 cwd: Option<PathBuf>,
1005 output_byte_limit: Option<u64>,
1006 terminal: Entity<::terminal::Terminal>,
1007 },
1008 Output {
1009 terminal_id: acp::TerminalId,
1010 data: Vec<u8>,
1011 },
1012 TitleChanged {
1013 terminal_id: acp::TerminalId,
1014 title: String,
1015 },
1016 Exit {
1017 terminal_id: acp::TerminalId,
1018 status: acp::TerminalExitStatus,
1019 },
1020}
1021
1022#[derive(Debug, Clone)]
1023pub enum TerminalProviderCommand {
1024 WriteInput {
1025 terminal_id: acp::TerminalId,
1026 bytes: Vec<u8>,
1027 },
1028 Resize {
1029 terminal_id: acp::TerminalId,
1030 cols: u16,
1031 rows: u16,
1032 },
1033 Close {
1034 terminal_id: acp::TerminalId,
1035 },
1036}
1037
1038impl AcpThread {
1039 pub fn on_terminal_provider_event(
1040 &mut self,
1041 event: TerminalProviderEvent,
1042 cx: &mut Context<Self>,
1043 ) {
1044 match event {
1045 TerminalProviderEvent::Created {
1046 terminal_id,
1047 label,
1048 cwd,
1049 output_byte_limit,
1050 terminal,
1051 } => {
1052 let entity = self.register_terminal_created(
1053 terminal_id.clone(),
1054 label,
1055 cwd,
1056 output_byte_limit,
1057 terminal,
1058 cx,
1059 );
1060
1061 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1062 for data in chunks.drain(..) {
1063 entity.update(cx, |term, cx| {
1064 term.inner().update(cx, |inner, cx| {
1065 inner.write_output(&data, cx);
1066 })
1067 });
1068 }
1069 }
1070
1071 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1072 entity.update(cx, |_term, cx| {
1073 cx.notify();
1074 });
1075 }
1076
1077 cx.notify();
1078 }
1079 TerminalProviderEvent::Output { terminal_id, data } => {
1080 if let Some(entity) = self.terminals.get(&terminal_id) {
1081 entity.update(cx, |term, cx| {
1082 term.inner().update(cx, |inner, cx| {
1083 inner.write_output(&data, cx);
1084 })
1085 });
1086 } else {
1087 self.pending_terminal_output
1088 .entry(terminal_id)
1089 .or_default()
1090 .push(data);
1091 }
1092 }
1093 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1094 if let Some(entity) = self.terminals.get(&terminal_id) {
1095 entity.update(cx, |term, cx| {
1096 term.inner().update(cx, |inner, cx| {
1097 inner.breadcrumb_text = title;
1098 cx.emit(::terminal::Event::BreadcrumbsChanged);
1099 })
1100 });
1101 }
1102 }
1103 TerminalProviderEvent::Exit {
1104 terminal_id,
1105 status,
1106 } => {
1107 if let Some(entity) = self.terminals.get(&terminal_id) {
1108 entity.update(cx, |_term, cx| {
1109 cx.notify();
1110 });
1111 } else {
1112 self.pending_terminal_exit.insert(terminal_id, status);
1113 }
1114 }
1115 }
1116 }
1117}
1118
1119#[derive(PartialEq, Eq, Debug)]
1120pub enum ThreadStatus {
1121 Idle,
1122 Generating,
1123}
1124
1125#[derive(Debug, Clone)]
1126pub enum LoadError {
1127 Unsupported {
1128 command: SharedString,
1129 current_version: SharedString,
1130 minimum_version: SharedString,
1131 },
1132 FailedToInstall(SharedString),
1133 Exited {
1134 status: ExitStatus,
1135 },
1136 Other(SharedString),
1137}
1138
1139impl Display for LoadError {
1140 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1141 match self {
1142 LoadError::Unsupported {
1143 command: path,
1144 current_version,
1145 minimum_version,
1146 } => {
1147 write!(
1148 f,
1149 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1150 )
1151 }
1152 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1153 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1154 LoadError::Other(msg) => write!(f, "{msg}"),
1155 }
1156 }
1157}
1158
1159impl Error for LoadError {}
1160
1161impl AcpThread {
1162 pub fn new(
1163 parent_session_id: Option<acp::SessionId>,
1164 title: impl Into<SharedString>,
1165 connection: Rc<dyn AgentConnection>,
1166 project: Entity<Project>,
1167 action_log: Entity<ActionLog>,
1168 session_id: acp::SessionId,
1169 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1170 cx: &mut Context<Self>,
1171 ) -> Self {
1172 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1173 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1174 loop {
1175 let caps = prompt_capabilities_rx.recv().await?;
1176 this.update(cx, |this, cx| {
1177 this.prompt_capabilities = caps;
1178 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1179 })?;
1180 }
1181 });
1182
1183 Self {
1184 parent_session_id,
1185 action_log,
1186 shared_buffers: Default::default(),
1187 entries: Default::default(),
1188 plan: Default::default(),
1189 title: title.into(),
1190 project,
1191 running_turn: None,
1192 turn_id: 0,
1193 connection,
1194 session_id,
1195 token_usage: None,
1196 prompt_capabilities,
1197 _observe_prompt_capabilities: task,
1198 terminals: HashMap::default(),
1199 pending_terminal_output: HashMap::default(),
1200 pending_terminal_exit: HashMap::default(),
1201 had_error: false,
1202 }
1203 }
1204
1205 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1206 self.parent_session_id.as_ref()
1207 }
1208
1209 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1210 self.prompt_capabilities.clone()
1211 }
1212
1213 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1214 &self.connection
1215 }
1216
1217 pub fn action_log(&self) -> &Entity<ActionLog> {
1218 &self.action_log
1219 }
1220
1221 pub fn project(&self) -> &Entity<Project> {
1222 &self.project
1223 }
1224
1225 pub fn title(&self) -> SharedString {
1226 self.title.clone()
1227 }
1228
1229 pub fn entries(&self) -> &[AgentThreadEntry] {
1230 &self.entries
1231 }
1232
1233 pub fn session_id(&self) -> &acp::SessionId {
1234 &self.session_id
1235 }
1236
1237 pub fn status(&self) -> ThreadStatus {
1238 if self.running_turn.is_some() {
1239 ThreadStatus::Generating
1240 } else {
1241 ThreadStatus::Idle
1242 }
1243 }
1244
1245 pub fn had_error(&self) -> bool {
1246 self.had_error
1247 }
1248
1249 pub fn is_waiting_for_confirmation(&self) -> bool {
1250 for entry in self.entries.iter().rev() {
1251 match entry {
1252 AgentThreadEntry::UserMessage(_) => return false,
1253 AgentThreadEntry::ToolCall(ToolCall {
1254 status: ToolCallStatus::WaitingForConfirmation { .. },
1255 ..
1256 }) => return true,
1257 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1258 }
1259 }
1260 false
1261 }
1262
1263 pub fn token_usage(&self) -> Option<&TokenUsage> {
1264 self.token_usage.as_ref()
1265 }
1266
1267 pub fn has_pending_edit_tool_calls(&self) -> bool {
1268 for entry in self.entries.iter().rev() {
1269 match entry {
1270 AgentThreadEntry::UserMessage(_) => return false,
1271 AgentThreadEntry::ToolCall(
1272 call @ ToolCall {
1273 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1274 ..
1275 },
1276 ) if call.diffs().next().is_some() => {
1277 return true;
1278 }
1279 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1280 }
1281 }
1282
1283 false
1284 }
1285
1286 pub fn has_in_progress_tool_calls(&self) -> bool {
1287 for entry in self.entries.iter().rev() {
1288 match entry {
1289 AgentThreadEntry::UserMessage(_) => return false,
1290 AgentThreadEntry::ToolCall(ToolCall {
1291 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1292 ..
1293 }) => {
1294 return true;
1295 }
1296 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1297 }
1298 }
1299
1300 false
1301 }
1302
1303 pub fn used_tools_since_last_user_message(&self) -> bool {
1304 for entry in self.entries.iter().rev() {
1305 match entry {
1306 AgentThreadEntry::UserMessage(..) => return false,
1307 AgentThreadEntry::AssistantMessage(..) => continue,
1308 AgentThreadEntry::ToolCall(..) => return true,
1309 }
1310 }
1311
1312 false
1313 }
1314
1315 pub fn handle_session_update(
1316 &mut self,
1317 update: acp::SessionUpdate,
1318 cx: &mut Context<Self>,
1319 ) -> Result<(), acp::Error> {
1320 match update {
1321 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1322 self.push_user_content_block(None, content, cx);
1323 }
1324 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1325 self.push_assistant_content_block(content, false, cx);
1326 }
1327 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1328 self.push_assistant_content_block(content, true, cx);
1329 }
1330 acp::SessionUpdate::ToolCall(tool_call) => {
1331 self.upsert_tool_call(tool_call, cx)?;
1332 }
1333 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1334 self.update_tool_call(tool_call_update, cx)?;
1335 }
1336 acp::SessionUpdate::Plan(plan) => {
1337 self.update_plan(plan, cx);
1338 }
1339 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1340 available_commands,
1341 ..
1342 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1343 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1344 current_mode_id,
1345 ..
1346 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1347 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1348 config_options,
1349 ..
1350 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1351 _ => {}
1352 }
1353 Ok(())
1354 }
1355
1356 pub fn push_user_content_block(
1357 &mut self,
1358 message_id: Option<UserMessageId>,
1359 chunk: acp::ContentBlock,
1360 cx: &mut Context<Self>,
1361 ) {
1362 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1363 }
1364
1365 pub fn push_user_content_block_with_indent(
1366 &mut self,
1367 message_id: Option<UserMessageId>,
1368 chunk: acp::ContentBlock,
1369 indented: bool,
1370 cx: &mut Context<Self>,
1371 ) {
1372 let language_registry = self.project.read(cx).languages().clone();
1373 let path_style = self.project.read(cx).path_style(cx);
1374 let entries_len = self.entries.len();
1375
1376 if let Some(last_entry) = self.entries.last_mut()
1377 && let AgentThreadEntry::UserMessage(UserMessage {
1378 id,
1379 content,
1380 chunks,
1381 indented: existing_indented,
1382 ..
1383 }) = last_entry
1384 && *existing_indented == indented
1385 {
1386 *id = message_id.or(id.take());
1387 content.append(chunk.clone(), &language_registry, path_style, cx);
1388 chunks.push(chunk);
1389 let idx = entries_len - 1;
1390 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1391 } else {
1392 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1393 self.push_entry(
1394 AgentThreadEntry::UserMessage(UserMessage {
1395 id: message_id,
1396 content,
1397 chunks: vec![chunk],
1398 checkpoint: None,
1399 indented,
1400 }),
1401 cx,
1402 );
1403 }
1404 }
1405
1406 pub fn push_assistant_content_block(
1407 &mut self,
1408 chunk: acp::ContentBlock,
1409 is_thought: bool,
1410 cx: &mut Context<Self>,
1411 ) {
1412 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1413 }
1414
1415 pub fn push_assistant_content_block_with_indent(
1416 &mut self,
1417 chunk: acp::ContentBlock,
1418 is_thought: bool,
1419 indented: bool,
1420 cx: &mut Context<Self>,
1421 ) {
1422 let language_registry = self.project.read(cx).languages().clone();
1423 let path_style = self.project.read(cx).path_style(cx);
1424 let entries_len = self.entries.len();
1425 if let Some(last_entry) = self.entries.last_mut()
1426 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1427 chunks,
1428 indented: existing_indented,
1429 is_subagent_output: _,
1430 }) = last_entry
1431 && *existing_indented == indented
1432 {
1433 let idx = entries_len - 1;
1434 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1435 match (chunks.last_mut(), is_thought) {
1436 (Some(AssistantMessageChunk::Message { block }), false)
1437 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1438 block.append(chunk, &language_registry, path_style, cx)
1439 }
1440 _ => {
1441 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1442 if is_thought {
1443 chunks.push(AssistantMessageChunk::Thought { block })
1444 } else {
1445 chunks.push(AssistantMessageChunk::Message { block })
1446 }
1447 }
1448 }
1449 } else {
1450 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1451 let chunk = if is_thought {
1452 AssistantMessageChunk::Thought { block }
1453 } else {
1454 AssistantMessageChunk::Message { block }
1455 };
1456
1457 self.push_entry(
1458 AgentThreadEntry::AssistantMessage(AssistantMessage {
1459 chunks: vec![chunk],
1460 indented,
1461 is_subagent_output: false,
1462 }),
1463 cx,
1464 );
1465 }
1466 }
1467
1468 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1469 self.entries.push(entry);
1470 cx.emit(AcpThreadEvent::NewEntry);
1471 }
1472
1473 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1474 self.connection.set_title(&self.session_id, cx).is_some()
1475 }
1476
1477 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1478 if title != self.title {
1479 self.title = title.clone();
1480 cx.emit(AcpThreadEvent::TitleUpdated);
1481 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1482 return set_title.run(title, cx);
1483 }
1484 }
1485 Task::ready(Ok(()))
1486 }
1487
1488 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1489 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1490 }
1491
1492 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1493 self.token_usage = usage;
1494 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1495 }
1496
1497 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1498 cx.emit(AcpThreadEvent::Retry(status));
1499 }
1500
1501 pub fn update_tool_call(
1502 &mut self,
1503 update: impl Into<ToolCallUpdate>,
1504 cx: &mut Context<Self>,
1505 ) -> Result<()> {
1506 let update = update.into();
1507 let languages = self.project.read(cx).languages().clone();
1508 let path_style = self.project.read(cx).path_style(cx);
1509
1510 let ix = match self.index_for_tool_call(update.id()) {
1511 Some(ix) => ix,
1512 None => {
1513 // Tool call not found - create a failed tool call entry
1514 let failed_tool_call = ToolCall {
1515 id: update.id().clone(),
1516 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1517 kind: acp::ToolKind::Fetch,
1518 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1519 "Tool call not found".into(),
1520 &languages,
1521 path_style,
1522 cx,
1523 ))],
1524 status: ToolCallStatus::Failed,
1525 locations: Vec::new(),
1526 resolved_locations: Vec::new(),
1527 raw_input: None,
1528 raw_input_markdown: None,
1529 raw_output: None,
1530 tool_name: None,
1531 subagent_session_id: None,
1532 };
1533 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1534 return Ok(());
1535 }
1536 };
1537 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1538 unreachable!()
1539 };
1540
1541 match update {
1542 ToolCallUpdate::UpdateFields(update) => {
1543 let location_updated = update.fields.locations.is_some();
1544 call.update_fields(
1545 update.fields,
1546 update.meta,
1547 languages,
1548 path_style,
1549 &self.terminals,
1550 cx,
1551 )?;
1552 if location_updated {
1553 self.resolve_locations(update.tool_call_id, cx);
1554 }
1555 }
1556 ToolCallUpdate::UpdateDiff(update) => {
1557 call.content.clear();
1558 call.content.push(ToolCallContent::Diff(update.diff));
1559 }
1560 ToolCallUpdate::UpdateTerminal(update) => {
1561 call.content.clear();
1562 call.content
1563 .push(ToolCallContent::Terminal(update.terminal));
1564 }
1565 }
1566
1567 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1568
1569 Ok(())
1570 }
1571
1572 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1573 pub fn upsert_tool_call(
1574 &mut self,
1575 tool_call: acp::ToolCall,
1576 cx: &mut Context<Self>,
1577 ) -> Result<(), acp::Error> {
1578 let status = tool_call.status.into();
1579 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1580 }
1581
1582 /// Fails if id does not match an existing entry.
1583 pub fn upsert_tool_call_inner(
1584 &mut self,
1585 update: acp::ToolCallUpdate,
1586 status: ToolCallStatus,
1587 cx: &mut Context<Self>,
1588 ) -> Result<(), acp::Error> {
1589 let language_registry = self.project.read(cx).languages().clone();
1590 let path_style = self.project.read(cx).path_style(cx);
1591 let id = update.tool_call_id.clone();
1592
1593 let agent_telemetry_id = self.connection().telemetry_id();
1594 let session = self.session_id();
1595 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1596 let status = if matches!(status, ToolCallStatus::Completed) {
1597 "completed"
1598 } else {
1599 "failed"
1600 };
1601 telemetry::event!(
1602 "Agent Tool Call Completed",
1603 agent_telemetry_id,
1604 session,
1605 status
1606 );
1607 }
1608
1609 if let Some(ix) = self.index_for_tool_call(&id) {
1610 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1611 unreachable!()
1612 };
1613
1614 call.update_fields(
1615 update.fields,
1616 update.meta,
1617 language_registry,
1618 path_style,
1619 &self.terminals,
1620 cx,
1621 )?;
1622 call.status = status;
1623
1624 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1625 } else {
1626 let call = ToolCall::from_acp(
1627 update.try_into()?,
1628 status,
1629 language_registry,
1630 self.project.read(cx).path_style(cx),
1631 &self.terminals,
1632 cx,
1633 )?;
1634 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1635 };
1636
1637 self.resolve_locations(id, cx);
1638 Ok(())
1639 }
1640
1641 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1642 self.entries
1643 .iter()
1644 .enumerate()
1645 .rev()
1646 .find_map(|(index, entry)| {
1647 if let AgentThreadEntry::ToolCall(tool_call) = entry
1648 && &tool_call.id == id
1649 {
1650 Some(index)
1651 } else {
1652 None
1653 }
1654 })
1655 }
1656
1657 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1658 // The tool call we are looking for is typically the last one, or very close to the end.
1659 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1660 self.entries
1661 .iter_mut()
1662 .enumerate()
1663 .rev()
1664 .find_map(|(index, tool_call)| {
1665 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1666 && &tool_call.id == id
1667 {
1668 Some((index, tool_call))
1669 } else {
1670 None
1671 }
1672 })
1673 }
1674
1675 pub fn tool_call(&self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1676 self.entries
1677 .iter()
1678 .enumerate()
1679 .rev()
1680 .find_map(|(index, tool_call)| {
1681 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1682 && &tool_call.id == id
1683 {
1684 Some((index, tool_call))
1685 } else {
1686 None
1687 }
1688 })
1689 }
1690
1691 pub fn tool_call_for_subagent(&self, session_id: &acp::SessionId) -> Option<&ToolCall> {
1692 self.entries.iter().find_map(|entry| match entry {
1693 AgentThreadEntry::ToolCall(tool_call)
1694 if tool_call.subagent_session_id.as_ref() == Some(session_id) =>
1695 {
1696 Some(tool_call)
1697 }
1698 _ => None,
1699 })
1700 }
1701
1702 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1703 let project = self.project.clone();
1704 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1705 return;
1706 };
1707 let task = tool_call.resolve_locations(project, cx);
1708 cx.spawn(async move |this, cx| {
1709 let resolved_locations = task.await;
1710
1711 this.update(cx, |this, cx| {
1712 let project = this.project.clone();
1713
1714 for location in resolved_locations.iter().flatten() {
1715 this.shared_buffers
1716 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1717 }
1718 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1719 return;
1720 };
1721
1722 if let Some(Some(location)) = resolved_locations.last() {
1723 project.update(cx, |project, cx| {
1724 let should_ignore = if let Some(agent_location) = project
1725 .agent_location()
1726 .filter(|agent_location| agent_location.buffer == location.buffer)
1727 {
1728 let snapshot = location.buffer.read(cx).snapshot();
1729 let old_position = agent_location.position.to_point(&snapshot);
1730 let new_position = location.position.to_point(&snapshot);
1731
1732 // ignore this so that when we get updates from the edit tool
1733 // the position doesn't reset to the startof line
1734 old_position.row == new_position.row
1735 && old_position.column > new_position.column
1736 } else {
1737 false
1738 };
1739 if !should_ignore {
1740 project.set_agent_location(Some(location.into()), cx);
1741 }
1742 });
1743 }
1744
1745 let resolved_locations = resolved_locations
1746 .iter()
1747 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1748 .collect::<Vec<_>>();
1749
1750 if tool_call.resolved_locations != resolved_locations {
1751 tool_call.resolved_locations = resolved_locations;
1752 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1753 }
1754 })
1755 })
1756 .detach();
1757 }
1758
1759 pub fn request_tool_call_authorization(
1760 &mut self,
1761 tool_call: acp::ToolCallUpdate,
1762 options: PermissionOptions,
1763 cx: &mut Context<Self>,
1764 ) -> Result<Task<acp::RequestPermissionOutcome>> {
1765 let (tx, rx) = oneshot::channel();
1766
1767 let status = ToolCallStatus::WaitingForConfirmation {
1768 options,
1769 respond_tx: tx,
1770 };
1771
1772 let tool_call_id = tool_call.tool_call_id.clone();
1773 self.upsert_tool_call_inner(tool_call, status, cx)?;
1774 cx.emit(AcpThreadEvent::ToolAuthorizationRequested(
1775 tool_call_id.clone(),
1776 ));
1777
1778 Ok(cx.spawn(async move |this, cx| {
1779 let outcome = match rx.await {
1780 Ok(option) => acp::RequestPermissionOutcome::Selected(
1781 acp::SelectedPermissionOutcome::new(option),
1782 ),
1783 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1784 };
1785 this.update(cx, |_this, cx| {
1786 cx.emit(AcpThreadEvent::ToolAuthorizationReceived(tool_call_id))
1787 })
1788 .ok();
1789 outcome
1790 }))
1791 }
1792
1793 pub fn authorize_tool_call(
1794 &mut self,
1795 id: acp::ToolCallId,
1796 option_id: acp::PermissionOptionId,
1797 option_kind: acp::PermissionOptionKind,
1798 cx: &mut Context<Self>,
1799 ) {
1800 let Some((ix, call)) = self.tool_call_mut(&id) else {
1801 return;
1802 };
1803
1804 let new_status = match option_kind {
1805 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1806 ToolCallStatus::Rejected
1807 }
1808 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1809 ToolCallStatus::InProgress
1810 }
1811 _ => ToolCallStatus::InProgress,
1812 };
1813
1814 let curr_status = mem::replace(&mut call.status, new_status);
1815
1816 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1817 respond_tx.send(option_id).log_err();
1818 } else if cfg!(debug_assertions) {
1819 panic!("tried to authorize an already authorized tool call");
1820 }
1821
1822 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1823 }
1824
1825 pub fn plan(&self) -> &Plan {
1826 &self.plan
1827 }
1828
1829 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1830 let new_entries_len = request.entries.len();
1831 let mut new_entries = request.entries.into_iter();
1832
1833 // Reuse existing markdown to prevent flickering
1834 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1835 let PlanEntry {
1836 content,
1837 priority,
1838 status,
1839 } = old;
1840 content.update(cx, |old, cx| {
1841 old.replace(new.content, cx);
1842 });
1843 *priority = new.priority;
1844 *status = new.status;
1845 }
1846 for new in new_entries {
1847 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1848 }
1849 self.plan.entries.truncate(new_entries_len);
1850
1851 cx.notify();
1852 }
1853
1854 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1855 self.plan
1856 .entries
1857 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1858 cx.notify();
1859 }
1860
1861 #[cfg(any(test, feature = "test-support"))]
1862 pub fn send_raw(
1863 &mut self,
1864 message: &str,
1865 cx: &mut Context<Self>,
1866 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1867 self.send(vec![message.into()], cx)
1868 }
1869
1870 pub fn send(
1871 &mut self,
1872 message: Vec<acp::ContentBlock>,
1873 cx: &mut Context<Self>,
1874 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1875 let block = ContentBlock::new_combined(
1876 message.clone(),
1877 self.project.read(cx).languages().clone(),
1878 self.project.read(cx).path_style(cx),
1879 cx,
1880 );
1881 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1882 let git_store = self.project.read(cx).git_store().clone();
1883
1884 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1885 Some(UserMessageId::new())
1886 } else {
1887 None
1888 };
1889
1890 self.run_turn(cx, async move |this, cx| {
1891 this.update(cx, |this, cx| {
1892 this.push_entry(
1893 AgentThreadEntry::UserMessage(UserMessage {
1894 id: message_id.clone(),
1895 content: block,
1896 chunks: message,
1897 checkpoint: None,
1898 indented: false,
1899 }),
1900 cx,
1901 );
1902 })
1903 .ok();
1904
1905 let old_checkpoint = git_store
1906 .update(cx, |git, cx| git.checkpoint(cx))
1907 .await
1908 .context("failed to get old checkpoint")
1909 .log_err();
1910 this.update(cx, |this, cx| {
1911 if let Some((_ix, message)) = this.last_user_message() {
1912 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1913 git_checkpoint,
1914 show: false,
1915 });
1916 }
1917 this.connection.prompt(message_id, request, cx)
1918 })?
1919 .await
1920 })
1921 }
1922
1923 pub fn can_retry(&self, cx: &App) -> bool {
1924 self.connection.retry(&self.session_id, cx).is_some()
1925 }
1926
1927 pub fn retry(
1928 &mut self,
1929 cx: &mut Context<Self>,
1930 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1931 self.run_turn(cx, async move |this, cx| {
1932 this.update(cx, |this, cx| {
1933 this.connection
1934 .retry(&this.session_id, cx)
1935 .map(|retry| retry.run(cx))
1936 })?
1937 .context("retrying a session is not supported")?
1938 .await
1939 })
1940 }
1941
1942 fn run_turn(
1943 &mut self,
1944 cx: &mut Context<Self>,
1945 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1946 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1947 self.clear_completed_plan_entries(cx);
1948 self.had_error = false;
1949
1950 let (tx, rx) = oneshot::channel();
1951 let cancel_task = self.cancel(cx);
1952
1953 self.turn_id += 1;
1954 let turn_id = self.turn_id;
1955 self.running_turn = Some(RunningTurn {
1956 id: turn_id,
1957 send_task: cx.spawn(async move |this, cx| {
1958 cancel_task.await;
1959 tx.send(f(this, cx).await).ok();
1960 }),
1961 });
1962
1963 cx.spawn(async move |this, cx| {
1964 let response = rx.await;
1965
1966 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1967 .await?;
1968
1969 this.update(cx, |this, cx| {
1970 this.project
1971 .update(cx, |project, cx| project.set_agent_location(None, cx));
1972 let Ok(response) = response else {
1973 // tx dropped, just return
1974 return Ok(None);
1975 };
1976
1977 let is_same_turn = this
1978 .running_turn
1979 .as_ref()
1980 .is_some_and(|turn| turn_id == turn.id);
1981
1982 // If the user submitted a follow up message, running_turn might
1983 // already point to a different turn. Therefore we only want to
1984 // take the task if it's the same turn.
1985 if is_same_turn {
1986 this.running_turn.take();
1987 }
1988
1989 match response {
1990 Ok(r) => {
1991 if r.stop_reason == acp::StopReason::MaxTokens {
1992 this.had_error = true;
1993 cx.emit(AcpThreadEvent::Error);
1994 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1995 return Err(anyhow!("Max tokens reached"));
1996 }
1997
1998 let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
1999 if canceled {
2000 this.mark_pending_tools_as_canceled();
2001 }
2002
2003 // Handle refusal - distinguish between user prompt and tool call refusals
2004 if let acp::StopReason::Refusal = r.stop_reason {
2005 this.had_error = true;
2006 if let Some((user_msg_ix, _)) = this.last_user_message() {
2007 // Check if there's a completed tool call with results after the last user message
2008 // This indicates the refusal is in response to tool output, not the user's prompt
2009 let has_completed_tool_call_after_user_msg =
2010 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
2011 if let AgentThreadEntry::ToolCall(tool_call) = entry {
2012 // Check if the tool call has completed and has output
2013 matches!(tool_call.status, ToolCallStatus::Completed)
2014 && tool_call.raw_output.is_some()
2015 } else {
2016 false
2017 }
2018 });
2019
2020 if has_completed_tool_call_after_user_msg {
2021 // Refusal is due to tool output - don't truncate, just notify
2022 // The model refused based on what the tool returned
2023 cx.emit(AcpThreadEvent::Refusal);
2024 } else {
2025 // User prompt was refused - truncate back to before the user message
2026 let range = user_msg_ix..this.entries.len();
2027 if range.start < range.end {
2028 this.entries.truncate(user_msg_ix);
2029 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2030 }
2031 cx.emit(AcpThreadEvent::Refusal);
2032 }
2033 } else {
2034 // No user message found, treat as general refusal
2035 cx.emit(AcpThreadEvent::Refusal);
2036 }
2037 }
2038
2039 cx.emit(AcpThreadEvent::Stopped(r.stop_reason));
2040 Ok(Some(r))
2041 }
2042 Err(e) => {
2043 this.had_error = true;
2044 cx.emit(AcpThreadEvent::Error);
2045 log::error!("Error in run turn: {:?}", e);
2046 Err(e)
2047 }
2048 }
2049 })?
2050 })
2051 .boxed()
2052 }
2053
2054 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2055 let Some(turn) = self.running_turn.take() else {
2056 return Task::ready(());
2057 };
2058 self.connection.cancel(&self.session_id, cx);
2059
2060 self.mark_pending_tools_as_canceled();
2061
2062 // Wait for the send task to complete
2063 cx.background_spawn(turn.send_task)
2064 }
2065
2066 fn mark_pending_tools_as_canceled(&mut self) {
2067 for entry in self.entries.iter_mut() {
2068 if let AgentThreadEntry::ToolCall(call) = entry {
2069 let cancel = matches!(
2070 call.status,
2071 ToolCallStatus::Pending
2072 | ToolCallStatus::WaitingForConfirmation { .. }
2073 | ToolCallStatus::InProgress
2074 );
2075
2076 if cancel {
2077 call.status = ToolCallStatus::Canceled;
2078 }
2079 }
2080 }
2081 }
2082
2083 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2084 pub fn restore_checkpoint(
2085 &mut self,
2086 id: UserMessageId,
2087 cx: &mut Context<Self>,
2088 ) -> Task<Result<()>> {
2089 let Some((_, message)) = self.user_message_mut(&id) else {
2090 return Task::ready(Err(anyhow!("message not found")));
2091 };
2092
2093 let checkpoint = message
2094 .checkpoint
2095 .as_ref()
2096 .map(|c| c.git_checkpoint.clone());
2097
2098 // Cancel any in-progress generation before restoring
2099 let cancel_task = self.cancel(cx);
2100 let rewind = self.rewind(id.clone(), cx);
2101 let git_store = self.project.read(cx).git_store().clone();
2102
2103 cx.spawn(async move |_, cx| {
2104 cancel_task.await;
2105 rewind.await?;
2106 if let Some(checkpoint) = checkpoint {
2107 git_store
2108 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2109 .await?;
2110 }
2111
2112 Ok(())
2113 })
2114 }
2115
2116 /// Rewinds this thread to before the entry at `index`, removing it and all
2117 /// subsequent entries while rejecting any action_log changes made from that point.
2118 /// Unlike `restore_checkpoint`, this method does not restore from git.
2119 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2120 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2121 return Task::ready(Err(anyhow!("not supported")));
2122 };
2123
2124 let telemetry = ActionLogTelemetry::from(&*self);
2125 cx.spawn(async move |this, cx| {
2126 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2127 this.update(cx, |this, cx| {
2128 if let Some((ix, _)) = this.user_message_mut(&id) {
2129 // Collect all terminals from entries that will be removed
2130 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2131 .iter()
2132 .flat_map(|entry| entry.terminals())
2133 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2134 .collect();
2135
2136 let range = ix..this.entries.len();
2137 this.entries.truncate(ix);
2138 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2139
2140 // Kill and remove the terminals
2141 for terminal_id in terminals_to_remove {
2142 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2143 terminal.update(cx, |terminal, cx| {
2144 terminal.kill(cx);
2145 });
2146 }
2147 }
2148 }
2149 this.action_log().update(cx, |action_log, cx| {
2150 action_log.reject_all_edits(Some(telemetry), cx)
2151 })
2152 })?
2153 .await;
2154 Ok(())
2155 })
2156 }
2157
2158 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2159 let git_store = self.project.read(cx).git_store().clone();
2160
2161 let Some((_, message)) = self.last_user_message() else {
2162 return Task::ready(Ok(()));
2163 };
2164 let Some(user_message_id) = message.id.clone() else {
2165 return Task::ready(Ok(()));
2166 };
2167 let Some(checkpoint) = message.checkpoint.as_ref() else {
2168 return Task::ready(Ok(()));
2169 };
2170 let old_checkpoint = checkpoint.git_checkpoint.clone();
2171
2172 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2173 cx.spawn(async move |this, cx| {
2174 let Some(new_checkpoint) = new_checkpoint
2175 .await
2176 .context("failed to get new checkpoint")
2177 .log_err()
2178 else {
2179 return Ok(());
2180 };
2181
2182 let equal = git_store
2183 .update(cx, |git, cx| {
2184 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2185 })
2186 .await
2187 .unwrap_or(true);
2188
2189 this.update(cx, |this, cx| {
2190 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2191 if let Some(checkpoint) = message.checkpoint.as_mut() {
2192 checkpoint.show = !equal;
2193 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2194 }
2195 }
2196 })?;
2197
2198 Ok(())
2199 })
2200 }
2201
2202 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2203 self.entries
2204 .iter_mut()
2205 .enumerate()
2206 .rev()
2207 .find_map(|(ix, entry)| {
2208 if let AgentThreadEntry::UserMessage(message) = entry {
2209 Some((ix, message))
2210 } else {
2211 None
2212 }
2213 })
2214 }
2215
2216 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2217 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2218 if let AgentThreadEntry::UserMessage(message) = entry {
2219 if message.id.as_ref() == Some(id) {
2220 Some((ix, message))
2221 } else {
2222 None
2223 }
2224 } else {
2225 None
2226 }
2227 })
2228 }
2229
2230 pub fn read_text_file(
2231 &self,
2232 path: PathBuf,
2233 line: Option<u32>,
2234 limit: Option<u32>,
2235 reuse_shared_snapshot: bool,
2236 cx: &mut Context<Self>,
2237 ) -> Task<Result<String, acp::Error>> {
2238 // Args are 1-based, move to 0-based
2239 let line = line.unwrap_or_default().saturating_sub(1);
2240 let limit = limit.unwrap_or(u32::MAX);
2241 let project = self.project.clone();
2242 let action_log = self.action_log.clone();
2243 cx.spawn(async move |this, cx| {
2244 let load = project.update(cx, |project, cx| {
2245 let path = project
2246 .project_path_for_absolute_path(&path, cx)
2247 .ok_or_else(|| {
2248 acp::Error::resource_not_found(Some(path.display().to_string()))
2249 })?;
2250 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2251 })?;
2252
2253 let buffer = load.await?;
2254
2255 let snapshot = if reuse_shared_snapshot {
2256 this.read_with(cx, |this, _| {
2257 this.shared_buffers.get(&buffer.clone()).cloned()
2258 })
2259 .log_err()
2260 .flatten()
2261 } else {
2262 None
2263 };
2264
2265 let snapshot = if let Some(snapshot) = snapshot {
2266 snapshot
2267 } else {
2268 action_log.update(cx, |action_log, cx| {
2269 action_log.buffer_read(buffer.clone(), cx);
2270 });
2271
2272 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2273 this.update(cx, |this, _| {
2274 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2275 })?;
2276 snapshot
2277 };
2278
2279 let max_point = snapshot.max_point();
2280 let start_position = Point::new(line, 0);
2281
2282 if start_position > max_point {
2283 return Err(acp::Error::invalid_params().data(format!(
2284 "Attempting to read beyond the end of the file, line {}:{}",
2285 max_point.row + 1,
2286 max_point.column
2287 )));
2288 }
2289
2290 let start = snapshot.anchor_before(start_position);
2291 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2292
2293 project.update(cx, |project, cx| {
2294 project.set_agent_location(
2295 Some(AgentLocation {
2296 buffer: buffer.downgrade(),
2297 position: start,
2298 }),
2299 cx,
2300 );
2301 });
2302
2303 Ok(snapshot.text_for_range(start..end).collect::<String>())
2304 })
2305 }
2306
2307 pub fn write_text_file(
2308 &self,
2309 path: PathBuf,
2310 content: String,
2311 cx: &mut Context<Self>,
2312 ) -> Task<Result<()>> {
2313 let project = self.project.clone();
2314 let action_log = self.action_log.clone();
2315 cx.spawn(async move |this, cx| {
2316 let load = project.update(cx, |project, cx| {
2317 let path = project
2318 .project_path_for_absolute_path(&path, cx)
2319 .context("invalid path")?;
2320 anyhow::Ok(project.open_buffer(path, cx))
2321 });
2322 let buffer = load?.await?;
2323 let snapshot = this.update(cx, |this, cx| {
2324 this.shared_buffers
2325 .get(&buffer)
2326 .cloned()
2327 .unwrap_or_else(|| buffer.read(cx).snapshot())
2328 })?;
2329 let edits = cx
2330 .background_executor()
2331 .spawn(async move {
2332 let old_text = snapshot.text();
2333 text_diff(old_text.as_str(), &content)
2334 .into_iter()
2335 .map(|(range, replacement)| {
2336 (snapshot.anchor_range_around(range), replacement)
2337 })
2338 .collect::<Vec<_>>()
2339 })
2340 .await;
2341
2342 project.update(cx, |project, cx| {
2343 project.set_agent_location(
2344 Some(AgentLocation {
2345 buffer: buffer.downgrade(),
2346 position: edits
2347 .last()
2348 .map(|(range, _)| range.end)
2349 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2350 }),
2351 cx,
2352 );
2353 });
2354
2355 let format_on_save = cx.update(|cx| {
2356 action_log.update(cx, |action_log, cx| {
2357 action_log.buffer_read(buffer.clone(), cx);
2358 });
2359
2360 let format_on_save = buffer.update(cx, |buffer, cx| {
2361 buffer.edit(edits, None, cx);
2362
2363 let settings = language::language_settings::language_settings(
2364 buffer.language().map(|l| l.name()),
2365 buffer.file(),
2366 cx,
2367 );
2368
2369 settings.format_on_save != FormatOnSave::Off
2370 });
2371 action_log.update(cx, |action_log, cx| {
2372 action_log.buffer_edited(buffer.clone(), cx);
2373 });
2374 format_on_save
2375 });
2376
2377 if format_on_save {
2378 let format_task = project.update(cx, |project, cx| {
2379 project.format(
2380 HashSet::from_iter([buffer.clone()]),
2381 LspFormatTarget::Buffers,
2382 false,
2383 FormatTrigger::Save,
2384 cx,
2385 )
2386 });
2387 format_task.await.log_err();
2388
2389 action_log.update(cx, |action_log, cx| {
2390 action_log.buffer_edited(buffer.clone(), cx);
2391 });
2392 }
2393
2394 project
2395 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2396 .await
2397 })
2398 }
2399
2400 pub fn create_terminal(
2401 &self,
2402 command: String,
2403 args: Vec<String>,
2404 extra_env: Vec<acp::EnvVariable>,
2405 cwd: Option<PathBuf>,
2406 output_byte_limit: Option<u64>,
2407 cx: &mut Context<Self>,
2408 ) -> Task<Result<Entity<Terminal>>> {
2409 let env = match &cwd {
2410 Some(dir) => self.project.update(cx, |project, cx| {
2411 project.environment().update(cx, |env, cx| {
2412 env.directory_environment(dir.as_path().into(), cx)
2413 })
2414 }),
2415 None => Task::ready(None).shared(),
2416 };
2417 let env = cx.spawn(async move |_, _| {
2418 let mut env = env.await.unwrap_or_default();
2419 // Disables paging for `git` and hopefully other commands
2420 env.insert("PAGER".into(), "".into());
2421 for var in extra_env {
2422 env.insert(var.name, var.value);
2423 }
2424 env
2425 });
2426
2427 let project = self.project.clone();
2428 let language_registry = project.read(cx).languages().clone();
2429 let is_windows = project.read(cx).path_style(cx).is_windows();
2430
2431 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2432 let terminal_task = cx.spawn({
2433 let terminal_id = terminal_id.clone();
2434 async move |_this, cx| {
2435 let env = env.await;
2436 let shell = project
2437 .update(cx, |project, cx| {
2438 project
2439 .remote_client()
2440 .and_then(|r| r.read(cx).default_system_shell())
2441 })
2442 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2443 let (task_command, task_args) =
2444 ShellBuilder::new(&Shell::Program(shell), is_windows)
2445 .redirect_stdin_to_dev_null()
2446 .build(Some(command.clone()), &args);
2447 let terminal = project
2448 .update(cx, |project, cx| {
2449 project.create_terminal_task(
2450 task::SpawnInTerminal {
2451 command: Some(task_command),
2452 args: task_args,
2453 cwd: cwd.clone(),
2454 env,
2455 ..Default::default()
2456 },
2457 cx,
2458 )
2459 })
2460 .await?;
2461
2462 anyhow::Ok(cx.new(|cx| {
2463 Terminal::new(
2464 terminal_id,
2465 &format!("{} {}", command, args.join(" ")),
2466 cwd,
2467 output_byte_limit.map(|l| l as usize),
2468 terminal,
2469 language_registry,
2470 cx,
2471 )
2472 }))
2473 }
2474 });
2475
2476 cx.spawn(async move |this, cx| {
2477 let terminal = terminal_task.await?;
2478 this.update(cx, |this, _cx| {
2479 this.terminals.insert(terminal_id, terminal.clone());
2480 terminal
2481 })
2482 })
2483 }
2484
2485 pub fn kill_terminal(
2486 &mut self,
2487 terminal_id: acp::TerminalId,
2488 cx: &mut Context<Self>,
2489 ) -> Result<()> {
2490 self.terminals
2491 .get(&terminal_id)
2492 .context("Terminal not found")?
2493 .update(cx, |terminal, cx| {
2494 terminal.kill(cx);
2495 });
2496
2497 Ok(())
2498 }
2499
2500 pub fn release_terminal(
2501 &mut self,
2502 terminal_id: acp::TerminalId,
2503 cx: &mut Context<Self>,
2504 ) -> Result<()> {
2505 self.terminals
2506 .remove(&terminal_id)
2507 .context("Terminal not found")?
2508 .update(cx, |terminal, cx| {
2509 terminal.kill(cx);
2510 });
2511
2512 Ok(())
2513 }
2514
2515 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2516 self.terminals
2517 .get(&terminal_id)
2518 .context("Terminal not found")
2519 .cloned()
2520 }
2521
2522 pub fn to_markdown(&self, cx: &App) -> String {
2523 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2524 }
2525
2526 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2527 cx.emit(AcpThreadEvent::LoadError(error));
2528 }
2529
2530 pub fn register_terminal_created(
2531 &mut self,
2532 terminal_id: acp::TerminalId,
2533 command_label: String,
2534 working_dir: Option<PathBuf>,
2535 output_byte_limit: Option<u64>,
2536 terminal: Entity<::terminal::Terminal>,
2537 cx: &mut Context<Self>,
2538 ) -> Entity<Terminal> {
2539 let language_registry = self.project.read(cx).languages().clone();
2540
2541 let entity = cx.new(|cx| {
2542 Terminal::new(
2543 terminal_id.clone(),
2544 &command_label,
2545 working_dir.clone(),
2546 output_byte_limit.map(|l| l as usize),
2547 terminal,
2548 language_registry,
2549 cx,
2550 )
2551 });
2552 self.terminals.insert(terminal_id.clone(), entity.clone());
2553 entity
2554 }
2555
2556 pub fn mark_as_subagent_output(&mut self, cx: &mut Context<Self>) {
2557 for entry in self.entries.iter_mut().rev() {
2558 if let AgentThreadEntry::AssistantMessage(assistant_message) = entry {
2559 assistant_message.is_subagent_output = true;
2560 cx.notify();
2561 return;
2562 }
2563 }
2564 }
2565}
2566
2567fn markdown_for_raw_output(
2568 raw_output: &serde_json::Value,
2569 language_registry: &Arc<LanguageRegistry>,
2570 cx: &mut App,
2571) -> Option<Entity<Markdown>> {
2572 match raw_output {
2573 serde_json::Value::Null => None,
2574 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2575 Markdown::new(
2576 value.to_string().into(),
2577 Some(language_registry.clone()),
2578 None,
2579 cx,
2580 )
2581 })),
2582 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2583 Markdown::new(
2584 value.to_string().into(),
2585 Some(language_registry.clone()),
2586 None,
2587 cx,
2588 )
2589 })),
2590 serde_json::Value::String(value) => Some(cx.new(|cx| {
2591 Markdown::new(
2592 value.clone().into(),
2593 Some(language_registry.clone()),
2594 None,
2595 cx,
2596 )
2597 })),
2598 value => Some(cx.new(|cx| {
2599 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2600
2601 Markdown::new(
2602 format!("```json\n{}\n```", pretty_json).into(),
2603 Some(language_registry.clone()),
2604 None,
2605 cx,
2606 )
2607 })),
2608 }
2609}
2610
2611#[cfg(test)]
2612mod tests {
2613 use super::*;
2614 use anyhow::anyhow;
2615 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2616 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2617 use indoc::indoc;
2618 use project::{FakeFs, Fs};
2619 use rand::{distr, prelude::*};
2620 use serde_json::json;
2621 use settings::SettingsStore;
2622 use smol::stream::StreamExt as _;
2623 use std::{
2624 any::Any,
2625 cell::RefCell,
2626 path::Path,
2627 rc::Rc,
2628 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2629 time::Duration,
2630 };
2631 use util::path;
2632
2633 fn init_test(cx: &mut TestAppContext) {
2634 env_logger::try_init().ok();
2635 cx.update(|cx| {
2636 let settings_store = SettingsStore::test(cx);
2637 cx.set_global(settings_store);
2638 });
2639 }
2640
2641 #[gpui::test]
2642 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2643 init_test(cx);
2644
2645 let fs = FakeFs::new(cx.executor());
2646 let project = Project::test(fs, [], cx).await;
2647 let connection = Rc::new(FakeAgentConnection::new());
2648 let thread = cx
2649 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2650 .await
2651 .unwrap();
2652
2653 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2654
2655 // Send Output BEFORE Created - should be buffered by acp_thread
2656 thread.update(cx, |thread, cx| {
2657 thread.on_terminal_provider_event(
2658 TerminalProviderEvent::Output {
2659 terminal_id: terminal_id.clone(),
2660 data: b"hello buffered".to_vec(),
2661 },
2662 cx,
2663 );
2664 });
2665
2666 // Create a display-only terminal and then send Created
2667 let lower = cx.new(|cx| {
2668 let builder = ::terminal::TerminalBuilder::new_display_only(
2669 ::terminal::terminal_settings::CursorShape::default(),
2670 ::terminal::terminal_settings::AlternateScroll::On,
2671 None,
2672 0,
2673 cx.background_executor(),
2674 PathStyle::local(),
2675 )
2676 .unwrap();
2677 builder.subscribe(cx)
2678 });
2679
2680 thread.update(cx, |thread, cx| {
2681 thread.on_terminal_provider_event(
2682 TerminalProviderEvent::Created {
2683 terminal_id: terminal_id.clone(),
2684 label: "Buffered Test".to_string(),
2685 cwd: None,
2686 output_byte_limit: None,
2687 terminal: lower.clone(),
2688 },
2689 cx,
2690 );
2691 });
2692
2693 // After Created, buffered Output should have been flushed into the renderer
2694 let content = thread.read_with(cx, |thread, cx| {
2695 let term = thread.terminal(terminal_id.clone()).unwrap();
2696 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2697 });
2698
2699 assert!(
2700 content.contains("hello buffered"),
2701 "expected buffered output to render, got: {content}"
2702 );
2703 }
2704
2705 #[gpui::test]
2706 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2707 init_test(cx);
2708
2709 let fs = FakeFs::new(cx.executor());
2710 let project = Project::test(fs, [], cx).await;
2711 let connection = Rc::new(FakeAgentConnection::new());
2712 let thread = cx
2713 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2714 .await
2715 .unwrap();
2716
2717 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2718
2719 // Send Output BEFORE Created
2720 thread.update(cx, |thread, cx| {
2721 thread.on_terminal_provider_event(
2722 TerminalProviderEvent::Output {
2723 terminal_id: terminal_id.clone(),
2724 data: b"pre-exit data".to_vec(),
2725 },
2726 cx,
2727 );
2728 });
2729
2730 // Send Exit BEFORE Created
2731 thread.update(cx, |thread, cx| {
2732 thread.on_terminal_provider_event(
2733 TerminalProviderEvent::Exit {
2734 terminal_id: terminal_id.clone(),
2735 status: acp::TerminalExitStatus::new().exit_code(0),
2736 },
2737 cx,
2738 );
2739 });
2740
2741 // Now create a display-only lower-level terminal and send Created
2742 let lower = cx.new(|cx| {
2743 let builder = ::terminal::TerminalBuilder::new_display_only(
2744 ::terminal::terminal_settings::CursorShape::default(),
2745 ::terminal::terminal_settings::AlternateScroll::On,
2746 None,
2747 0,
2748 cx.background_executor(),
2749 PathStyle::local(),
2750 )
2751 .unwrap();
2752 builder.subscribe(cx)
2753 });
2754
2755 thread.update(cx, |thread, cx| {
2756 thread.on_terminal_provider_event(
2757 TerminalProviderEvent::Created {
2758 terminal_id: terminal_id.clone(),
2759 label: "Buffered Exit Test".to_string(),
2760 cwd: None,
2761 output_byte_limit: None,
2762 terminal: lower.clone(),
2763 },
2764 cx,
2765 );
2766 });
2767
2768 // Output should be present after Created (flushed from buffer)
2769 let content = thread.read_with(cx, |thread, cx| {
2770 let term = thread.terminal(terminal_id.clone()).unwrap();
2771 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2772 });
2773
2774 assert!(
2775 content.contains("pre-exit data"),
2776 "expected pre-exit data to render, got: {content}"
2777 );
2778 }
2779
2780 /// Test that killing a terminal via Terminal::kill properly:
2781 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2782 /// 2. The underlying terminal still has the output that was written before the kill
2783 ///
2784 /// This test verifies that the fix to kill_active_task (which now also kills
2785 /// the shell process in addition to the foreground process) properly allows
2786 /// wait_for_exit to complete instead of hanging indefinitely.
2787 #[cfg(unix)]
2788 #[gpui::test]
2789 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2790 use std::collections::HashMap;
2791 use task::Shell;
2792 use util::shell_builder::ShellBuilder;
2793
2794 init_test(cx);
2795 cx.executor().allow_parking();
2796
2797 let fs = FakeFs::new(cx.executor());
2798 let project = Project::test(fs, [], cx).await;
2799 let connection = Rc::new(FakeAgentConnection::new());
2800 let thread = cx
2801 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2802 .await
2803 .unwrap();
2804
2805 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2806
2807 // Create a real PTY terminal that runs a command which prints output then sleeps
2808 // We use printf instead of echo and chain with && sleep to ensure proper execution
2809 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2810 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2811 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2812 &[],
2813 );
2814
2815 let builder = cx
2816 .update(|cx| {
2817 ::terminal::TerminalBuilder::new(
2818 None,
2819 None,
2820 task::Shell::WithArguments {
2821 program,
2822 args,
2823 title_override: None,
2824 },
2825 HashMap::default(),
2826 ::terminal::terminal_settings::CursorShape::default(),
2827 ::terminal::terminal_settings::AlternateScroll::On,
2828 None,
2829 vec![],
2830 0,
2831 false,
2832 0,
2833 Some(completion_tx),
2834 cx,
2835 vec![],
2836 PathStyle::local(),
2837 )
2838 })
2839 .await
2840 .unwrap();
2841
2842 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2843
2844 // Create the acp_thread Terminal wrapper
2845 thread.update(cx, |thread, cx| {
2846 thread.on_terminal_provider_event(
2847 TerminalProviderEvent::Created {
2848 terminal_id: terminal_id.clone(),
2849 label: "printf output_before_kill && sleep 60".to_string(),
2850 cwd: None,
2851 output_byte_limit: None,
2852 terminal: lower_terminal.clone(),
2853 },
2854 cx,
2855 );
2856 });
2857
2858 // Wait for the printf command to execute and produce output
2859 // Use real time since parking is enabled
2860 cx.executor().timer(Duration::from_millis(500)).await;
2861
2862 // Get the acp_thread Terminal and kill it
2863 let wait_for_exit = thread.update(cx, |thread, cx| {
2864 let term = thread.terminals.get(&terminal_id).unwrap();
2865 let wait_for_exit = term.read(cx).wait_for_exit();
2866 term.update(cx, |term, cx| {
2867 term.kill(cx);
2868 });
2869 wait_for_exit
2870 });
2871
2872 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2873 // Before the fix to kill_active_task, this would hang forever because
2874 // only the foreground process was killed, not the shell, so the PTY
2875 // child never exited and wait_for_completed_task never completed.
2876 let exit_result = futures::select! {
2877 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2878 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2879 };
2880
2881 assert!(
2882 exit_result.is_some(),
2883 "wait_for_exit should complete after kill, but it timed out. \
2884 This indicates kill_active_task is not properly killing the shell process."
2885 );
2886
2887 // Give the system a chance to process any pending updates
2888 cx.run_until_parked();
2889
2890 // Verify that the underlying terminal still has the output that was
2891 // written before the kill. This verifies that killing doesn't lose output.
2892 let inner_content = thread.read_with(cx, |thread, cx| {
2893 let term = thread.terminals.get(&terminal_id).unwrap();
2894 term.read(cx).inner().read(cx).get_content()
2895 });
2896
2897 assert!(
2898 inner_content.contains("output_before_kill"),
2899 "Underlying terminal should contain output from before kill, got: {}",
2900 inner_content
2901 );
2902 }
2903
2904 #[gpui::test]
2905 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2906 init_test(cx);
2907
2908 let fs = FakeFs::new(cx.executor());
2909 let project = Project::test(fs, [], cx).await;
2910 let connection = Rc::new(FakeAgentConnection::new());
2911 let thread = cx
2912 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2913 .await
2914 .unwrap();
2915
2916 // Test creating a new user message
2917 thread.update(cx, |thread, cx| {
2918 thread.push_user_content_block(None, "Hello, ".into(), cx);
2919 });
2920
2921 thread.update(cx, |thread, cx| {
2922 assert_eq!(thread.entries.len(), 1);
2923 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2924 assert_eq!(user_msg.id, None);
2925 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2926 } else {
2927 panic!("Expected UserMessage");
2928 }
2929 });
2930
2931 // Test appending to existing user message
2932 let message_1_id = UserMessageId::new();
2933 thread.update(cx, |thread, cx| {
2934 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2935 });
2936
2937 thread.update(cx, |thread, cx| {
2938 assert_eq!(thread.entries.len(), 1);
2939 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2940 assert_eq!(user_msg.id, Some(message_1_id));
2941 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2942 } else {
2943 panic!("Expected UserMessage");
2944 }
2945 });
2946
2947 // Test creating new user message after assistant message
2948 thread.update(cx, |thread, cx| {
2949 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2950 });
2951
2952 let message_2_id = UserMessageId::new();
2953 thread.update(cx, |thread, cx| {
2954 thread.push_user_content_block(
2955 Some(message_2_id.clone()),
2956 "New user message".into(),
2957 cx,
2958 );
2959 });
2960
2961 thread.update(cx, |thread, cx| {
2962 assert_eq!(thread.entries.len(), 3);
2963 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2964 assert_eq!(user_msg.id, Some(message_2_id));
2965 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2966 } else {
2967 panic!("Expected UserMessage at index 2");
2968 }
2969 });
2970 }
2971
2972 #[gpui::test]
2973 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2974 init_test(cx);
2975
2976 let fs = FakeFs::new(cx.executor());
2977 let project = Project::test(fs, [], cx).await;
2978 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2979 |_, thread, mut cx| {
2980 async move {
2981 thread.update(&mut cx, |thread, cx| {
2982 thread
2983 .handle_session_update(
2984 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2985 "Thinking ".into(),
2986 )),
2987 cx,
2988 )
2989 .unwrap();
2990 thread
2991 .handle_session_update(
2992 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2993 "hard!".into(),
2994 )),
2995 cx,
2996 )
2997 .unwrap();
2998 })?;
2999 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3000 }
3001 .boxed_local()
3002 },
3003 ));
3004
3005 let thread = cx
3006 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3007 .await
3008 .unwrap();
3009
3010 thread
3011 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
3012 .await
3013 .unwrap();
3014
3015 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
3016 assert_eq!(
3017 output,
3018 indoc! {r#"
3019 ## User
3020
3021 Hello from Zed!
3022
3023 ## Assistant
3024
3025 <thinking>
3026 Thinking hard!
3027 </thinking>
3028
3029 "#}
3030 );
3031 }
3032
3033 #[gpui::test]
3034 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3035 init_test(cx);
3036
3037 let fs = FakeFs::new(cx.executor());
3038 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3039 .await;
3040 let project = Project::test(fs.clone(), [], cx).await;
3041 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3042 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3043 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3044 move |_, thread, mut cx| {
3045 let read_file_tx = read_file_tx.clone();
3046 async move {
3047 let content = thread
3048 .update(&mut cx, |thread, cx| {
3049 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3050 })
3051 .unwrap()
3052 .await
3053 .unwrap();
3054 assert_eq!(content, "one\ntwo\nthree\n");
3055 read_file_tx.take().unwrap().send(()).unwrap();
3056 thread
3057 .update(&mut cx, |thread, cx| {
3058 thread.write_text_file(
3059 path!("/tmp/foo").into(),
3060 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3061 cx,
3062 )
3063 })
3064 .unwrap()
3065 .await
3066 .unwrap();
3067 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3068 }
3069 .boxed_local()
3070 },
3071 ));
3072
3073 let (worktree, pathbuf) = project
3074 .update(cx, |project, cx| {
3075 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3076 })
3077 .await
3078 .unwrap();
3079 let buffer = project
3080 .update(cx, |project, cx| {
3081 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3082 })
3083 .await
3084 .unwrap();
3085
3086 let thread = cx
3087 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3088 .await
3089 .unwrap();
3090
3091 let request = thread.update(cx, |thread, cx| {
3092 thread.send_raw("Extend the count in /tmp/foo", cx)
3093 });
3094 read_file_rx.await.ok();
3095 buffer.update(cx, |buffer, cx| {
3096 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3097 });
3098 cx.run_until_parked();
3099 assert_eq!(
3100 buffer.read_with(cx, |buffer, _| buffer.text()),
3101 "zero\none\ntwo\nthree\nfour\nfive\n"
3102 );
3103 assert_eq!(
3104 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3105 "zero\none\ntwo\nthree\nfour\nfive\n"
3106 );
3107 request.await.unwrap();
3108 }
3109
3110 #[gpui::test]
3111 async fn test_reading_from_line(cx: &mut TestAppContext) {
3112 init_test(cx);
3113
3114 let fs = FakeFs::new(cx.executor());
3115 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3116 .await;
3117 let project = Project::test(fs.clone(), [], cx).await;
3118 project
3119 .update(cx, |project, cx| {
3120 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3121 })
3122 .await
3123 .unwrap();
3124
3125 let connection = Rc::new(FakeAgentConnection::new());
3126
3127 let thread = cx
3128 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3129 .await
3130 .unwrap();
3131
3132 // Whole file
3133 let content = thread
3134 .update(cx, |thread, cx| {
3135 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3136 })
3137 .await
3138 .unwrap();
3139
3140 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3141
3142 // Only start line
3143 let content = thread
3144 .update(cx, |thread, cx| {
3145 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3146 })
3147 .await
3148 .unwrap();
3149
3150 assert_eq!(content, "three\nfour\n");
3151
3152 // Only limit
3153 let content = thread
3154 .update(cx, |thread, cx| {
3155 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3156 })
3157 .await
3158 .unwrap();
3159
3160 assert_eq!(content, "one\ntwo\n");
3161
3162 // Range
3163 let content = thread
3164 .update(cx, |thread, cx| {
3165 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3166 })
3167 .await
3168 .unwrap();
3169
3170 assert_eq!(content, "two\nthree\n");
3171
3172 // Invalid
3173 let err = thread
3174 .update(cx, |thread, cx| {
3175 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3176 })
3177 .await
3178 .unwrap_err();
3179
3180 assert_eq!(
3181 err.to_string(),
3182 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3183 );
3184 }
3185
3186 #[gpui::test]
3187 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3188 init_test(cx);
3189
3190 let fs = FakeFs::new(cx.executor());
3191 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3192 let project = Project::test(fs.clone(), [], cx).await;
3193 project
3194 .update(cx, |project, cx| {
3195 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3196 })
3197 .await
3198 .unwrap();
3199
3200 let connection = Rc::new(FakeAgentConnection::new());
3201
3202 let thread = cx
3203 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3204 .await
3205 .unwrap();
3206
3207 // Whole file
3208 let content = thread
3209 .update(cx, |thread, cx| {
3210 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3211 })
3212 .await
3213 .unwrap();
3214
3215 assert_eq!(content, "");
3216
3217 // Only start line
3218 let content = thread
3219 .update(cx, |thread, cx| {
3220 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3221 })
3222 .await
3223 .unwrap();
3224
3225 assert_eq!(content, "");
3226
3227 // Only limit
3228 let content = thread
3229 .update(cx, |thread, cx| {
3230 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3231 })
3232 .await
3233 .unwrap();
3234
3235 assert_eq!(content, "");
3236
3237 // Range
3238 let content = thread
3239 .update(cx, |thread, cx| {
3240 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3241 })
3242 .await
3243 .unwrap();
3244
3245 assert_eq!(content, "");
3246
3247 // Invalid
3248 let err = thread
3249 .update(cx, |thread, cx| {
3250 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3251 })
3252 .await
3253 .unwrap_err();
3254
3255 assert_eq!(
3256 err.to_string(),
3257 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3258 );
3259 }
3260 #[gpui::test]
3261 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3262 init_test(cx);
3263
3264 let fs = FakeFs::new(cx.executor());
3265 fs.insert_tree(path!("/tmp"), json!({})).await;
3266 let project = Project::test(fs.clone(), [], cx).await;
3267 project
3268 .update(cx, |project, cx| {
3269 project.find_or_create_worktree(path!("/tmp"), true, cx)
3270 })
3271 .await
3272 .unwrap();
3273
3274 let connection = Rc::new(FakeAgentConnection::new());
3275
3276 let thread = cx
3277 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3278 .await
3279 .unwrap();
3280
3281 // Out of project file
3282 let err = thread
3283 .update(cx, |thread, cx| {
3284 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3285 })
3286 .await
3287 .unwrap_err();
3288
3289 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3290 }
3291
3292 #[gpui::test]
3293 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3294 init_test(cx);
3295
3296 let fs = FakeFs::new(cx.executor());
3297 let project = Project::test(fs, [], cx).await;
3298 let id = acp::ToolCallId::new("test");
3299
3300 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3301 let id = id.clone();
3302 move |_, thread, mut cx| {
3303 let id = id.clone();
3304 async move {
3305 thread
3306 .update(&mut cx, |thread, cx| {
3307 thread.handle_session_update(
3308 acp::SessionUpdate::ToolCall(
3309 acp::ToolCall::new(id.clone(), "Label")
3310 .kind(acp::ToolKind::Fetch)
3311 .status(acp::ToolCallStatus::InProgress),
3312 ),
3313 cx,
3314 )
3315 })
3316 .unwrap()
3317 .unwrap();
3318 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3319 }
3320 .boxed_local()
3321 }
3322 }));
3323
3324 let thread = cx
3325 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3326 .await
3327 .unwrap();
3328
3329 let request = thread.update(cx, |thread, cx| {
3330 thread.send_raw("Fetch https://example.com", cx)
3331 });
3332
3333 run_until_first_tool_call(&thread, cx).await;
3334
3335 thread.read_with(cx, |thread, _| {
3336 assert!(matches!(
3337 thread.entries[1],
3338 AgentThreadEntry::ToolCall(ToolCall {
3339 status: ToolCallStatus::InProgress,
3340 ..
3341 })
3342 ));
3343 });
3344
3345 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3346
3347 thread.read_with(cx, |thread, _| {
3348 assert!(matches!(
3349 &thread.entries[1],
3350 AgentThreadEntry::ToolCall(ToolCall {
3351 status: ToolCallStatus::Canceled,
3352 ..
3353 })
3354 ));
3355 });
3356
3357 thread
3358 .update(cx, |thread, cx| {
3359 thread.handle_session_update(
3360 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3361 id,
3362 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3363 )),
3364 cx,
3365 )
3366 })
3367 .unwrap();
3368
3369 request.await.unwrap();
3370
3371 thread.read_with(cx, |thread, _| {
3372 assert!(matches!(
3373 thread.entries[1],
3374 AgentThreadEntry::ToolCall(ToolCall {
3375 status: ToolCallStatus::Completed,
3376 ..
3377 })
3378 ));
3379 });
3380 }
3381
3382 #[gpui::test]
3383 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3384 init_test(cx);
3385 let fs = FakeFs::new(cx.background_executor.clone());
3386 fs.insert_tree(path!("/test"), json!({})).await;
3387 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3388
3389 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3390 move |_, thread, mut cx| {
3391 async move {
3392 thread
3393 .update(&mut cx, |thread, cx| {
3394 thread.handle_session_update(
3395 acp::SessionUpdate::ToolCall(
3396 acp::ToolCall::new("test", "Label")
3397 .kind(acp::ToolKind::Edit)
3398 .status(acp::ToolCallStatus::Completed)
3399 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3400 "/test/test.txt",
3401 "foo",
3402 ))]),
3403 ),
3404 cx,
3405 )
3406 })
3407 .unwrap()
3408 .unwrap();
3409 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3410 }
3411 .boxed_local()
3412 }
3413 }));
3414
3415 let thread = cx
3416 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3417 .await
3418 .unwrap();
3419
3420 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3421 .await
3422 .unwrap();
3423
3424 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3425 }
3426
3427 #[gpui::test(iterations = 10)]
3428 async fn test_checkpoints(cx: &mut TestAppContext) {
3429 init_test(cx);
3430 let fs = FakeFs::new(cx.background_executor.clone());
3431 fs.insert_tree(
3432 path!("/test"),
3433 json!({
3434 ".git": {}
3435 }),
3436 )
3437 .await;
3438 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3439
3440 let simulate_changes = Arc::new(AtomicBool::new(true));
3441 let next_filename = Arc::new(AtomicUsize::new(0));
3442 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3443 let simulate_changes = simulate_changes.clone();
3444 let next_filename = next_filename.clone();
3445 let fs = fs.clone();
3446 move |request, thread, mut cx| {
3447 let fs = fs.clone();
3448 let simulate_changes = simulate_changes.clone();
3449 let next_filename = next_filename.clone();
3450 async move {
3451 if simulate_changes.load(SeqCst) {
3452 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3453 fs.write(Path::new(&filename), b"").await?;
3454 }
3455
3456 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3457 panic!("expected text content block");
3458 };
3459 thread.update(&mut cx, |thread, cx| {
3460 thread
3461 .handle_session_update(
3462 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3463 content.text.to_uppercase().into(),
3464 )),
3465 cx,
3466 )
3467 .unwrap();
3468 })?;
3469 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3470 }
3471 .boxed_local()
3472 }
3473 }));
3474 let thread = cx
3475 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3476 .await
3477 .unwrap();
3478
3479 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3480 .await
3481 .unwrap();
3482 thread.read_with(cx, |thread, cx| {
3483 assert_eq!(
3484 thread.to_markdown(cx),
3485 indoc! {"
3486 ## User (checkpoint)
3487
3488 Lorem
3489
3490 ## Assistant
3491
3492 LOREM
3493
3494 "}
3495 );
3496 });
3497 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3498
3499 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3500 .await
3501 .unwrap();
3502 thread.read_with(cx, |thread, cx| {
3503 assert_eq!(
3504 thread.to_markdown(cx),
3505 indoc! {"
3506 ## User (checkpoint)
3507
3508 Lorem
3509
3510 ## Assistant
3511
3512 LOREM
3513
3514 ## User (checkpoint)
3515
3516 ipsum
3517
3518 ## Assistant
3519
3520 IPSUM
3521
3522 "}
3523 );
3524 });
3525 assert_eq!(
3526 fs.files(),
3527 vec![
3528 Path::new(path!("/test/file-0")),
3529 Path::new(path!("/test/file-1"))
3530 ]
3531 );
3532
3533 // Checkpoint isn't stored when there are no changes.
3534 simulate_changes.store(false, SeqCst);
3535 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3536 .await
3537 .unwrap();
3538 thread.read_with(cx, |thread, cx| {
3539 assert_eq!(
3540 thread.to_markdown(cx),
3541 indoc! {"
3542 ## User (checkpoint)
3543
3544 Lorem
3545
3546 ## Assistant
3547
3548 LOREM
3549
3550 ## User (checkpoint)
3551
3552 ipsum
3553
3554 ## Assistant
3555
3556 IPSUM
3557
3558 ## User
3559
3560 dolor
3561
3562 ## Assistant
3563
3564 DOLOR
3565
3566 "}
3567 );
3568 });
3569 assert_eq!(
3570 fs.files(),
3571 vec![
3572 Path::new(path!("/test/file-0")),
3573 Path::new(path!("/test/file-1"))
3574 ]
3575 );
3576
3577 // Rewinding the conversation truncates the history and restores the checkpoint.
3578 thread
3579 .update(cx, |thread, cx| {
3580 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3581 panic!("unexpected entries {:?}", thread.entries)
3582 };
3583 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3584 })
3585 .await
3586 .unwrap();
3587 thread.read_with(cx, |thread, cx| {
3588 assert_eq!(
3589 thread.to_markdown(cx),
3590 indoc! {"
3591 ## User (checkpoint)
3592
3593 Lorem
3594
3595 ## Assistant
3596
3597 LOREM
3598
3599 "}
3600 );
3601 });
3602 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3603 }
3604
3605 #[gpui::test]
3606 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3607 use std::sync::atomic::AtomicUsize;
3608 init_test(cx);
3609
3610 let fs = FakeFs::new(cx.executor());
3611 let project = Project::test(fs, None, cx).await;
3612
3613 // Create a connection that simulates refusal after tool result
3614 let prompt_count = Arc::new(AtomicUsize::new(0));
3615 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3616 let prompt_count = prompt_count.clone();
3617 move |_request, thread, mut cx| {
3618 let count = prompt_count.fetch_add(1, SeqCst);
3619 async move {
3620 if count == 0 {
3621 // First prompt: Generate a tool call with result
3622 thread.update(&mut cx, |thread, cx| {
3623 thread
3624 .handle_session_update(
3625 acp::SessionUpdate::ToolCall(
3626 acp::ToolCall::new("tool1", "Test Tool")
3627 .kind(acp::ToolKind::Fetch)
3628 .status(acp::ToolCallStatus::Completed)
3629 .raw_input(serde_json::json!({"query": "test"}))
3630 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3631 ),
3632 cx,
3633 )
3634 .unwrap();
3635 })?;
3636
3637 // Now return refusal because of the tool result
3638 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3639 } else {
3640 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3641 }
3642 }
3643 .boxed_local()
3644 }
3645 }));
3646
3647 let thread = cx
3648 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3649 .await
3650 .unwrap();
3651
3652 // Track if we see a Refusal event
3653 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3654 let saw_refusal_event_captured = saw_refusal_event.clone();
3655 thread.update(cx, |_thread, cx| {
3656 cx.subscribe(
3657 &thread,
3658 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3659 if matches!(event, AcpThreadEvent::Refusal) {
3660 *saw_refusal_event_captured.lock().unwrap() = true;
3661 }
3662 },
3663 )
3664 .detach();
3665 });
3666
3667 // Send a user message - this will trigger tool call and then refusal
3668 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3669 cx.background_executor.spawn(send_task).detach();
3670 cx.run_until_parked();
3671
3672 // Verify that:
3673 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3674 // 2. The user message was NOT truncated
3675 assert!(
3676 *saw_refusal_event.lock().unwrap(),
3677 "Refusal event should be emitted for tool result refusals"
3678 );
3679
3680 thread.read_with(cx, |thread, _| {
3681 let entries = thread.entries();
3682 assert!(entries.len() >= 2, "Should have user message and tool call");
3683
3684 // Verify user message is still there
3685 assert!(
3686 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3687 "User message should not be truncated"
3688 );
3689
3690 // Verify tool call is there with result
3691 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3692 assert!(
3693 tool_call.raw_output.is_some(),
3694 "Tool call should have output"
3695 );
3696 } else {
3697 panic!("Expected tool call at index 1");
3698 }
3699 });
3700 }
3701
3702 #[gpui::test]
3703 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3704 init_test(cx);
3705
3706 let fs = FakeFs::new(cx.executor());
3707 let project = Project::test(fs, None, cx).await;
3708
3709 let refuse_next = Arc::new(AtomicBool::new(false));
3710 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3711 let refuse_next = refuse_next.clone();
3712 move |_request, _thread, _cx| {
3713 if refuse_next.load(SeqCst) {
3714 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3715 .boxed_local()
3716 } else {
3717 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3718 .boxed_local()
3719 }
3720 }
3721 }));
3722
3723 let thread = cx
3724 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3725 .await
3726 .unwrap();
3727
3728 // Track if we see a Refusal event
3729 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3730 let saw_refusal_event_captured = saw_refusal_event.clone();
3731 thread.update(cx, |_thread, cx| {
3732 cx.subscribe(
3733 &thread,
3734 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3735 if matches!(event, AcpThreadEvent::Refusal) {
3736 *saw_refusal_event_captured.lock().unwrap() = true;
3737 }
3738 },
3739 )
3740 .detach();
3741 });
3742
3743 // Send a message that will be refused
3744 refuse_next.store(true, SeqCst);
3745 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3746 .await
3747 .unwrap();
3748
3749 // Verify that a Refusal event WAS emitted for user prompt refusal
3750 assert!(
3751 *saw_refusal_event.lock().unwrap(),
3752 "Refusal event should be emitted for user prompt refusals"
3753 );
3754
3755 // Verify the message was truncated (user prompt refusal)
3756 thread.read_with(cx, |thread, cx| {
3757 assert_eq!(thread.to_markdown(cx), "");
3758 });
3759 }
3760
3761 #[gpui::test]
3762 async fn test_refusal(cx: &mut TestAppContext) {
3763 init_test(cx);
3764 let fs = FakeFs::new(cx.background_executor.clone());
3765 fs.insert_tree(path!("/"), json!({})).await;
3766 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3767
3768 let refuse_next = Arc::new(AtomicBool::new(false));
3769 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3770 let refuse_next = refuse_next.clone();
3771 move |request, thread, mut cx| {
3772 let refuse_next = refuse_next.clone();
3773 async move {
3774 if refuse_next.load(SeqCst) {
3775 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3776 }
3777
3778 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3779 panic!("expected text content block");
3780 };
3781 thread.update(&mut cx, |thread, cx| {
3782 thread
3783 .handle_session_update(
3784 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3785 content.text.to_uppercase().into(),
3786 )),
3787 cx,
3788 )
3789 .unwrap();
3790 })?;
3791 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3792 }
3793 .boxed_local()
3794 }
3795 }));
3796 let thread = cx
3797 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3798 .await
3799 .unwrap();
3800
3801 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3802 .await
3803 .unwrap();
3804 thread.read_with(cx, |thread, cx| {
3805 assert_eq!(
3806 thread.to_markdown(cx),
3807 indoc! {"
3808 ## User
3809
3810 hello
3811
3812 ## Assistant
3813
3814 HELLO
3815
3816 "}
3817 );
3818 });
3819
3820 // Simulate refusing the second message. The message should be truncated
3821 // when a user prompt is refused.
3822 refuse_next.store(true, SeqCst);
3823 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3824 .await
3825 .unwrap();
3826 thread.read_with(cx, |thread, cx| {
3827 assert_eq!(
3828 thread.to_markdown(cx),
3829 indoc! {"
3830 ## User
3831
3832 hello
3833
3834 ## Assistant
3835
3836 HELLO
3837
3838 "}
3839 );
3840 });
3841 }
3842
3843 async fn run_until_first_tool_call(
3844 thread: &Entity<AcpThread>,
3845 cx: &mut TestAppContext,
3846 ) -> usize {
3847 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3848
3849 let subscription = cx.update(|cx| {
3850 cx.subscribe(thread, move |thread, _, cx| {
3851 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3852 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3853 return tx.try_send(ix).unwrap();
3854 }
3855 }
3856 })
3857 });
3858
3859 select! {
3860 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3861 panic!("Timeout waiting for tool call")
3862 }
3863 ix = rx.next().fuse() => {
3864 drop(subscription);
3865 ix.unwrap()
3866 }
3867 }
3868 }
3869
3870 #[derive(Clone, Default)]
3871 struct FakeAgentConnection {
3872 auth_methods: Vec<acp::AuthMethod>,
3873 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3874 on_user_message: Option<
3875 Rc<
3876 dyn Fn(
3877 acp::PromptRequest,
3878 WeakEntity<AcpThread>,
3879 AsyncApp,
3880 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3881 + 'static,
3882 >,
3883 >,
3884 }
3885
3886 impl FakeAgentConnection {
3887 fn new() -> Self {
3888 Self {
3889 auth_methods: Vec::new(),
3890 on_user_message: None,
3891 sessions: Arc::default(),
3892 }
3893 }
3894
3895 #[expect(unused)]
3896 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3897 self.auth_methods = auth_methods;
3898 self
3899 }
3900
3901 fn on_user_message(
3902 mut self,
3903 handler: impl Fn(
3904 acp::PromptRequest,
3905 WeakEntity<AcpThread>,
3906 AsyncApp,
3907 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3908 + 'static,
3909 ) -> Self {
3910 self.on_user_message.replace(Rc::new(handler));
3911 self
3912 }
3913 }
3914
3915 impl AgentConnection for FakeAgentConnection {
3916 fn telemetry_id(&self) -> SharedString {
3917 "fake".into()
3918 }
3919
3920 fn auth_methods(&self) -> &[acp::AuthMethod] {
3921 &self.auth_methods
3922 }
3923
3924 fn new_session(
3925 self: Rc<Self>,
3926 project: Entity<Project>,
3927 _cwd: &Path,
3928 cx: &mut App,
3929 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3930 let session_id = acp::SessionId::new(
3931 rand::rng()
3932 .sample_iter(&distr::Alphanumeric)
3933 .take(7)
3934 .map(char::from)
3935 .collect::<String>(),
3936 );
3937 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3938 let thread = cx.new(|cx| {
3939 AcpThread::new(
3940 None,
3941 "Test",
3942 self.clone(),
3943 project,
3944 action_log,
3945 session_id.clone(),
3946 watch::Receiver::constant(
3947 acp::PromptCapabilities::new()
3948 .image(true)
3949 .audio(true)
3950 .embedded_context(true),
3951 ),
3952 cx,
3953 )
3954 });
3955 self.sessions.lock().insert(session_id, thread.downgrade());
3956 Task::ready(Ok(thread))
3957 }
3958
3959 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3960 if self.auth_methods().iter().any(|m| m.id == method) {
3961 Task::ready(Ok(()))
3962 } else {
3963 Task::ready(Err(anyhow!("Invalid Auth Method")))
3964 }
3965 }
3966
3967 fn prompt(
3968 &self,
3969 _id: Option<UserMessageId>,
3970 params: acp::PromptRequest,
3971 cx: &mut App,
3972 ) -> Task<gpui::Result<acp::PromptResponse>> {
3973 let sessions = self.sessions.lock();
3974 let thread = sessions.get(¶ms.session_id).unwrap();
3975 if let Some(handler) = &self.on_user_message {
3976 let handler = handler.clone();
3977 let thread = thread.clone();
3978 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3979 } else {
3980 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3981 }
3982 }
3983
3984 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
3985
3986 fn truncate(
3987 &self,
3988 session_id: &acp::SessionId,
3989 _cx: &App,
3990 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3991 Some(Rc::new(FakeAgentSessionEditor {
3992 _session_id: session_id.clone(),
3993 }))
3994 }
3995
3996 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3997 self
3998 }
3999 }
4000
4001 struct FakeAgentSessionEditor {
4002 _session_id: acp::SessionId,
4003 }
4004
4005 impl AgentSessionTruncate for FakeAgentSessionEditor {
4006 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
4007 Task::ready(Ok(()))
4008 }
4009 }
4010
4011 #[gpui::test]
4012 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
4013 init_test(cx);
4014
4015 let fs = FakeFs::new(cx.executor());
4016 let project = Project::test(fs, [], cx).await;
4017 let connection = Rc::new(FakeAgentConnection::new());
4018 let thread = cx
4019 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4020 .await
4021 .unwrap();
4022
4023 // Try to update a tool call that doesn't exist
4024 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
4025 thread.update(cx, |thread, cx| {
4026 let result = thread.handle_session_update(
4027 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4028 nonexistent_id.clone(),
4029 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4030 )),
4031 cx,
4032 );
4033
4034 // The update should succeed (not return an error)
4035 assert!(result.is_ok());
4036
4037 // There should now be exactly one entry in the thread
4038 assert_eq!(thread.entries.len(), 1);
4039
4040 // The entry should be a failed tool call
4041 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4042 assert_eq!(tool_call.id, nonexistent_id);
4043 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4044 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4045
4046 // Check that the content contains the error message
4047 assert_eq!(tool_call.content.len(), 1);
4048 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4049 match content_block {
4050 ContentBlock::Markdown { markdown } => {
4051 let markdown_text = markdown.read(cx).source();
4052 assert!(markdown_text.contains("Tool call not found"));
4053 }
4054 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4055 ContentBlock::ResourceLink { .. } => {
4056 panic!("Expected markdown content, got resource link")
4057 }
4058 ContentBlock::Image { .. } => {
4059 panic!("Expected markdown content, got image")
4060 }
4061 }
4062 } else {
4063 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4064 }
4065 } else {
4066 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4067 }
4068 });
4069 }
4070
4071 /// Tests that restoring a checkpoint properly cleans up terminals that were
4072 /// created after that checkpoint, and cancels any in-progress generation.
4073 ///
4074 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4075 /// that were started after that checkpoint should be terminated, and any in-progress
4076 /// AI generation should be canceled.
4077 #[gpui::test]
4078 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4079 init_test(cx);
4080
4081 let fs = FakeFs::new(cx.executor());
4082 let project = Project::test(fs, [], cx).await;
4083 let connection = Rc::new(FakeAgentConnection::new());
4084 let thread = cx
4085 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4086 .await
4087 .unwrap();
4088
4089 // Send first user message to create a checkpoint
4090 cx.update(|cx| {
4091 thread.update(cx, |thread, cx| {
4092 thread.send(vec!["first message".into()], cx)
4093 })
4094 })
4095 .await
4096 .unwrap();
4097
4098 // Send second message (creates another checkpoint) - we'll restore to this one
4099 cx.update(|cx| {
4100 thread.update(cx, |thread, cx| {
4101 thread.send(vec!["second message".into()], cx)
4102 })
4103 })
4104 .await
4105 .unwrap();
4106
4107 // Create 2 terminals BEFORE the checkpoint that have completed running
4108 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4109 let mock_terminal_1 = cx.new(|cx| {
4110 let builder = ::terminal::TerminalBuilder::new_display_only(
4111 ::terminal::terminal_settings::CursorShape::default(),
4112 ::terminal::terminal_settings::AlternateScroll::On,
4113 None,
4114 0,
4115 cx.background_executor(),
4116 PathStyle::local(),
4117 )
4118 .unwrap();
4119 builder.subscribe(cx)
4120 });
4121
4122 thread.update(cx, |thread, cx| {
4123 thread.on_terminal_provider_event(
4124 TerminalProviderEvent::Created {
4125 terminal_id: terminal_id_1.clone(),
4126 label: "echo 'first'".to_string(),
4127 cwd: Some(PathBuf::from("/test")),
4128 output_byte_limit: None,
4129 terminal: mock_terminal_1.clone(),
4130 },
4131 cx,
4132 );
4133 });
4134
4135 thread.update(cx, |thread, cx| {
4136 thread.on_terminal_provider_event(
4137 TerminalProviderEvent::Output {
4138 terminal_id: terminal_id_1.clone(),
4139 data: b"first\n".to_vec(),
4140 },
4141 cx,
4142 );
4143 });
4144
4145 thread.update(cx, |thread, cx| {
4146 thread.on_terminal_provider_event(
4147 TerminalProviderEvent::Exit {
4148 terminal_id: terminal_id_1.clone(),
4149 status: acp::TerminalExitStatus::new().exit_code(0),
4150 },
4151 cx,
4152 );
4153 });
4154
4155 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4156 let mock_terminal_2 = cx.new(|cx| {
4157 let builder = ::terminal::TerminalBuilder::new_display_only(
4158 ::terminal::terminal_settings::CursorShape::default(),
4159 ::terminal::terminal_settings::AlternateScroll::On,
4160 None,
4161 0,
4162 cx.background_executor(),
4163 PathStyle::local(),
4164 )
4165 .unwrap();
4166 builder.subscribe(cx)
4167 });
4168
4169 thread.update(cx, |thread, cx| {
4170 thread.on_terminal_provider_event(
4171 TerminalProviderEvent::Created {
4172 terminal_id: terminal_id_2.clone(),
4173 label: "echo 'second'".to_string(),
4174 cwd: Some(PathBuf::from("/test")),
4175 output_byte_limit: None,
4176 terminal: mock_terminal_2.clone(),
4177 },
4178 cx,
4179 );
4180 });
4181
4182 thread.update(cx, |thread, cx| {
4183 thread.on_terminal_provider_event(
4184 TerminalProviderEvent::Output {
4185 terminal_id: terminal_id_2.clone(),
4186 data: b"second\n".to_vec(),
4187 },
4188 cx,
4189 );
4190 });
4191
4192 thread.update(cx, |thread, cx| {
4193 thread.on_terminal_provider_event(
4194 TerminalProviderEvent::Exit {
4195 terminal_id: terminal_id_2.clone(),
4196 status: acp::TerminalExitStatus::new().exit_code(0),
4197 },
4198 cx,
4199 );
4200 });
4201
4202 // Get the second message ID to restore to
4203 let second_message_id = thread.read_with(cx, |thread, _| {
4204 // At this point we have:
4205 // - Index 0: First user message (with checkpoint)
4206 // - Index 1: Second user message (with checkpoint)
4207 // No assistant responses because FakeAgentConnection just returns EndTurn
4208 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4209 panic!("expected user message at index 1");
4210 };
4211 message.id.clone().unwrap()
4212 });
4213
4214 // Create a terminal AFTER the checkpoint we'll restore to.
4215 // This simulates the AI agent starting a long-running terminal command.
4216 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4217 let mock_terminal = cx.new(|cx| {
4218 let builder = ::terminal::TerminalBuilder::new_display_only(
4219 ::terminal::terminal_settings::CursorShape::default(),
4220 ::terminal::terminal_settings::AlternateScroll::On,
4221 None,
4222 0,
4223 cx.background_executor(),
4224 PathStyle::local(),
4225 )
4226 .unwrap();
4227 builder.subscribe(cx)
4228 });
4229
4230 // Register the terminal as created
4231 thread.update(cx, |thread, cx| {
4232 thread.on_terminal_provider_event(
4233 TerminalProviderEvent::Created {
4234 terminal_id: terminal_id.clone(),
4235 label: "sleep 1000".to_string(),
4236 cwd: Some(PathBuf::from("/test")),
4237 output_byte_limit: None,
4238 terminal: mock_terminal.clone(),
4239 },
4240 cx,
4241 );
4242 });
4243
4244 // Simulate the terminal producing output (still running)
4245 thread.update(cx, |thread, cx| {
4246 thread.on_terminal_provider_event(
4247 TerminalProviderEvent::Output {
4248 terminal_id: terminal_id.clone(),
4249 data: b"terminal is running...\n".to_vec(),
4250 },
4251 cx,
4252 );
4253 });
4254
4255 // Create a tool call entry that references this terminal
4256 // This represents the agent requesting a terminal command
4257 thread.update(cx, |thread, cx| {
4258 thread
4259 .handle_session_update(
4260 acp::SessionUpdate::ToolCall(
4261 acp::ToolCall::new("terminal-tool-1", "Running command")
4262 .kind(acp::ToolKind::Execute)
4263 .status(acp::ToolCallStatus::InProgress)
4264 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4265 terminal_id.clone(),
4266 ))])
4267 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4268 ),
4269 cx,
4270 )
4271 .unwrap();
4272 });
4273
4274 // Verify terminal exists and is in the thread
4275 let terminal_exists_before =
4276 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4277 assert!(
4278 terminal_exists_before,
4279 "Terminal should exist before checkpoint restore"
4280 );
4281
4282 // Verify the terminal's underlying task is still running (not completed)
4283 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4284 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4285 terminal_entity.read_with(cx, |term, _cx| {
4286 term.output().is_none() // output is None means it's still running
4287 })
4288 });
4289 assert!(
4290 terminal_running_before,
4291 "Terminal should be running before checkpoint restore"
4292 );
4293
4294 // Verify we have the expected entries before restore
4295 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4296 assert!(
4297 entry_count_before > 1,
4298 "Should have multiple entries before restore"
4299 );
4300
4301 // Restore the checkpoint to the second message.
4302 // This should:
4303 // 1. Cancel any in-progress generation (via the cancel() call)
4304 // 2. Remove the terminal that was created after that point
4305 thread
4306 .update(cx, |thread, cx| {
4307 thread.restore_checkpoint(second_message_id, cx)
4308 })
4309 .await
4310 .unwrap();
4311
4312 // Verify that no send_task is in progress after restore
4313 // (cancel() clears the send_task)
4314 let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4315 assert!(
4316 !has_send_task_after,
4317 "Should not have a send_task after restore (cancel should have cleared it)"
4318 );
4319
4320 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4321 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4322 assert_eq!(
4323 entry_count, 1,
4324 "Should have 1 entry after restore (only the first user message)"
4325 );
4326
4327 // Verify the 2 completed terminals from before the checkpoint still exist
4328 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4329 thread.terminals.contains_key(&terminal_id_1)
4330 });
4331 assert!(
4332 terminal_1_exists,
4333 "Terminal 1 (from before checkpoint) should still exist"
4334 );
4335
4336 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4337 thread.terminals.contains_key(&terminal_id_2)
4338 });
4339 assert!(
4340 terminal_2_exists,
4341 "Terminal 2 (from before checkpoint) should still exist"
4342 );
4343
4344 // Verify they're still in completed state
4345 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4346 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4347 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4348 });
4349 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4350
4351 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4352 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4353 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4354 });
4355 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4356
4357 // Verify the running terminal (created after checkpoint) was removed
4358 let terminal_3_exists =
4359 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4360 assert!(
4361 !terminal_3_exists,
4362 "Terminal 3 (created after checkpoint) should have been removed"
4363 );
4364
4365 // Verify total count is 2 (the two from before the checkpoint)
4366 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4367 assert_eq!(
4368 terminal_count, 2,
4369 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4370 );
4371 }
4372
4373 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4374 /// even when a new user message is added while the async checkpoint comparison is in progress.
4375 ///
4376 /// This is a regression test for a bug where update_last_checkpoint would fail with
4377 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4378 /// update_last_checkpoint started and when its async closure ran.
4379 #[gpui::test]
4380 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4381 init_test(cx);
4382
4383 let fs = FakeFs::new(cx.executor());
4384 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4385 .await;
4386 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4387
4388 let handler_done = Arc::new(AtomicBool::new(false));
4389 let handler_done_clone = handler_done.clone();
4390 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4391 move |_, _thread, _cx| {
4392 handler_done_clone.store(true, SeqCst);
4393 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4394 },
4395 ));
4396
4397 let thread = cx
4398 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4399 .await
4400 .unwrap();
4401
4402 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4403 let send_task = cx.background_executor.spawn(send_future);
4404
4405 // Tick until handler completes, then a few more to let update_last_checkpoint start
4406 while !handler_done.load(SeqCst) {
4407 cx.executor().tick();
4408 }
4409 for _ in 0..5 {
4410 cx.executor().tick();
4411 }
4412
4413 thread.update(cx, |thread, cx| {
4414 thread.push_entry(
4415 AgentThreadEntry::UserMessage(UserMessage {
4416 id: Some(UserMessageId::new()),
4417 content: ContentBlock::Empty,
4418 chunks: vec!["Injected message (no checkpoint)".into()],
4419 checkpoint: None,
4420 indented: false,
4421 }),
4422 cx,
4423 );
4424 });
4425
4426 cx.run_until_parked();
4427 let result = send_task.await;
4428
4429 assert!(
4430 result.is_ok(),
4431 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4432 result.err()
4433 );
4434 }
4435
4436 /// Tests that when a follow-up message is sent during generation,
4437 /// the first turn completing does NOT clear `running_turn` because
4438 /// it now belongs to the second turn.
4439 #[gpui::test]
4440 async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
4441 init_test(cx);
4442
4443 let fs = FakeFs::new(cx.executor());
4444 let project = Project::test(fs, [], cx).await;
4445
4446 // First handler waits for this signal before completing
4447 let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
4448 let first_complete_rx = RefCell::new(Some(first_complete_rx));
4449
4450 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
4451 move |params, _thread, _cx| {
4452 let first_complete_rx = first_complete_rx.borrow_mut().take();
4453 let is_first = params
4454 .prompt
4455 .iter()
4456 .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
4457
4458 async move {
4459 if is_first {
4460 // First handler waits until signaled
4461 if let Some(rx) = first_complete_rx {
4462 rx.await.ok();
4463 }
4464 }
4465 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
4466 }
4467 .boxed_local()
4468 }
4469 }));
4470
4471 let thread = cx
4472 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4473 .await
4474 .unwrap();
4475
4476 // Send first message (turn_id=1) - handler will block
4477 let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
4478 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
4479
4480 // Send second message (turn_id=2) while first is still blocked
4481 // This calls cancel() which takes turn 1's running_turn and sets turn 2's
4482 let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
4483 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
4484
4485 let running_turn_after_second_send =
4486 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4487 assert_eq!(
4488 running_turn_after_second_send,
4489 Some(2),
4490 "running_turn should be set to turn 2 after sending second message"
4491 );
4492
4493 // Now signal first handler to complete
4494 first_complete_tx.send(()).ok();
4495
4496 // First request completes - should NOT clear running_turn
4497 // because running_turn now belongs to turn 2
4498 first_request.await.unwrap();
4499
4500 let running_turn_after_first =
4501 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4502 assert_eq!(
4503 running_turn_after_first,
4504 Some(2),
4505 "first turn completing should not clear running_turn (belongs to turn 2)"
4506 );
4507
4508 // Second request completes - SHOULD clear running_turn
4509 second_request.await.unwrap();
4510
4511 let running_turn_after_second =
4512 thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4513 assert!(
4514 !running_turn_after_second,
4515 "second turn completing should clear running_turn"
4516 );
4517 }
4518
4519 #[gpui::test]
4520 async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
4521 cx: &mut TestAppContext,
4522 ) {
4523 init_test(cx);
4524
4525 let fs = FakeFs::new(cx.executor());
4526 let project = Project::test(fs, [], cx).await;
4527
4528 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4529 move |_params, thread, mut cx| {
4530 async move {
4531 thread
4532 .update(&mut cx, |thread, cx| {
4533 thread.handle_session_update(
4534 acp::SessionUpdate::ToolCall(
4535 acp::ToolCall::new(
4536 acp::ToolCallId::new("test-tool"),
4537 "Test Tool",
4538 )
4539 .kind(acp::ToolKind::Fetch)
4540 .status(acp::ToolCallStatus::InProgress),
4541 ),
4542 cx,
4543 )
4544 })
4545 .unwrap()
4546 .unwrap();
4547
4548 Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
4549 }
4550 .boxed_local()
4551 },
4552 ));
4553
4554 let thread = cx
4555 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4556 .await
4557 .unwrap();
4558
4559 let response = thread
4560 .update(cx, |thread, cx| thread.send_raw("test message", cx))
4561 .await;
4562
4563 let response = response
4564 .expect("send should succeed")
4565 .expect("should have response");
4566 assert_eq!(
4567 response.stop_reason,
4568 acp::StopReason::Cancelled,
4569 "response should have Cancelled stop_reason"
4570 );
4571
4572 thread.read_with(cx, |thread, _| {
4573 let tool_entry = thread
4574 .entries
4575 .iter()
4576 .find_map(|e| {
4577 if let AgentThreadEntry::ToolCall(call) = e {
4578 Some(call)
4579 } else {
4580 None
4581 }
4582 })
4583 .expect("should have tool call entry");
4584
4585 assert!(
4586 matches!(tool_entry.status, ToolCallStatus::Canceled),
4587 "tool should be marked as Canceled when response is Cancelled, got {:?}",
4588 tool_entry.status
4589 );
4590 });
4591 }
4592}