1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 if self.kind == acp::ToolKind::Execute {
318 for terminal in self.terminals() {
319 terminal.update(cx, |terminal, cx| {
320 terminal.update_command_label(&title, cx);
321 });
322 }
323 }
324 self.label.update(cx, |label, cx| {
325 if self.kind == acp::ToolKind::Execute {
326 label.replace(title, cx);
327 } else if let Some((first_line, _)) = title.split_once("\n") {
328 label.replace(first_line.to_owned() + "…", cx);
329 } else {
330 label.replace(title, cx);
331 }
332 });
333 }
334
335 if let Some(content) = content {
336 let mut new_content_len = content.len();
337 let mut content = content.into_iter();
338
339 // Reuse existing content if we can
340 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
341 let valid_content =
342 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
343 if !valid_content {
344 new_content_len -= 1;
345 }
346 }
347 for new in content {
348 if let Some(new) = ToolCallContent::from_acp(
349 new,
350 language_registry.clone(),
351 path_style,
352 terminals,
353 cx,
354 )? {
355 self.content.push(new);
356 } else {
357 new_content_len -= 1;
358 }
359 }
360 self.content.truncate(new_content_len);
361 }
362
363 if let Some(locations) = locations {
364 self.locations = locations;
365 }
366
367 if let Some(raw_input) = raw_input {
368 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
369 self.raw_input = Some(raw_input);
370 }
371
372 if let Some(raw_output) = raw_output {
373 if self.content.is_empty()
374 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
375 {
376 self.content
377 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
378 markdown,
379 }));
380 }
381 self.raw_output = Some(raw_output);
382 }
383 Ok(())
384 }
385
386 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
387 self.content.iter().filter_map(|content| match content {
388 ToolCallContent::Diff(diff) => Some(diff),
389 ToolCallContent::ContentBlock(_) => None,
390 ToolCallContent::Terminal(_) => None,
391 })
392 }
393
394 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
395 self.content.iter().filter_map(|content| match content {
396 ToolCallContent::Terminal(terminal) => Some(terminal),
397 ToolCallContent::ContentBlock(_) => None,
398 ToolCallContent::Diff(_) => None,
399 })
400 }
401
402 pub fn is_subagent(&self) -> bool {
403 self.tool_name.as_ref().is_some_and(|s| s == "spawn_agent")
404 || self.subagent_session_id.is_some()
405 }
406
407 pub fn to_markdown(&self, cx: &App) -> String {
408 let mut markdown = format!(
409 "**Tool Call: {}**\nStatus: {}\n\n",
410 self.label.read(cx).source(),
411 self.status
412 );
413 for content in &self.content {
414 markdown.push_str(content.to_markdown(cx).as_str());
415 markdown.push_str("\n\n");
416 }
417 markdown
418 }
419
420 async fn resolve_location(
421 location: acp::ToolCallLocation,
422 project: WeakEntity<Project>,
423 cx: &mut AsyncApp,
424 ) -> Option<ResolvedLocation> {
425 let buffer = project
426 .update(cx, |project, cx| {
427 project
428 .project_path_for_absolute_path(&location.path, cx)
429 .map(|path| project.open_buffer(path, cx))
430 })
431 .ok()??;
432 let buffer = buffer.await.log_err()?;
433 let position = buffer.update(cx, |buffer, _| {
434 let snapshot = buffer.snapshot();
435 if let Some(row) = location.line {
436 let column = snapshot.indent_size_for_line(row).len;
437 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
438 snapshot.anchor_before(point)
439 } else {
440 Anchor::min_for_buffer(snapshot.remote_id())
441 }
442 });
443
444 Some(ResolvedLocation { buffer, position })
445 }
446
447 fn resolve_locations(
448 &self,
449 project: Entity<Project>,
450 cx: &mut App,
451 ) -> Task<Vec<Option<ResolvedLocation>>> {
452 let locations = self.locations.clone();
453 project.update(cx, |_, cx| {
454 cx.spawn(async move |project, cx| {
455 let mut new_locations = Vec::new();
456 for location in locations {
457 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
458 }
459 new_locations
460 })
461 })
462 }
463}
464
465// Separate so we can hold a strong reference to the buffer
466// for saving on the thread
467#[derive(Clone, Debug, PartialEq, Eq)]
468struct ResolvedLocation {
469 buffer: Entity<Buffer>,
470 position: Anchor,
471}
472
473impl From<&ResolvedLocation> for AgentLocation {
474 fn from(value: &ResolvedLocation) -> Self {
475 Self {
476 buffer: value.buffer.downgrade(),
477 position: value.position,
478 }
479 }
480}
481
482#[derive(Debug)]
483pub enum ToolCallStatus {
484 /// The tool call hasn't started running yet, but we start showing it to
485 /// the user.
486 Pending,
487 /// The tool call is waiting for confirmation from the user.
488 WaitingForConfirmation {
489 options: PermissionOptions,
490 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
491 },
492 /// The tool call is currently running.
493 InProgress,
494 /// The tool call completed successfully.
495 Completed,
496 /// The tool call failed.
497 Failed,
498 /// The user rejected the tool call.
499 Rejected,
500 /// The user canceled generation so the tool call was canceled.
501 Canceled,
502}
503
504impl From<acp::ToolCallStatus> for ToolCallStatus {
505 fn from(status: acp::ToolCallStatus) -> Self {
506 match status {
507 acp::ToolCallStatus::Pending => Self::Pending,
508 acp::ToolCallStatus::InProgress => Self::InProgress,
509 acp::ToolCallStatus::Completed => Self::Completed,
510 acp::ToolCallStatus::Failed => Self::Failed,
511 _ => Self::Pending,
512 }
513 }
514}
515
516impl Display for ToolCallStatus {
517 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "{}",
521 match self {
522 ToolCallStatus::Pending => "Pending",
523 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
524 ToolCallStatus::InProgress => "In Progress",
525 ToolCallStatus::Completed => "Completed",
526 ToolCallStatus::Failed => "Failed",
527 ToolCallStatus::Rejected => "Rejected",
528 ToolCallStatus::Canceled => "Canceled",
529 }
530 )
531 }
532}
533
534#[derive(Debug, PartialEq, Clone)]
535pub enum ContentBlock {
536 Empty,
537 Markdown { markdown: Entity<Markdown> },
538 ResourceLink { resource_link: acp::ResourceLink },
539 Image { image: Arc<gpui::Image> },
540}
541
542impl ContentBlock {
543 pub fn new(
544 block: acp::ContentBlock,
545 language_registry: &Arc<LanguageRegistry>,
546 path_style: PathStyle,
547 cx: &mut App,
548 ) -> Self {
549 let mut this = Self::Empty;
550 this.append(block, language_registry, path_style, cx);
551 this
552 }
553
554 pub fn new_combined(
555 blocks: impl IntoIterator<Item = acp::ContentBlock>,
556 language_registry: Arc<LanguageRegistry>,
557 path_style: PathStyle,
558 cx: &mut App,
559 ) -> Self {
560 let mut this = Self::Empty;
561 for block in blocks {
562 this.append(block, &language_registry, path_style, cx);
563 }
564 this
565 }
566
567 pub fn append(
568 &mut self,
569 block: acp::ContentBlock,
570 language_registry: &Arc<LanguageRegistry>,
571 path_style: PathStyle,
572 cx: &mut App,
573 ) {
574 match (&mut *self, &block) {
575 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
576 *self = ContentBlock::ResourceLink {
577 resource_link: resource_link.clone(),
578 };
579 }
580 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
581 if let Some(image) = Self::decode_image(image_content) {
582 *self = ContentBlock::Image { image };
583 } else {
584 let new_content = Self::image_md(image_content);
585 *self = Self::create_markdown_block(new_content, language_registry, cx);
586 }
587 }
588 (ContentBlock::Empty, _) => {
589 let new_content = Self::block_string_contents(&block, path_style);
590 *self = Self::create_markdown_block(new_content, language_registry, cx);
591 }
592 (ContentBlock::Markdown { markdown }, _) => {
593 let new_content = Self::block_string_contents(&block, path_style);
594 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
595 }
596 (ContentBlock::ResourceLink { resource_link }, _) => {
597 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
598 let new_content = Self::block_string_contents(&block, path_style);
599 let combined = format!("{}\n{}", existing_content, new_content);
600 *self = Self::create_markdown_block(combined, language_registry, cx);
601 }
602 (ContentBlock::Image { .. }, _) => {
603 let new_content = Self::block_string_contents(&block, path_style);
604 let combined = format!("`Image`\n{}", new_content);
605 *self = Self::create_markdown_block(combined, language_registry, cx);
606 }
607 }
608 }
609
610 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
611 use base64::Engine as _;
612
613 let bytes = base64::engine::general_purpose::STANDARD
614 .decode(image_content.data.as_bytes())
615 .ok()?;
616 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
617 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
618 }
619
620 fn create_markdown_block(
621 content: String,
622 language_registry: &Arc<LanguageRegistry>,
623 cx: &mut App,
624 ) -> ContentBlock {
625 ContentBlock::Markdown {
626 markdown: cx
627 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
628 }
629 }
630
631 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
632 match block {
633 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
634 acp::ContentBlock::ResourceLink(resource_link) => {
635 Self::resource_link_md(&resource_link.uri, path_style)
636 }
637 acp::ContentBlock::Resource(acp::EmbeddedResource {
638 resource:
639 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
640 uri,
641 ..
642 }),
643 ..
644 }) => Self::resource_link_md(uri, path_style),
645 acp::ContentBlock::Image(image) => Self::image_md(image),
646 _ => String::new(),
647 }
648 }
649
650 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
651 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
652 uri.as_link().to_string()
653 } else {
654 uri.to_string()
655 }
656 }
657
658 fn image_md(_image: &acp::ImageContent) -> String {
659 "`Image`".into()
660 }
661
662 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
663 match self {
664 ContentBlock::Empty => "",
665 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
666 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
667 ContentBlock::Image { .. } => "`Image`",
668 }
669 }
670
671 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
672 match self {
673 ContentBlock::Empty => None,
674 ContentBlock::Markdown { markdown } => Some(markdown),
675 ContentBlock::ResourceLink { .. } => None,
676 ContentBlock::Image { .. } => None,
677 }
678 }
679
680 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
681 match self {
682 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
683 _ => None,
684 }
685 }
686
687 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
688 match self {
689 ContentBlock::Image { image } => Some(image),
690 _ => None,
691 }
692 }
693}
694
695#[derive(Debug)]
696pub enum ToolCallContent {
697 ContentBlock(ContentBlock),
698 Diff(Entity<Diff>),
699 Terminal(Entity<Terminal>),
700}
701
702impl ToolCallContent {
703 pub fn from_acp(
704 content: acp::ToolCallContent,
705 language_registry: Arc<LanguageRegistry>,
706 path_style: PathStyle,
707 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
708 cx: &mut App,
709 ) -> Result<Option<Self>> {
710 match content {
711 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
712 Ok(Some(Self::ContentBlock(ContentBlock::new(
713 content,
714 &language_registry,
715 path_style,
716 cx,
717 ))))
718 }
719 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
720 Diff::finalized(
721 diff.path.to_string_lossy().into_owned(),
722 diff.old_text,
723 diff.new_text,
724 language_registry,
725 cx,
726 )
727 })))),
728 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
729 .get(&terminal_id)
730 .cloned()
731 .map(|terminal| Some(Self::Terminal(terminal)))
732 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
733 _ => Ok(None),
734 }
735 }
736
737 pub fn update_from_acp(
738 &mut self,
739 new: acp::ToolCallContent,
740 language_registry: Arc<LanguageRegistry>,
741 path_style: PathStyle,
742 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
743 cx: &mut App,
744 ) -> Result<bool> {
745 let needs_update = match (&self, &new) {
746 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
747 old_diff.read(cx).needs_update(
748 new_diff.old_text.as_deref().unwrap_or(""),
749 &new_diff.new_text,
750 cx,
751 )
752 }
753 _ => true,
754 };
755
756 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
757 if needs_update {
758 *self = update;
759 }
760 Ok(true)
761 } else {
762 Ok(false)
763 }
764 }
765
766 pub fn to_markdown(&self, cx: &App) -> String {
767 match self {
768 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
769 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
770 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
771 }
772 }
773
774 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
775 match self {
776 Self::ContentBlock(content) => content.image(),
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, PartialEq)]
783pub enum ToolCallUpdate {
784 UpdateFields(acp::ToolCallUpdate),
785 UpdateDiff(ToolCallUpdateDiff),
786 UpdateTerminal(ToolCallUpdateTerminal),
787}
788
789impl ToolCallUpdate {
790 fn id(&self) -> &acp::ToolCallId {
791 match self {
792 Self::UpdateFields(update) => &update.tool_call_id,
793 Self::UpdateDiff(diff) => &diff.id,
794 Self::UpdateTerminal(terminal) => &terminal.id,
795 }
796 }
797}
798
799impl From<acp::ToolCallUpdate> for ToolCallUpdate {
800 fn from(update: acp::ToolCallUpdate) -> Self {
801 Self::UpdateFields(update)
802 }
803}
804
805impl From<ToolCallUpdateDiff> for ToolCallUpdate {
806 fn from(diff: ToolCallUpdateDiff) -> Self {
807 Self::UpdateDiff(diff)
808 }
809}
810
811#[derive(Debug, PartialEq)]
812pub struct ToolCallUpdateDiff {
813 pub id: acp::ToolCallId,
814 pub diff: Entity<Diff>,
815}
816
817impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
818 fn from(terminal: ToolCallUpdateTerminal) -> Self {
819 Self::UpdateTerminal(terminal)
820 }
821}
822
823#[derive(Debug, PartialEq)]
824pub struct ToolCallUpdateTerminal {
825 pub id: acp::ToolCallId,
826 pub terminal: Entity<Terminal>,
827}
828
829#[derive(Debug, Default)]
830pub struct Plan {
831 pub entries: Vec<PlanEntry>,
832}
833
834#[derive(Debug)]
835pub struct PlanStats<'a> {
836 pub in_progress_entry: Option<&'a PlanEntry>,
837 pub pending: u32,
838 pub completed: u32,
839}
840
841impl Plan {
842 pub fn is_empty(&self) -> bool {
843 self.entries.is_empty()
844 }
845
846 pub fn stats(&self) -> PlanStats<'_> {
847 let mut stats = PlanStats {
848 in_progress_entry: None,
849 pending: 0,
850 completed: 0,
851 };
852
853 for entry in &self.entries {
854 match &entry.status {
855 acp::PlanEntryStatus::Pending => {
856 stats.pending += 1;
857 }
858 acp::PlanEntryStatus::InProgress => {
859 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
860 }
861 acp::PlanEntryStatus::Completed => {
862 stats.completed += 1;
863 }
864 _ => {}
865 }
866 }
867
868 stats
869 }
870}
871
872#[derive(Debug)]
873pub struct PlanEntry {
874 pub content: Entity<Markdown>,
875 pub priority: acp::PlanEntryPriority,
876 pub status: acp::PlanEntryStatus,
877}
878
879impl PlanEntry {
880 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
881 Self {
882 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
883 priority: entry.priority,
884 status: entry.status,
885 }
886 }
887}
888
889#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
890pub struct TokenUsage {
891 pub max_tokens: u64,
892 pub used_tokens: u64,
893 pub input_tokens: u64,
894 pub output_tokens: u64,
895 pub max_output_tokens: Option<u64>,
896}
897
898impl TokenUsage {
899 pub fn ratio(&self) -> TokenUsageRatio {
900 #[cfg(debug_assertions)]
901 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
902 .unwrap_or("0.8".to_string())
903 .parse()
904 .unwrap();
905 #[cfg(not(debug_assertions))]
906 let warning_threshold: f32 = 0.8;
907
908 // When the maximum is unknown because there is no selected model,
909 // avoid showing the token limit warning.
910 if self.max_tokens == 0 {
911 TokenUsageRatio::Normal
912 } else if self.used_tokens >= self.max_tokens {
913 TokenUsageRatio::Exceeded
914 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
915 TokenUsageRatio::Warning
916 } else {
917 TokenUsageRatio::Normal
918 }
919 }
920}
921
922#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
923pub enum TokenUsageRatio {
924 Normal,
925 Warning,
926 Exceeded,
927}
928
929#[derive(Debug, Clone)]
930pub struct RetryStatus {
931 pub last_error: SharedString,
932 pub attempt: usize,
933 pub max_attempts: usize,
934 pub started_at: Instant,
935 pub duration: Duration,
936}
937
938struct RunningTurn {
939 id: u32,
940 send_task: Task<()>,
941}
942
943pub struct AcpThread {
944 parent_session_id: Option<acp::SessionId>,
945 title: SharedString,
946 entries: Vec<AgentThreadEntry>,
947 plan: Plan,
948 project: Entity<Project>,
949 action_log: Entity<ActionLog>,
950 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
951 turn_id: u32,
952 running_turn: Option<RunningTurn>,
953 connection: Rc<dyn AgentConnection>,
954 session_id: acp::SessionId,
955 token_usage: Option<TokenUsage>,
956 prompt_capabilities: acp::PromptCapabilities,
957 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
958 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
959 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
960 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
961 had_error: bool,
962}
963
964impl From<&AcpThread> for ActionLogTelemetry {
965 fn from(value: &AcpThread) -> Self {
966 Self {
967 agent_telemetry_id: value.connection().telemetry_id(),
968 session_id: value.session_id.0.clone(),
969 }
970 }
971}
972
973#[derive(Debug)]
974pub enum AcpThreadEvent {
975 NewEntry,
976 TitleUpdated,
977 TokenUsageUpdated,
978 EntryUpdated(usize),
979 EntriesRemoved(Range<usize>),
980 ToolAuthorizationRequested(acp::ToolCallId),
981 ToolAuthorizationReceived(acp::ToolCallId),
982 Retry(RetryStatus),
983 SubagentSpawned(acp::SessionId),
984 Stopped,
985 Error,
986 LoadError(LoadError),
987 PromptCapabilitiesUpdated,
988 Refusal,
989 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
990 ModeUpdated(acp::SessionModeId),
991 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
992}
993
994impl EventEmitter<AcpThreadEvent> for AcpThread {}
995
996#[derive(Debug, Clone)]
997pub enum TerminalProviderEvent {
998 Created {
999 terminal_id: acp::TerminalId,
1000 label: String,
1001 cwd: Option<PathBuf>,
1002 output_byte_limit: Option<u64>,
1003 terminal: Entity<::terminal::Terminal>,
1004 },
1005 Output {
1006 terminal_id: acp::TerminalId,
1007 data: Vec<u8>,
1008 },
1009 TitleChanged {
1010 terminal_id: acp::TerminalId,
1011 title: String,
1012 },
1013 Exit {
1014 terminal_id: acp::TerminalId,
1015 status: acp::TerminalExitStatus,
1016 },
1017}
1018
1019#[derive(Debug, Clone)]
1020pub enum TerminalProviderCommand {
1021 WriteInput {
1022 terminal_id: acp::TerminalId,
1023 bytes: Vec<u8>,
1024 },
1025 Resize {
1026 terminal_id: acp::TerminalId,
1027 cols: u16,
1028 rows: u16,
1029 },
1030 Close {
1031 terminal_id: acp::TerminalId,
1032 },
1033}
1034
1035impl AcpThread {
1036 pub fn on_terminal_provider_event(
1037 &mut self,
1038 event: TerminalProviderEvent,
1039 cx: &mut Context<Self>,
1040 ) {
1041 match event {
1042 TerminalProviderEvent::Created {
1043 terminal_id,
1044 label,
1045 cwd,
1046 output_byte_limit,
1047 terminal,
1048 } => {
1049 let entity = self.register_terminal_created(
1050 terminal_id.clone(),
1051 label,
1052 cwd,
1053 output_byte_limit,
1054 terminal,
1055 cx,
1056 );
1057
1058 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1059 for data in chunks.drain(..) {
1060 entity.update(cx, |term, cx| {
1061 term.inner().update(cx, |inner, cx| {
1062 inner.write_output(&data, cx);
1063 })
1064 });
1065 }
1066 }
1067
1068 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1069 entity.update(cx, |_term, cx| {
1070 cx.notify();
1071 });
1072 }
1073
1074 cx.notify();
1075 }
1076 TerminalProviderEvent::Output { terminal_id, data } => {
1077 if let Some(entity) = self.terminals.get(&terminal_id) {
1078 entity.update(cx, |term, cx| {
1079 term.inner().update(cx, |inner, cx| {
1080 inner.write_output(&data, cx);
1081 })
1082 });
1083 } else {
1084 self.pending_terminal_output
1085 .entry(terminal_id)
1086 .or_default()
1087 .push(data);
1088 }
1089 }
1090 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1091 if let Some(entity) = self.terminals.get(&terminal_id) {
1092 entity.update(cx, |term, cx| {
1093 term.inner().update(cx, |inner, cx| {
1094 inner.breadcrumb_text = title;
1095 cx.emit(::terminal::Event::BreadcrumbsChanged);
1096 })
1097 });
1098 }
1099 }
1100 TerminalProviderEvent::Exit {
1101 terminal_id,
1102 status,
1103 } => {
1104 if let Some(entity) = self.terminals.get(&terminal_id) {
1105 entity.update(cx, |_term, cx| {
1106 cx.notify();
1107 });
1108 } else {
1109 self.pending_terminal_exit.insert(terminal_id, status);
1110 }
1111 }
1112 }
1113 }
1114}
1115
1116#[derive(PartialEq, Eq, Debug)]
1117pub enum ThreadStatus {
1118 Idle,
1119 Generating,
1120}
1121
1122#[derive(Debug, Clone)]
1123pub enum LoadError {
1124 Unsupported {
1125 command: SharedString,
1126 current_version: SharedString,
1127 minimum_version: SharedString,
1128 },
1129 FailedToInstall(SharedString),
1130 Exited {
1131 status: ExitStatus,
1132 },
1133 Other(SharedString),
1134}
1135
1136impl Display for LoadError {
1137 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1138 match self {
1139 LoadError::Unsupported {
1140 command: path,
1141 current_version,
1142 minimum_version,
1143 } => {
1144 write!(
1145 f,
1146 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1147 )
1148 }
1149 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1150 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1151 LoadError::Other(msg) => write!(f, "{msg}"),
1152 }
1153 }
1154}
1155
1156impl Error for LoadError {}
1157
1158impl AcpThread {
1159 pub fn new(
1160 parent_session_id: Option<acp::SessionId>,
1161 title: impl Into<SharedString>,
1162 connection: Rc<dyn AgentConnection>,
1163 project: Entity<Project>,
1164 action_log: Entity<ActionLog>,
1165 session_id: acp::SessionId,
1166 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1167 cx: &mut Context<Self>,
1168 ) -> Self {
1169 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1170 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1171 loop {
1172 let caps = prompt_capabilities_rx.recv().await?;
1173 this.update(cx, |this, cx| {
1174 this.prompt_capabilities = caps;
1175 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1176 })?;
1177 }
1178 });
1179
1180 Self {
1181 parent_session_id,
1182 action_log,
1183 shared_buffers: Default::default(),
1184 entries: Default::default(),
1185 plan: Default::default(),
1186 title: title.into(),
1187 project,
1188 running_turn: None,
1189 turn_id: 0,
1190 connection,
1191 session_id,
1192 token_usage: None,
1193 prompt_capabilities,
1194 _observe_prompt_capabilities: task,
1195 terminals: HashMap::default(),
1196 pending_terminal_output: HashMap::default(),
1197 pending_terminal_exit: HashMap::default(),
1198 had_error: false,
1199 }
1200 }
1201
1202 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1203 self.parent_session_id.as_ref()
1204 }
1205
1206 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1207 self.prompt_capabilities.clone()
1208 }
1209
1210 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1211 &self.connection
1212 }
1213
1214 pub fn action_log(&self) -> &Entity<ActionLog> {
1215 &self.action_log
1216 }
1217
1218 pub fn project(&self) -> &Entity<Project> {
1219 &self.project
1220 }
1221
1222 pub fn title(&self) -> SharedString {
1223 self.title.clone()
1224 }
1225
1226 pub fn entries(&self) -> &[AgentThreadEntry] {
1227 &self.entries
1228 }
1229
1230 pub fn session_id(&self) -> &acp::SessionId {
1231 &self.session_id
1232 }
1233
1234 pub fn status(&self) -> ThreadStatus {
1235 if self.running_turn.is_some() {
1236 ThreadStatus::Generating
1237 } else {
1238 ThreadStatus::Idle
1239 }
1240 }
1241
1242 pub fn had_error(&self) -> bool {
1243 self.had_error
1244 }
1245
1246 pub fn is_waiting_for_confirmation(&self) -> bool {
1247 for entry in self.entries.iter().rev() {
1248 match entry {
1249 AgentThreadEntry::UserMessage(_) => return false,
1250 AgentThreadEntry::ToolCall(ToolCall {
1251 status: ToolCallStatus::WaitingForConfirmation { .. },
1252 ..
1253 }) => return true,
1254 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1255 }
1256 }
1257 false
1258 }
1259
1260 pub fn token_usage(&self) -> Option<&TokenUsage> {
1261 self.token_usage.as_ref()
1262 }
1263
1264 pub fn has_pending_edit_tool_calls(&self) -> bool {
1265 for entry in self.entries.iter().rev() {
1266 match entry {
1267 AgentThreadEntry::UserMessage(_) => return false,
1268 AgentThreadEntry::ToolCall(
1269 call @ ToolCall {
1270 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1271 ..
1272 },
1273 ) if call.diffs().next().is_some() => {
1274 return true;
1275 }
1276 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1277 }
1278 }
1279
1280 false
1281 }
1282
1283 pub fn has_in_progress_tool_calls(&self) -> bool {
1284 for entry in self.entries.iter().rev() {
1285 match entry {
1286 AgentThreadEntry::UserMessage(_) => return false,
1287 AgentThreadEntry::ToolCall(ToolCall {
1288 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1289 ..
1290 }) => {
1291 return true;
1292 }
1293 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1294 }
1295 }
1296
1297 false
1298 }
1299
1300 pub fn used_tools_since_last_user_message(&self) -> bool {
1301 for entry in self.entries.iter().rev() {
1302 match entry {
1303 AgentThreadEntry::UserMessage(..) => return false,
1304 AgentThreadEntry::AssistantMessage(..) => continue,
1305 AgentThreadEntry::ToolCall(..) => return true,
1306 }
1307 }
1308
1309 false
1310 }
1311
1312 pub fn handle_session_update(
1313 &mut self,
1314 update: acp::SessionUpdate,
1315 cx: &mut Context<Self>,
1316 ) -> Result<(), acp::Error> {
1317 match update {
1318 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1319 self.push_user_content_block(None, content, cx);
1320 }
1321 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1322 self.push_assistant_content_block(content, false, cx);
1323 }
1324 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1325 self.push_assistant_content_block(content, true, cx);
1326 }
1327 acp::SessionUpdate::ToolCall(tool_call) => {
1328 self.upsert_tool_call(tool_call, cx)?;
1329 }
1330 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1331 self.update_tool_call(tool_call_update, cx)?;
1332 }
1333 acp::SessionUpdate::Plan(plan) => {
1334 self.update_plan(plan, cx);
1335 }
1336 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1337 available_commands,
1338 ..
1339 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1340 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1341 current_mode_id,
1342 ..
1343 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1344 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1345 config_options,
1346 ..
1347 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1348 _ => {}
1349 }
1350 Ok(())
1351 }
1352
1353 pub fn push_user_content_block(
1354 &mut self,
1355 message_id: Option<UserMessageId>,
1356 chunk: acp::ContentBlock,
1357 cx: &mut Context<Self>,
1358 ) {
1359 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1360 }
1361
1362 pub fn push_user_content_block_with_indent(
1363 &mut self,
1364 message_id: Option<UserMessageId>,
1365 chunk: acp::ContentBlock,
1366 indented: bool,
1367 cx: &mut Context<Self>,
1368 ) {
1369 let language_registry = self.project.read(cx).languages().clone();
1370 let path_style = self.project.read(cx).path_style(cx);
1371 let entries_len = self.entries.len();
1372
1373 if let Some(last_entry) = self.entries.last_mut()
1374 && let AgentThreadEntry::UserMessage(UserMessage {
1375 id,
1376 content,
1377 chunks,
1378 indented: existing_indented,
1379 ..
1380 }) = last_entry
1381 && *existing_indented == indented
1382 {
1383 *id = message_id.or(id.take());
1384 content.append(chunk.clone(), &language_registry, path_style, cx);
1385 chunks.push(chunk);
1386 let idx = entries_len - 1;
1387 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1388 } else {
1389 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1390 self.push_entry(
1391 AgentThreadEntry::UserMessage(UserMessage {
1392 id: message_id,
1393 content,
1394 chunks: vec![chunk],
1395 checkpoint: None,
1396 indented,
1397 }),
1398 cx,
1399 );
1400 }
1401 }
1402
1403 pub fn push_assistant_content_block(
1404 &mut self,
1405 chunk: acp::ContentBlock,
1406 is_thought: bool,
1407 cx: &mut Context<Self>,
1408 ) {
1409 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1410 }
1411
1412 pub fn push_assistant_content_block_with_indent(
1413 &mut self,
1414 chunk: acp::ContentBlock,
1415 is_thought: bool,
1416 indented: bool,
1417 cx: &mut Context<Self>,
1418 ) {
1419 let language_registry = self.project.read(cx).languages().clone();
1420 let path_style = self.project.read(cx).path_style(cx);
1421 let entries_len = self.entries.len();
1422 if let Some(last_entry) = self.entries.last_mut()
1423 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1424 chunks,
1425 indented: existing_indented,
1426 }) = last_entry
1427 && *existing_indented == indented
1428 {
1429 let idx = entries_len - 1;
1430 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1431 match (chunks.last_mut(), is_thought) {
1432 (Some(AssistantMessageChunk::Message { block }), false)
1433 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1434 block.append(chunk, &language_registry, path_style, cx)
1435 }
1436 _ => {
1437 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1438 if is_thought {
1439 chunks.push(AssistantMessageChunk::Thought { block })
1440 } else {
1441 chunks.push(AssistantMessageChunk::Message { block })
1442 }
1443 }
1444 }
1445 } else {
1446 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1447 let chunk = if is_thought {
1448 AssistantMessageChunk::Thought { block }
1449 } else {
1450 AssistantMessageChunk::Message { block }
1451 };
1452
1453 self.push_entry(
1454 AgentThreadEntry::AssistantMessage(AssistantMessage {
1455 chunks: vec![chunk],
1456 indented,
1457 }),
1458 cx,
1459 );
1460 }
1461 }
1462
1463 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1464 self.entries.push(entry);
1465 cx.emit(AcpThreadEvent::NewEntry);
1466 }
1467
1468 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1469 self.connection.set_title(&self.session_id, cx).is_some()
1470 }
1471
1472 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1473 if title != self.title {
1474 self.title = title.clone();
1475 cx.emit(AcpThreadEvent::TitleUpdated);
1476 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1477 return set_title.run(title, cx);
1478 }
1479 }
1480 Task::ready(Ok(()))
1481 }
1482
1483 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1484 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1485 }
1486
1487 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1488 self.token_usage = usage;
1489 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1490 }
1491
1492 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1493 cx.emit(AcpThreadEvent::Retry(status));
1494 }
1495
1496 pub fn update_tool_call(
1497 &mut self,
1498 update: impl Into<ToolCallUpdate>,
1499 cx: &mut Context<Self>,
1500 ) -> Result<()> {
1501 let update = update.into();
1502 let languages = self.project.read(cx).languages().clone();
1503 let path_style = self.project.read(cx).path_style(cx);
1504
1505 let ix = match self.index_for_tool_call(update.id()) {
1506 Some(ix) => ix,
1507 None => {
1508 // Tool call not found - create a failed tool call entry
1509 let failed_tool_call = ToolCall {
1510 id: update.id().clone(),
1511 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1512 kind: acp::ToolKind::Fetch,
1513 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1514 "Tool call not found".into(),
1515 &languages,
1516 path_style,
1517 cx,
1518 ))],
1519 status: ToolCallStatus::Failed,
1520 locations: Vec::new(),
1521 resolved_locations: Vec::new(),
1522 raw_input: None,
1523 raw_input_markdown: None,
1524 raw_output: None,
1525 tool_name: None,
1526 subagent_session_id: None,
1527 };
1528 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1529 return Ok(());
1530 }
1531 };
1532 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1533 unreachable!()
1534 };
1535
1536 match update {
1537 ToolCallUpdate::UpdateFields(update) => {
1538 let location_updated = update.fields.locations.is_some();
1539 call.update_fields(
1540 update.fields,
1541 update.meta,
1542 languages,
1543 path_style,
1544 &self.terminals,
1545 cx,
1546 )?;
1547 if location_updated {
1548 self.resolve_locations(update.tool_call_id, cx);
1549 }
1550 }
1551 ToolCallUpdate::UpdateDiff(update) => {
1552 call.content.clear();
1553 call.content.push(ToolCallContent::Diff(update.diff));
1554 }
1555 ToolCallUpdate::UpdateTerminal(update) => {
1556 call.content.clear();
1557 call.content
1558 .push(ToolCallContent::Terminal(update.terminal));
1559 }
1560 }
1561
1562 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1563
1564 Ok(())
1565 }
1566
1567 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1568 pub fn upsert_tool_call(
1569 &mut self,
1570 tool_call: acp::ToolCall,
1571 cx: &mut Context<Self>,
1572 ) -> Result<(), acp::Error> {
1573 let status = tool_call.status.into();
1574 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1575 }
1576
1577 /// Fails if id does not match an existing entry.
1578 pub fn upsert_tool_call_inner(
1579 &mut self,
1580 update: acp::ToolCallUpdate,
1581 status: ToolCallStatus,
1582 cx: &mut Context<Self>,
1583 ) -> Result<(), acp::Error> {
1584 let language_registry = self.project.read(cx).languages().clone();
1585 let path_style = self.project.read(cx).path_style(cx);
1586 let id = update.tool_call_id.clone();
1587
1588 let agent_telemetry_id = self.connection().telemetry_id();
1589 let session = self.session_id();
1590 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1591 let status = if matches!(status, ToolCallStatus::Completed) {
1592 "completed"
1593 } else {
1594 "failed"
1595 };
1596 telemetry::event!(
1597 "Agent Tool Call Completed",
1598 agent_telemetry_id,
1599 session,
1600 status
1601 );
1602 }
1603
1604 if let Some(ix) = self.index_for_tool_call(&id) {
1605 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1606 unreachable!()
1607 };
1608
1609 call.update_fields(
1610 update.fields,
1611 update.meta,
1612 language_registry,
1613 path_style,
1614 &self.terminals,
1615 cx,
1616 )?;
1617 call.status = status;
1618
1619 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1620 } else {
1621 let call = ToolCall::from_acp(
1622 update.try_into()?,
1623 status,
1624 language_registry,
1625 self.project.read(cx).path_style(cx),
1626 &self.terminals,
1627 cx,
1628 )?;
1629 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1630 };
1631
1632 self.resolve_locations(id, cx);
1633 Ok(())
1634 }
1635
1636 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1637 self.entries
1638 .iter()
1639 .enumerate()
1640 .rev()
1641 .find_map(|(index, entry)| {
1642 if let AgentThreadEntry::ToolCall(tool_call) = entry
1643 && &tool_call.id == id
1644 {
1645 Some(index)
1646 } else {
1647 None
1648 }
1649 })
1650 }
1651
1652 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1653 // The tool call we are looking for is typically the last one, or very close to the end.
1654 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1655 self.entries
1656 .iter_mut()
1657 .enumerate()
1658 .rev()
1659 .find_map(|(index, tool_call)| {
1660 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1661 && &tool_call.id == id
1662 {
1663 Some((index, tool_call))
1664 } else {
1665 None
1666 }
1667 })
1668 }
1669
1670 pub fn tool_call(&self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1671 self.entries
1672 .iter()
1673 .enumerate()
1674 .rev()
1675 .find_map(|(index, tool_call)| {
1676 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1677 && &tool_call.id == id
1678 {
1679 Some((index, tool_call))
1680 } else {
1681 None
1682 }
1683 })
1684 }
1685
1686 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1687 let project = self.project.clone();
1688 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1689 return;
1690 };
1691 let task = tool_call.resolve_locations(project, cx);
1692 cx.spawn(async move |this, cx| {
1693 let resolved_locations = task.await;
1694
1695 this.update(cx, |this, cx| {
1696 let project = this.project.clone();
1697
1698 for location in resolved_locations.iter().flatten() {
1699 this.shared_buffers
1700 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1701 }
1702 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1703 return;
1704 };
1705
1706 if let Some(Some(location)) = resolved_locations.last() {
1707 project.update(cx, |project, cx| {
1708 let should_ignore = if let Some(agent_location) = project
1709 .agent_location()
1710 .filter(|agent_location| agent_location.buffer == location.buffer)
1711 {
1712 let snapshot = location.buffer.read(cx).snapshot();
1713 let old_position = agent_location.position.to_point(&snapshot);
1714 let new_position = location.position.to_point(&snapshot);
1715
1716 // ignore this so that when we get updates from the edit tool
1717 // the position doesn't reset to the startof line
1718 old_position.row == new_position.row
1719 && old_position.column > new_position.column
1720 } else {
1721 false
1722 };
1723 if !should_ignore {
1724 project.set_agent_location(Some(location.into()), cx);
1725 }
1726 });
1727 }
1728
1729 let resolved_locations = resolved_locations
1730 .iter()
1731 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1732 .collect::<Vec<_>>();
1733
1734 if tool_call.resolved_locations != resolved_locations {
1735 tool_call.resolved_locations = resolved_locations;
1736 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1737 }
1738 })
1739 })
1740 .detach();
1741 }
1742
1743 pub fn request_tool_call_authorization(
1744 &mut self,
1745 tool_call: acp::ToolCallUpdate,
1746 options: PermissionOptions,
1747 cx: &mut Context<Self>,
1748 ) -> Result<Task<acp::RequestPermissionOutcome>> {
1749 let (tx, rx) = oneshot::channel();
1750
1751 let status = ToolCallStatus::WaitingForConfirmation {
1752 options,
1753 respond_tx: tx,
1754 };
1755
1756 let tool_call_id = tool_call.tool_call_id.clone();
1757 self.upsert_tool_call_inner(tool_call, status, cx)?;
1758 cx.emit(AcpThreadEvent::ToolAuthorizationRequested(
1759 tool_call_id.clone(),
1760 ));
1761
1762 Ok(cx.spawn(async move |this, cx| {
1763 let outcome = match rx.await {
1764 Ok(option) => acp::RequestPermissionOutcome::Selected(
1765 acp::SelectedPermissionOutcome::new(option),
1766 ),
1767 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1768 };
1769 this.update(cx, |_this, cx| {
1770 cx.emit(AcpThreadEvent::ToolAuthorizationReceived(tool_call_id))
1771 })
1772 .ok();
1773 outcome
1774 }))
1775 }
1776
1777 pub fn authorize_tool_call(
1778 &mut self,
1779 id: acp::ToolCallId,
1780 option_id: acp::PermissionOptionId,
1781 option_kind: acp::PermissionOptionKind,
1782 cx: &mut Context<Self>,
1783 ) {
1784 let Some((ix, call)) = self.tool_call_mut(&id) else {
1785 return;
1786 };
1787
1788 let new_status = match option_kind {
1789 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1790 ToolCallStatus::Rejected
1791 }
1792 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1793 ToolCallStatus::InProgress
1794 }
1795 _ => ToolCallStatus::InProgress,
1796 };
1797
1798 let curr_status = mem::replace(&mut call.status, new_status);
1799
1800 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1801 respond_tx.send(option_id).log_err();
1802 } else if cfg!(debug_assertions) {
1803 panic!("tried to authorize an already authorized tool call");
1804 }
1805
1806 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1807 }
1808
1809 pub fn plan(&self) -> &Plan {
1810 &self.plan
1811 }
1812
1813 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1814 let new_entries_len = request.entries.len();
1815 let mut new_entries = request.entries.into_iter();
1816
1817 // Reuse existing markdown to prevent flickering
1818 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1819 let PlanEntry {
1820 content,
1821 priority,
1822 status,
1823 } = old;
1824 content.update(cx, |old, cx| {
1825 old.replace(new.content, cx);
1826 });
1827 *priority = new.priority;
1828 *status = new.status;
1829 }
1830 for new in new_entries {
1831 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1832 }
1833 self.plan.entries.truncate(new_entries_len);
1834
1835 cx.notify();
1836 }
1837
1838 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1839 self.plan
1840 .entries
1841 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1842 cx.notify();
1843 }
1844
1845 #[cfg(any(test, feature = "test-support"))]
1846 pub fn send_raw(
1847 &mut self,
1848 message: &str,
1849 cx: &mut Context<Self>,
1850 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1851 self.send(vec![message.into()], cx)
1852 }
1853
1854 pub fn send(
1855 &mut self,
1856 message: Vec<acp::ContentBlock>,
1857 cx: &mut Context<Self>,
1858 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1859 let block = ContentBlock::new_combined(
1860 message.clone(),
1861 self.project.read(cx).languages().clone(),
1862 self.project.read(cx).path_style(cx),
1863 cx,
1864 );
1865 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1866 let git_store = self.project.read(cx).git_store().clone();
1867
1868 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1869 Some(UserMessageId::new())
1870 } else {
1871 None
1872 };
1873
1874 self.run_turn(cx, async move |this, cx| {
1875 this.update(cx, |this, cx| {
1876 this.push_entry(
1877 AgentThreadEntry::UserMessage(UserMessage {
1878 id: message_id.clone(),
1879 content: block,
1880 chunks: message,
1881 checkpoint: None,
1882 indented: false,
1883 }),
1884 cx,
1885 );
1886 })
1887 .ok();
1888
1889 let old_checkpoint = git_store
1890 .update(cx, |git, cx| git.checkpoint(cx))
1891 .await
1892 .context("failed to get old checkpoint")
1893 .log_err();
1894 this.update(cx, |this, cx| {
1895 if let Some((_ix, message)) = this.last_user_message() {
1896 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1897 git_checkpoint,
1898 show: false,
1899 });
1900 }
1901 this.connection.prompt(message_id, request, cx)
1902 })?
1903 .await
1904 })
1905 }
1906
1907 pub fn can_retry(&self, cx: &App) -> bool {
1908 self.connection.retry(&self.session_id, cx).is_some()
1909 }
1910
1911 pub fn retry(
1912 &mut self,
1913 cx: &mut Context<Self>,
1914 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1915 self.run_turn(cx, async move |this, cx| {
1916 this.update(cx, |this, cx| {
1917 this.connection
1918 .retry(&this.session_id, cx)
1919 .map(|retry| retry.run(cx))
1920 })?
1921 .context("retrying a session is not supported")?
1922 .await
1923 })
1924 }
1925
1926 fn run_turn(
1927 &mut self,
1928 cx: &mut Context<Self>,
1929 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1930 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1931 self.clear_completed_plan_entries(cx);
1932 self.had_error = false;
1933
1934 let (tx, rx) = oneshot::channel();
1935 let cancel_task = self.cancel(cx);
1936
1937 self.turn_id += 1;
1938 let turn_id = self.turn_id;
1939 self.running_turn = Some(RunningTurn {
1940 id: turn_id,
1941 send_task: cx.spawn(async move |this, cx| {
1942 cancel_task.await;
1943 tx.send(f(this, cx).await).ok();
1944 }),
1945 });
1946
1947 cx.spawn(async move |this, cx| {
1948 let response = rx.await;
1949
1950 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1951 .await?;
1952
1953 this.update(cx, |this, cx| {
1954 this.project
1955 .update(cx, |project, cx| project.set_agent_location(None, cx));
1956 let Ok(response) = response else {
1957 // tx dropped, just return
1958 return Ok(None);
1959 };
1960
1961 let is_same_turn = this
1962 .running_turn
1963 .as_ref()
1964 .is_some_and(|turn| turn_id == turn.id);
1965
1966 // If the user submitted a follow up message, running_turn might
1967 // already point to a different turn. Therefore we only want to
1968 // take the task if it's the same turn.
1969 if is_same_turn {
1970 this.running_turn.take();
1971 }
1972
1973 match response {
1974 Ok(r) => {
1975 if r.stop_reason == acp::StopReason::MaxTokens {
1976 this.had_error = true;
1977 cx.emit(AcpThreadEvent::Error);
1978 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1979 return Err(anyhow!("Max tokens reached"));
1980 }
1981
1982 let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
1983 if canceled {
1984 this.mark_pending_tools_as_canceled();
1985 }
1986
1987 // Handle refusal - distinguish between user prompt and tool call refusals
1988 if let acp::StopReason::Refusal = r.stop_reason {
1989 this.had_error = true;
1990 if let Some((user_msg_ix, _)) = this.last_user_message() {
1991 // Check if there's a completed tool call with results after the last user message
1992 // This indicates the refusal is in response to tool output, not the user's prompt
1993 let has_completed_tool_call_after_user_msg =
1994 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1995 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1996 // Check if the tool call has completed and has output
1997 matches!(tool_call.status, ToolCallStatus::Completed)
1998 && tool_call.raw_output.is_some()
1999 } else {
2000 false
2001 }
2002 });
2003
2004 if has_completed_tool_call_after_user_msg {
2005 // Refusal is due to tool output - don't truncate, just notify
2006 // The model refused based on what the tool returned
2007 cx.emit(AcpThreadEvent::Refusal);
2008 } else {
2009 // User prompt was refused - truncate back to before the user message
2010 let range = user_msg_ix..this.entries.len();
2011 if range.start < range.end {
2012 this.entries.truncate(user_msg_ix);
2013 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2014 }
2015 cx.emit(AcpThreadEvent::Refusal);
2016 }
2017 } else {
2018 // No user message found, treat as general refusal
2019 cx.emit(AcpThreadEvent::Refusal);
2020 }
2021 }
2022
2023 cx.emit(AcpThreadEvent::Stopped);
2024 Ok(Some(r))
2025 }
2026 Err(e) => {
2027 this.had_error = true;
2028 cx.emit(AcpThreadEvent::Error);
2029 log::error!("Error in run turn: {:?}", e);
2030 Err(e)
2031 }
2032 }
2033 })?
2034 })
2035 .boxed()
2036 }
2037
2038 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2039 let Some(turn) = self.running_turn.take() else {
2040 return Task::ready(());
2041 };
2042 self.connection.cancel(&self.session_id, cx);
2043
2044 self.mark_pending_tools_as_canceled();
2045
2046 // Wait for the send task to complete
2047 cx.background_spawn(turn.send_task)
2048 }
2049
2050 fn mark_pending_tools_as_canceled(&mut self) {
2051 for entry in self.entries.iter_mut() {
2052 if let AgentThreadEntry::ToolCall(call) = entry {
2053 let cancel = matches!(
2054 call.status,
2055 ToolCallStatus::Pending
2056 | ToolCallStatus::WaitingForConfirmation { .. }
2057 | ToolCallStatus::InProgress
2058 );
2059
2060 if cancel {
2061 call.status = ToolCallStatus::Canceled;
2062 }
2063 }
2064 }
2065 }
2066
2067 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2068 pub fn restore_checkpoint(
2069 &mut self,
2070 id: UserMessageId,
2071 cx: &mut Context<Self>,
2072 ) -> Task<Result<()>> {
2073 let Some((_, message)) = self.user_message_mut(&id) else {
2074 return Task::ready(Err(anyhow!("message not found")));
2075 };
2076
2077 let checkpoint = message
2078 .checkpoint
2079 .as_ref()
2080 .map(|c| c.git_checkpoint.clone());
2081
2082 // Cancel any in-progress generation before restoring
2083 let cancel_task = self.cancel(cx);
2084 let rewind = self.rewind(id.clone(), cx);
2085 let git_store = self.project.read(cx).git_store().clone();
2086
2087 cx.spawn(async move |_, cx| {
2088 cancel_task.await;
2089 rewind.await?;
2090 if let Some(checkpoint) = checkpoint {
2091 git_store
2092 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2093 .await?;
2094 }
2095
2096 Ok(())
2097 })
2098 }
2099
2100 /// Rewinds this thread to before the entry at `index`, removing it and all
2101 /// subsequent entries while rejecting any action_log changes made from that point.
2102 /// Unlike `restore_checkpoint`, this method does not restore from git.
2103 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2104 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2105 return Task::ready(Err(anyhow!("not supported")));
2106 };
2107
2108 let telemetry = ActionLogTelemetry::from(&*self);
2109 cx.spawn(async move |this, cx| {
2110 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2111 this.update(cx, |this, cx| {
2112 if let Some((ix, _)) = this.user_message_mut(&id) {
2113 // Collect all terminals from entries that will be removed
2114 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2115 .iter()
2116 .flat_map(|entry| entry.terminals())
2117 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2118 .collect();
2119
2120 let range = ix..this.entries.len();
2121 this.entries.truncate(ix);
2122 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2123
2124 // Kill and remove the terminals
2125 for terminal_id in terminals_to_remove {
2126 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2127 terminal.update(cx, |terminal, cx| {
2128 terminal.kill(cx);
2129 });
2130 }
2131 }
2132 }
2133 this.action_log().update(cx, |action_log, cx| {
2134 action_log.reject_all_edits(Some(telemetry), cx)
2135 })
2136 })?
2137 .await;
2138 Ok(())
2139 })
2140 }
2141
2142 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2143 let git_store = self.project.read(cx).git_store().clone();
2144
2145 let Some((_, message)) = self.last_user_message() else {
2146 return Task::ready(Ok(()));
2147 };
2148 let Some(user_message_id) = message.id.clone() else {
2149 return Task::ready(Ok(()));
2150 };
2151 let Some(checkpoint) = message.checkpoint.as_ref() else {
2152 return Task::ready(Ok(()));
2153 };
2154 let old_checkpoint = checkpoint.git_checkpoint.clone();
2155
2156 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2157 cx.spawn(async move |this, cx| {
2158 let Some(new_checkpoint) = new_checkpoint
2159 .await
2160 .context("failed to get new checkpoint")
2161 .log_err()
2162 else {
2163 return Ok(());
2164 };
2165
2166 let equal = git_store
2167 .update(cx, |git, cx| {
2168 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2169 })
2170 .await
2171 .unwrap_or(true);
2172
2173 this.update(cx, |this, cx| {
2174 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2175 if let Some(checkpoint) = message.checkpoint.as_mut() {
2176 checkpoint.show = !equal;
2177 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2178 }
2179 }
2180 })?;
2181
2182 Ok(())
2183 })
2184 }
2185
2186 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2187 self.entries
2188 .iter_mut()
2189 .enumerate()
2190 .rev()
2191 .find_map(|(ix, entry)| {
2192 if let AgentThreadEntry::UserMessage(message) = entry {
2193 Some((ix, message))
2194 } else {
2195 None
2196 }
2197 })
2198 }
2199
2200 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2201 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2202 if let AgentThreadEntry::UserMessage(message) = entry {
2203 if message.id.as_ref() == Some(id) {
2204 Some((ix, message))
2205 } else {
2206 None
2207 }
2208 } else {
2209 None
2210 }
2211 })
2212 }
2213
2214 pub fn read_text_file(
2215 &self,
2216 path: PathBuf,
2217 line: Option<u32>,
2218 limit: Option<u32>,
2219 reuse_shared_snapshot: bool,
2220 cx: &mut Context<Self>,
2221 ) -> Task<Result<String, acp::Error>> {
2222 // Args are 1-based, move to 0-based
2223 let line = line.unwrap_or_default().saturating_sub(1);
2224 let limit = limit.unwrap_or(u32::MAX);
2225 let project = self.project.clone();
2226 let action_log = self.action_log.clone();
2227 cx.spawn(async move |this, cx| {
2228 let load = project.update(cx, |project, cx| {
2229 let path = project
2230 .project_path_for_absolute_path(&path, cx)
2231 .ok_or_else(|| {
2232 acp::Error::resource_not_found(Some(path.display().to_string()))
2233 })?;
2234 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2235 })?;
2236
2237 let buffer = load.await?;
2238
2239 let snapshot = if reuse_shared_snapshot {
2240 this.read_with(cx, |this, _| {
2241 this.shared_buffers.get(&buffer.clone()).cloned()
2242 })
2243 .log_err()
2244 .flatten()
2245 } else {
2246 None
2247 };
2248
2249 let snapshot = if let Some(snapshot) = snapshot {
2250 snapshot
2251 } else {
2252 action_log.update(cx, |action_log, cx| {
2253 action_log.buffer_read(buffer.clone(), cx);
2254 });
2255
2256 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2257 this.update(cx, |this, _| {
2258 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2259 })?;
2260 snapshot
2261 };
2262
2263 let max_point = snapshot.max_point();
2264 let start_position = Point::new(line, 0);
2265
2266 if start_position > max_point {
2267 return Err(acp::Error::invalid_params().data(format!(
2268 "Attempting to read beyond the end of the file, line {}:{}",
2269 max_point.row + 1,
2270 max_point.column
2271 )));
2272 }
2273
2274 let start = snapshot.anchor_before(start_position);
2275 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2276
2277 project.update(cx, |project, cx| {
2278 project.set_agent_location(
2279 Some(AgentLocation {
2280 buffer: buffer.downgrade(),
2281 position: start,
2282 }),
2283 cx,
2284 );
2285 });
2286
2287 Ok(snapshot.text_for_range(start..end).collect::<String>())
2288 })
2289 }
2290
2291 pub fn write_text_file(
2292 &self,
2293 path: PathBuf,
2294 content: String,
2295 cx: &mut Context<Self>,
2296 ) -> Task<Result<()>> {
2297 let project = self.project.clone();
2298 let action_log = self.action_log.clone();
2299 cx.spawn(async move |this, cx| {
2300 let load = project.update(cx, |project, cx| {
2301 let path = project
2302 .project_path_for_absolute_path(&path, cx)
2303 .context("invalid path")?;
2304 anyhow::Ok(project.open_buffer(path, cx))
2305 });
2306 let buffer = load?.await?;
2307 let snapshot = this.update(cx, |this, cx| {
2308 this.shared_buffers
2309 .get(&buffer)
2310 .cloned()
2311 .unwrap_or_else(|| buffer.read(cx).snapshot())
2312 })?;
2313 let edits = cx
2314 .background_executor()
2315 .spawn(async move {
2316 let old_text = snapshot.text();
2317 text_diff(old_text.as_str(), &content)
2318 .into_iter()
2319 .map(|(range, replacement)| {
2320 (snapshot.anchor_range_between(range), replacement)
2321 })
2322 .collect::<Vec<_>>()
2323 })
2324 .await;
2325
2326 project.update(cx, |project, cx| {
2327 project.set_agent_location(
2328 Some(AgentLocation {
2329 buffer: buffer.downgrade(),
2330 position: edits
2331 .last()
2332 .map(|(range, _)| range.end)
2333 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2334 }),
2335 cx,
2336 );
2337 });
2338
2339 let format_on_save = cx.update(|cx| {
2340 action_log.update(cx, |action_log, cx| {
2341 action_log.buffer_read(buffer.clone(), cx);
2342 });
2343
2344 let format_on_save = buffer.update(cx, |buffer, cx| {
2345 buffer.edit(edits, None, cx);
2346
2347 let settings = language::language_settings::language_settings(
2348 buffer.language().map(|l| l.name()),
2349 buffer.file(),
2350 cx,
2351 );
2352
2353 settings.format_on_save != FormatOnSave::Off
2354 });
2355 action_log.update(cx, |action_log, cx| {
2356 action_log.buffer_edited(buffer.clone(), cx);
2357 });
2358 format_on_save
2359 });
2360
2361 if format_on_save {
2362 let format_task = project.update(cx, |project, cx| {
2363 project.format(
2364 HashSet::from_iter([buffer.clone()]),
2365 LspFormatTarget::Buffers,
2366 false,
2367 FormatTrigger::Save,
2368 cx,
2369 )
2370 });
2371 format_task.await.log_err();
2372
2373 action_log.update(cx, |action_log, cx| {
2374 action_log.buffer_edited(buffer.clone(), cx);
2375 });
2376 }
2377
2378 project
2379 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2380 .await
2381 })
2382 }
2383
2384 pub fn create_terminal(
2385 &self,
2386 command: String,
2387 args: Vec<String>,
2388 extra_env: Vec<acp::EnvVariable>,
2389 cwd: Option<PathBuf>,
2390 output_byte_limit: Option<u64>,
2391 cx: &mut Context<Self>,
2392 ) -> Task<Result<Entity<Terminal>>> {
2393 let env = match &cwd {
2394 Some(dir) => self.project.update(cx, |project, cx| {
2395 project.environment().update(cx, |env, cx| {
2396 env.directory_environment(dir.as_path().into(), cx)
2397 })
2398 }),
2399 None => Task::ready(None).shared(),
2400 };
2401 let env = cx.spawn(async move |_, _| {
2402 let mut env = env.await.unwrap_or_default();
2403 // Disables paging for `git` and hopefully other commands
2404 env.insert("PAGER".into(), "".into());
2405 for var in extra_env {
2406 env.insert(var.name, var.value);
2407 }
2408 env
2409 });
2410
2411 let project = self.project.clone();
2412 let language_registry = project.read(cx).languages().clone();
2413 let is_windows = project.read(cx).path_style(cx).is_windows();
2414
2415 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2416 let terminal_task = cx.spawn({
2417 let terminal_id = terminal_id.clone();
2418 async move |_this, cx| {
2419 let env = env.await;
2420 let shell = project
2421 .update(cx, |project, cx| {
2422 project
2423 .remote_client()
2424 .and_then(|r| r.read(cx).default_system_shell())
2425 })
2426 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2427 let (task_command, task_args) =
2428 ShellBuilder::new(&Shell::Program(shell), is_windows)
2429 .redirect_stdin_to_dev_null()
2430 .build(Some(command.clone()), &args);
2431 let terminal = project
2432 .update(cx, |project, cx| {
2433 project.create_terminal_task(
2434 task::SpawnInTerminal {
2435 command: Some(task_command),
2436 args: task_args,
2437 cwd: cwd.clone(),
2438 env,
2439 ..Default::default()
2440 },
2441 cx,
2442 )
2443 })
2444 .await?;
2445
2446 anyhow::Ok(cx.new(|cx| {
2447 Terminal::new(
2448 terminal_id,
2449 &format!("{} {}", command, args.join(" ")),
2450 cwd,
2451 output_byte_limit.map(|l| l as usize),
2452 terminal,
2453 language_registry,
2454 cx,
2455 )
2456 }))
2457 }
2458 });
2459
2460 cx.spawn(async move |this, cx| {
2461 let terminal = terminal_task.await?;
2462 this.update(cx, |this, _cx| {
2463 this.terminals.insert(terminal_id, terminal.clone());
2464 terminal
2465 })
2466 })
2467 }
2468
2469 pub fn kill_terminal(
2470 &mut self,
2471 terminal_id: acp::TerminalId,
2472 cx: &mut Context<Self>,
2473 ) -> Result<()> {
2474 self.terminals
2475 .get(&terminal_id)
2476 .context("Terminal not found")?
2477 .update(cx, |terminal, cx| {
2478 terminal.kill(cx);
2479 });
2480
2481 Ok(())
2482 }
2483
2484 pub fn release_terminal(
2485 &mut self,
2486 terminal_id: acp::TerminalId,
2487 cx: &mut Context<Self>,
2488 ) -> Result<()> {
2489 self.terminals
2490 .remove(&terminal_id)
2491 .context("Terminal not found")?
2492 .update(cx, |terminal, cx| {
2493 terminal.kill(cx);
2494 });
2495
2496 Ok(())
2497 }
2498
2499 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2500 self.terminals
2501 .get(&terminal_id)
2502 .context("Terminal not found")
2503 .cloned()
2504 }
2505
2506 pub fn to_markdown(&self, cx: &App) -> String {
2507 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2508 }
2509
2510 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2511 cx.emit(AcpThreadEvent::LoadError(error));
2512 }
2513
2514 pub fn register_terminal_created(
2515 &mut self,
2516 terminal_id: acp::TerminalId,
2517 command_label: String,
2518 working_dir: Option<PathBuf>,
2519 output_byte_limit: Option<u64>,
2520 terminal: Entity<::terminal::Terminal>,
2521 cx: &mut Context<Self>,
2522 ) -> Entity<Terminal> {
2523 let language_registry = self.project.read(cx).languages().clone();
2524
2525 let entity = cx.new(|cx| {
2526 Terminal::new(
2527 terminal_id.clone(),
2528 &command_label,
2529 working_dir.clone(),
2530 output_byte_limit.map(|l| l as usize),
2531 terminal,
2532 language_registry,
2533 cx,
2534 )
2535 });
2536 self.terminals.insert(terminal_id.clone(), entity.clone());
2537 entity
2538 }
2539}
2540
2541fn markdown_for_raw_output(
2542 raw_output: &serde_json::Value,
2543 language_registry: &Arc<LanguageRegistry>,
2544 cx: &mut App,
2545) -> Option<Entity<Markdown>> {
2546 match raw_output {
2547 serde_json::Value::Null => None,
2548 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2549 Markdown::new(
2550 value.to_string().into(),
2551 Some(language_registry.clone()),
2552 None,
2553 cx,
2554 )
2555 })),
2556 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2557 Markdown::new(
2558 value.to_string().into(),
2559 Some(language_registry.clone()),
2560 None,
2561 cx,
2562 )
2563 })),
2564 serde_json::Value::String(value) => Some(cx.new(|cx| {
2565 Markdown::new(
2566 value.clone().into(),
2567 Some(language_registry.clone()),
2568 None,
2569 cx,
2570 )
2571 })),
2572 value => Some(cx.new(|cx| {
2573 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2574
2575 Markdown::new(
2576 format!("```json\n{}\n```", pretty_json).into(),
2577 Some(language_registry.clone()),
2578 None,
2579 cx,
2580 )
2581 })),
2582 }
2583}
2584
2585#[cfg(test)]
2586mod tests {
2587 use super::*;
2588 use anyhow::anyhow;
2589 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2590 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2591 use indoc::indoc;
2592 use project::{FakeFs, Fs};
2593 use rand::{distr, prelude::*};
2594 use serde_json::json;
2595 use settings::SettingsStore;
2596 use smol::stream::StreamExt as _;
2597 use std::{
2598 any::Any,
2599 cell::RefCell,
2600 path::Path,
2601 rc::Rc,
2602 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2603 time::Duration,
2604 };
2605 use util::path;
2606
2607 fn init_test(cx: &mut TestAppContext) {
2608 env_logger::try_init().ok();
2609 cx.update(|cx| {
2610 let settings_store = SettingsStore::test(cx);
2611 cx.set_global(settings_store);
2612 });
2613 }
2614
2615 #[gpui::test]
2616 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2617 init_test(cx);
2618
2619 let fs = FakeFs::new(cx.executor());
2620 let project = Project::test(fs, [], cx).await;
2621 let connection = Rc::new(FakeAgentConnection::new());
2622 let thread = cx
2623 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2624 .await
2625 .unwrap();
2626
2627 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2628
2629 // Send Output BEFORE Created - should be buffered by acp_thread
2630 thread.update(cx, |thread, cx| {
2631 thread.on_terminal_provider_event(
2632 TerminalProviderEvent::Output {
2633 terminal_id: terminal_id.clone(),
2634 data: b"hello buffered".to_vec(),
2635 },
2636 cx,
2637 );
2638 });
2639
2640 // Create a display-only terminal and then send Created
2641 let lower = cx.new(|cx| {
2642 let builder = ::terminal::TerminalBuilder::new_display_only(
2643 ::terminal::terminal_settings::CursorShape::default(),
2644 ::terminal::terminal_settings::AlternateScroll::On,
2645 None,
2646 0,
2647 cx.background_executor(),
2648 PathStyle::local(),
2649 )
2650 .unwrap();
2651 builder.subscribe(cx)
2652 });
2653
2654 thread.update(cx, |thread, cx| {
2655 thread.on_terminal_provider_event(
2656 TerminalProviderEvent::Created {
2657 terminal_id: terminal_id.clone(),
2658 label: "Buffered Test".to_string(),
2659 cwd: None,
2660 output_byte_limit: None,
2661 terminal: lower.clone(),
2662 },
2663 cx,
2664 );
2665 });
2666
2667 // After Created, buffered Output should have been flushed into the renderer
2668 let content = thread.read_with(cx, |thread, cx| {
2669 let term = thread.terminal(terminal_id.clone()).unwrap();
2670 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2671 });
2672
2673 assert!(
2674 content.contains("hello buffered"),
2675 "expected buffered output to render, got: {content}"
2676 );
2677 }
2678
2679 #[gpui::test]
2680 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2681 init_test(cx);
2682
2683 let fs = FakeFs::new(cx.executor());
2684 let project = Project::test(fs, [], cx).await;
2685 let connection = Rc::new(FakeAgentConnection::new());
2686 let thread = cx
2687 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2688 .await
2689 .unwrap();
2690
2691 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2692
2693 // Send Output BEFORE Created
2694 thread.update(cx, |thread, cx| {
2695 thread.on_terminal_provider_event(
2696 TerminalProviderEvent::Output {
2697 terminal_id: terminal_id.clone(),
2698 data: b"pre-exit data".to_vec(),
2699 },
2700 cx,
2701 );
2702 });
2703
2704 // Send Exit BEFORE Created
2705 thread.update(cx, |thread, cx| {
2706 thread.on_terminal_provider_event(
2707 TerminalProviderEvent::Exit {
2708 terminal_id: terminal_id.clone(),
2709 status: acp::TerminalExitStatus::new().exit_code(0),
2710 },
2711 cx,
2712 );
2713 });
2714
2715 // Now create a display-only lower-level terminal and send Created
2716 let lower = cx.new(|cx| {
2717 let builder = ::terminal::TerminalBuilder::new_display_only(
2718 ::terminal::terminal_settings::CursorShape::default(),
2719 ::terminal::terminal_settings::AlternateScroll::On,
2720 None,
2721 0,
2722 cx.background_executor(),
2723 PathStyle::local(),
2724 )
2725 .unwrap();
2726 builder.subscribe(cx)
2727 });
2728
2729 thread.update(cx, |thread, cx| {
2730 thread.on_terminal_provider_event(
2731 TerminalProviderEvent::Created {
2732 terminal_id: terminal_id.clone(),
2733 label: "Buffered Exit Test".to_string(),
2734 cwd: None,
2735 output_byte_limit: None,
2736 terminal: lower.clone(),
2737 },
2738 cx,
2739 );
2740 });
2741
2742 // Output should be present after Created (flushed from buffer)
2743 let content = thread.read_with(cx, |thread, cx| {
2744 let term = thread.terminal(terminal_id.clone()).unwrap();
2745 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2746 });
2747
2748 assert!(
2749 content.contains("pre-exit data"),
2750 "expected pre-exit data to render, got: {content}"
2751 );
2752 }
2753
2754 /// Test that killing a terminal via Terminal::kill properly:
2755 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2756 /// 2. The underlying terminal still has the output that was written before the kill
2757 ///
2758 /// This test verifies that the fix to kill_active_task (which now also kills
2759 /// the shell process in addition to the foreground process) properly allows
2760 /// wait_for_exit to complete instead of hanging indefinitely.
2761 #[cfg(unix)]
2762 #[gpui::test]
2763 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2764 use std::collections::HashMap;
2765 use task::Shell;
2766 use util::shell_builder::ShellBuilder;
2767
2768 init_test(cx);
2769 cx.executor().allow_parking();
2770
2771 let fs = FakeFs::new(cx.executor());
2772 let project = Project::test(fs, [], cx).await;
2773 let connection = Rc::new(FakeAgentConnection::new());
2774 let thread = cx
2775 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2776 .await
2777 .unwrap();
2778
2779 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2780
2781 // Create a real PTY terminal that runs a command which prints output then sleeps
2782 // We use printf instead of echo and chain with && sleep to ensure proper execution
2783 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2784 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2785 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2786 &[],
2787 );
2788
2789 let builder = cx
2790 .update(|cx| {
2791 ::terminal::TerminalBuilder::new(
2792 None,
2793 None,
2794 task::Shell::WithArguments {
2795 program,
2796 args,
2797 title_override: None,
2798 },
2799 HashMap::default(),
2800 ::terminal::terminal_settings::CursorShape::default(),
2801 ::terminal::terminal_settings::AlternateScroll::On,
2802 None,
2803 vec![],
2804 0,
2805 false,
2806 0,
2807 Some(completion_tx),
2808 cx,
2809 vec![],
2810 PathStyle::local(),
2811 )
2812 })
2813 .await
2814 .unwrap();
2815
2816 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2817
2818 // Create the acp_thread Terminal wrapper
2819 thread.update(cx, |thread, cx| {
2820 thread.on_terminal_provider_event(
2821 TerminalProviderEvent::Created {
2822 terminal_id: terminal_id.clone(),
2823 label: "printf output_before_kill && sleep 60".to_string(),
2824 cwd: None,
2825 output_byte_limit: None,
2826 terminal: lower_terminal.clone(),
2827 },
2828 cx,
2829 );
2830 });
2831
2832 // Wait for the printf command to execute and produce output
2833 // Use real time since parking is enabled
2834 cx.executor().timer(Duration::from_millis(500)).await;
2835
2836 // Get the acp_thread Terminal and kill it
2837 let wait_for_exit = thread.update(cx, |thread, cx| {
2838 let term = thread.terminals.get(&terminal_id).unwrap();
2839 let wait_for_exit = term.read(cx).wait_for_exit();
2840 term.update(cx, |term, cx| {
2841 term.kill(cx);
2842 });
2843 wait_for_exit
2844 });
2845
2846 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2847 // Before the fix to kill_active_task, this would hang forever because
2848 // only the foreground process was killed, not the shell, so the PTY
2849 // child never exited and wait_for_completed_task never completed.
2850 let exit_result = futures::select! {
2851 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2852 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2853 };
2854
2855 assert!(
2856 exit_result.is_some(),
2857 "wait_for_exit should complete after kill, but it timed out. \
2858 This indicates kill_active_task is not properly killing the shell process."
2859 );
2860
2861 // Give the system a chance to process any pending updates
2862 cx.run_until_parked();
2863
2864 // Verify that the underlying terminal still has the output that was
2865 // written before the kill. This verifies that killing doesn't lose output.
2866 let inner_content = thread.read_with(cx, |thread, cx| {
2867 let term = thread.terminals.get(&terminal_id).unwrap();
2868 term.read(cx).inner().read(cx).get_content()
2869 });
2870
2871 assert!(
2872 inner_content.contains("output_before_kill"),
2873 "Underlying terminal should contain output from before kill, got: {}",
2874 inner_content
2875 );
2876 }
2877
2878 #[gpui::test]
2879 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2880 init_test(cx);
2881
2882 let fs = FakeFs::new(cx.executor());
2883 let project = Project::test(fs, [], cx).await;
2884 let connection = Rc::new(FakeAgentConnection::new());
2885 let thread = cx
2886 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2887 .await
2888 .unwrap();
2889
2890 // Test creating a new user message
2891 thread.update(cx, |thread, cx| {
2892 thread.push_user_content_block(None, "Hello, ".into(), cx);
2893 });
2894
2895 thread.update(cx, |thread, cx| {
2896 assert_eq!(thread.entries.len(), 1);
2897 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2898 assert_eq!(user_msg.id, None);
2899 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2900 } else {
2901 panic!("Expected UserMessage");
2902 }
2903 });
2904
2905 // Test appending to existing user message
2906 let message_1_id = UserMessageId::new();
2907 thread.update(cx, |thread, cx| {
2908 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2909 });
2910
2911 thread.update(cx, |thread, cx| {
2912 assert_eq!(thread.entries.len(), 1);
2913 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2914 assert_eq!(user_msg.id, Some(message_1_id));
2915 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2916 } else {
2917 panic!("Expected UserMessage");
2918 }
2919 });
2920
2921 // Test creating new user message after assistant message
2922 thread.update(cx, |thread, cx| {
2923 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2924 });
2925
2926 let message_2_id = UserMessageId::new();
2927 thread.update(cx, |thread, cx| {
2928 thread.push_user_content_block(
2929 Some(message_2_id.clone()),
2930 "New user message".into(),
2931 cx,
2932 );
2933 });
2934
2935 thread.update(cx, |thread, cx| {
2936 assert_eq!(thread.entries.len(), 3);
2937 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2938 assert_eq!(user_msg.id, Some(message_2_id));
2939 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2940 } else {
2941 panic!("Expected UserMessage at index 2");
2942 }
2943 });
2944 }
2945
2946 #[gpui::test]
2947 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2948 init_test(cx);
2949
2950 let fs = FakeFs::new(cx.executor());
2951 let project = Project::test(fs, [], cx).await;
2952 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2953 |_, thread, mut cx| {
2954 async move {
2955 thread.update(&mut cx, |thread, cx| {
2956 thread
2957 .handle_session_update(
2958 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2959 "Thinking ".into(),
2960 )),
2961 cx,
2962 )
2963 .unwrap();
2964 thread
2965 .handle_session_update(
2966 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2967 "hard!".into(),
2968 )),
2969 cx,
2970 )
2971 .unwrap();
2972 })?;
2973 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2974 }
2975 .boxed_local()
2976 },
2977 ));
2978
2979 let thread = cx
2980 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2981 .await
2982 .unwrap();
2983
2984 thread
2985 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2986 .await
2987 .unwrap();
2988
2989 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2990 assert_eq!(
2991 output,
2992 indoc! {r#"
2993 ## User
2994
2995 Hello from Zed!
2996
2997 ## Assistant
2998
2999 <thinking>
3000 Thinking hard!
3001 </thinking>
3002
3003 "#}
3004 );
3005 }
3006
3007 #[gpui::test]
3008 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3009 init_test(cx);
3010
3011 let fs = FakeFs::new(cx.executor());
3012 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3013 .await;
3014 let project = Project::test(fs.clone(), [], cx).await;
3015 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3016 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3017 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3018 move |_, thread, mut cx| {
3019 let read_file_tx = read_file_tx.clone();
3020 async move {
3021 let content = thread
3022 .update(&mut cx, |thread, cx| {
3023 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3024 })
3025 .unwrap()
3026 .await
3027 .unwrap();
3028 assert_eq!(content, "one\ntwo\nthree\n");
3029 read_file_tx.take().unwrap().send(()).unwrap();
3030 thread
3031 .update(&mut cx, |thread, cx| {
3032 thread.write_text_file(
3033 path!("/tmp/foo").into(),
3034 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3035 cx,
3036 )
3037 })
3038 .unwrap()
3039 .await
3040 .unwrap();
3041 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3042 }
3043 .boxed_local()
3044 },
3045 ));
3046
3047 let (worktree, pathbuf) = project
3048 .update(cx, |project, cx| {
3049 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3050 })
3051 .await
3052 .unwrap();
3053 let buffer = project
3054 .update(cx, |project, cx| {
3055 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3056 })
3057 .await
3058 .unwrap();
3059
3060 let thread = cx
3061 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3062 .await
3063 .unwrap();
3064
3065 let request = thread.update(cx, |thread, cx| {
3066 thread.send_raw("Extend the count in /tmp/foo", cx)
3067 });
3068 read_file_rx.await.ok();
3069 buffer.update(cx, |buffer, cx| {
3070 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3071 });
3072 cx.run_until_parked();
3073 assert_eq!(
3074 buffer.read_with(cx, |buffer, _| buffer.text()),
3075 "zero\none\ntwo\nthree\nfour\nfive\n"
3076 );
3077 assert_eq!(
3078 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3079 "zero\none\ntwo\nthree\nfour\nfive\n"
3080 );
3081 request.await.unwrap();
3082 }
3083
3084 #[gpui::test]
3085 async fn test_reading_from_line(cx: &mut TestAppContext) {
3086 init_test(cx);
3087
3088 let fs = FakeFs::new(cx.executor());
3089 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3090 .await;
3091 let project = Project::test(fs.clone(), [], cx).await;
3092 project
3093 .update(cx, |project, cx| {
3094 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3095 })
3096 .await
3097 .unwrap();
3098
3099 let connection = Rc::new(FakeAgentConnection::new());
3100
3101 let thread = cx
3102 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3103 .await
3104 .unwrap();
3105
3106 // Whole file
3107 let content = thread
3108 .update(cx, |thread, cx| {
3109 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3110 })
3111 .await
3112 .unwrap();
3113
3114 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3115
3116 // Only start line
3117 let content = thread
3118 .update(cx, |thread, cx| {
3119 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3120 })
3121 .await
3122 .unwrap();
3123
3124 assert_eq!(content, "three\nfour\n");
3125
3126 // Only limit
3127 let content = thread
3128 .update(cx, |thread, cx| {
3129 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3130 })
3131 .await
3132 .unwrap();
3133
3134 assert_eq!(content, "one\ntwo\n");
3135
3136 // Range
3137 let content = thread
3138 .update(cx, |thread, cx| {
3139 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3140 })
3141 .await
3142 .unwrap();
3143
3144 assert_eq!(content, "two\nthree\n");
3145
3146 // Invalid
3147 let err = thread
3148 .update(cx, |thread, cx| {
3149 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3150 })
3151 .await
3152 .unwrap_err();
3153
3154 assert_eq!(
3155 err.to_string(),
3156 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3157 );
3158 }
3159
3160 #[gpui::test]
3161 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3162 init_test(cx);
3163
3164 let fs = FakeFs::new(cx.executor());
3165 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3166 let project = Project::test(fs.clone(), [], cx).await;
3167 project
3168 .update(cx, |project, cx| {
3169 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3170 })
3171 .await
3172 .unwrap();
3173
3174 let connection = Rc::new(FakeAgentConnection::new());
3175
3176 let thread = cx
3177 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3178 .await
3179 .unwrap();
3180
3181 // Whole file
3182 let content = thread
3183 .update(cx, |thread, cx| {
3184 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3185 })
3186 .await
3187 .unwrap();
3188
3189 assert_eq!(content, "");
3190
3191 // Only start line
3192 let content = thread
3193 .update(cx, |thread, cx| {
3194 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3195 })
3196 .await
3197 .unwrap();
3198
3199 assert_eq!(content, "");
3200
3201 // Only limit
3202 let content = thread
3203 .update(cx, |thread, cx| {
3204 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3205 })
3206 .await
3207 .unwrap();
3208
3209 assert_eq!(content, "");
3210
3211 // Range
3212 let content = thread
3213 .update(cx, |thread, cx| {
3214 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3215 })
3216 .await
3217 .unwrap();
3218
3219 assert_eq!(content, "");
3220
3221 // Invalid
3222 let err = thread
3223 .update(cx, |thread, cx| {
3224 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3225 })
3226 .await
3227 .unwrap_err();
3228
3229 assert_eq!(
3230 err.to_string(),
3231 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3232 );
3233 }
3234 #[gpui::test]
3235 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3236 init_test(cx);
3237
3238 let fs = FakeFs::new(cx.executor());
3239 fs.insert_tree(path!("/tmp"), json!({})).await;
3240 let project = Project::test(fs.clone(), [], cx).await;
3241 project
3242 .update(cx, |project, cx| {
3243 project.find_or_create_worktree(path!("/tmp"), true, cx)
3244 })
3245 .await
3246 .unwrap();
3247
3248 let connection = Rc::new(FakeAgentConnection::new());
3249
3250 let thread = cx
3251 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3252 .await
3253 .unwrap();
3254
3255 // Out of project file
3256 let err = thread
3257 .update(cx, |thread, cx| {
3258 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3259 })
3260 .await
3261 .unwrap_err();
3262
3263 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3264 }
3265
3266 #[gpui::test]
3267 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3268 init_test(cx);
3269
3270 let fs = FakeFs::new(cx.executor());
3271 let project = Project::test(fs, [], cx).await;
3272 let id = acp::ToolCallId::new("test");
3273
3274 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3275 let id = id.clone();
3276 move |_, thread, mut cx| {
3277 let id = id.clone();
3278 async move {
3279 thread
3280 .update(&mut cx, |thread, cx| {
3281 thread.handle_session_update(
3282 acp::SessionUpdate::ToolCall(
3283 acp::ToolCall::new(id.clone(), "Label")
3284 .kind(acp::ToolKind::Fetch)
3285 .status(acp::ToolCallStatus::InProgress),
3286 ),
3287 cx,
3288 )
3289 })
3290 .unwrap()
3291 .unwrap();
3292 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3293 }
3294 .boxed_local()
3295 }
3296 }));
3297
3298 let thread = cx
3299 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3300 .await
3301 .unwrap();
3302
3303 let request = thread.update(cx, |thread, cx| {
3304 thread.send_raw("Fetch https://example.com", cx)
3305 });
3306
3307 run_until_first_tool_call(&thread, cx).await;
3308
3309 thread.read_with(cx, |thread, _| {
3310 assert!(matches!(
3311 thread.entries[1],
3312 AgentThreadEntry::ToolCall(ToolCall {
3313 status: ToolCallStatus::InProgress,
3314 ..
3315 })
3316 ));
3317 });
3318
3319 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3320
3321 thread.read_with(cx, |thread, _| {
3322 assert!(matches!(
3323 &thread.entries[1],
3324 AgentThreadEntry::ToolCall(ToolCall {
3325 status: ToolCallStatus::Canceled,
3326 ..
3327 })
3328 ));
3329 });
3330
3331 thread
3332 .update(cx, |thread, cx| {
3333 thread.handle_session_update(
3334 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3335 id,
3336 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3337 )),
3338 cx,
3339 )
3340 })
3341 .unwrap();
3342
3343 request.await.unwrap();
3344
3345 thread.read_with(cx, |thread, _| {
3346 assert!(matches!(
3347 thread.entries[1],
3348 AgentThreadEntry::ToolCall(ToolCall {
3349 status: ToolCallStatus::Completed,
3350 ..
3351 })
3352 ));
3353 });
3354 }
3355
3356 #[gpui::test]
3357 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3358 init_test(cx);
3359 let fs = FakeFs::new(cx.background_executor.clone());
3360 fs.insert_tree(path!("/test"), json!({})).await;
3361 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3362
3363 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3364 move |_, thread, mut cx| {
3365 async move {
3366 thread
3367 .update(&mut cx, |thread, cx| {
3368 thread.handle_session_update(
3369 acp::SessionUpdate::ToolCall(
3370 acp::ToolCall::new("test", "Label")
3371 .kind(acp::ToolKind::Edit)
3372 .status(acp::ToolCallStatus::Completed)
3373 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3374 "/test/test.txt",
3375 "foo",
3376 ))]),
3377 ),
3378 cx,
3379 )
3380 })
3381 .unwrap()
3382 .unwrap();
3383 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3384 }
3385 .boxed_local()
3386 }
3387 }));
3388
3389 let thread = cx
3390 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3391 .await
3392 .unwrap();
3393
3394 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3395 .await
3396 .unwrap();
3397
3398 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3399 }
3400
3401 #[gpui::test(iterations = 10)]
3402 async fn test_checkpoints(cx: &mut TestAppContext) {
3403 init_test(cx);
3404 let fs = FakeFs::new(cx.background_executor.clone());
3405 fs.insert_tree(
3406 path!("/test"),
3407 json!({
3408 ".git": {}
3409 }),
3410 )
3411 .await;
3412 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3413
3414 let simulate_changes = Arc::new(AtomicBool::new(true));
3415 let next_filename = Arc::new(AtomicUsize::new(0));
3416 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3417 let simulate_changes = simulate_changes.clone();
3418 let next_filename = next_filename.clone();
3419 let fs = fs.clone();
3420 move |request, thread, mut cx| {
3421 let fs = fs.clone();
3422 let simulate_changes = simulate_changes.clone();
3423 let next_filename = next_filename.clone();
3424 async move {
3425 if simulate_changes.load(SeqCst) {
3426 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3427 fs.write(Path::new(&filename), b"").await?;
3428 }
3429
3430 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3431 panic!("expected text content block");
3432 };
3433 thread.update(&mut cx, |thread, cx| {
3434 thread
3435 .handle_session_update(
3436 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3437 content.text.to_uppercase().into(),
3438 )),
3439 cx,
3440 )
3441 .unwrap();
3442 })?;
3443 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3444 }
3445 .boxed_local()
3446 }
3447 }));
3448 let thread = cx
3449 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3450 .await
3451 .unwrap();
3452
3453 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3454 .await
3455 .unwrap();
3456 thread.read_with(cx, |thread, cx| {
3457 assert_eq!(
3458 thread.to_markdown(cx),
3459 indoc! {"
3460 ## User (checkpoint)
3461
3462 Lorem
3463
3464 ## Assistant
3465
3466 LOREM
3467
3468 "}
3469 );
3470 });
3471 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3472
3473 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3474 .await
3475 .unwrap();
3476 thread.read_with(cx, |thread, cx| {
3477 assert_eq!(
3478 thread.to_markdown(cx),
3479 indoc! {"
3480 ## User (checkpoint)
3481
3482 Lorem
3483
3484 ## Assistant
3485
3486 LOREM
3487
3488 ## User (checkpoint)
3489
3490 ipsum
3491
3492 ## Assistant
3493
3494 IPSUM
3495
3496 "}
3497 );
3498 });
3499 assert_eq!(
3500 fs.files(),
3501 vec![
3502 Path::new(path!("/test/file-0")),
3503 Path::new(path!("/test/file-1"))
3504 ]
3505 );
3506
3507 // Checkpoint isn't stored when there are no changes.
3508 simulate_changes.store(false, SeqCst);
3509 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3510 .await
3511 .unwrap();
3512 thread.read_with(cx, |thread, cx| {
3513 assert_eq!(
3514 thread.to_markdown(cx),
3515 indoc! {"
3516 ## User (checkpoint)
3517
3518 Lorem
3519
3520 ## Assistant
3521
3522 LOREM
3523
3524 ## User (checkpoint)
3525
3526 ipsum
3527
3528 ## Assistant
3529
3530 IPSUM
3531
3532 ## User
3533
3534 dolor
3535
3536 ## Assistant
3537
3538 DOLOR
3539
3540 "}
3541 );
3542 });
3543 assert_eq!(
3544 fs.files(),
3545 vec![
3546 Path::new(path!("/test/file-0")),
3547 Path::new(path!("/test/file-1"))
3548 ]
3549 );
3550
3551 // Rewinding the conversation truncates the history and restores the checkpoint.
3552 thread
3553 .update(cx, |thread, cx| {
3554 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3555 panic!("unexpected entries {:?}", thread.entries)
3556 };
3557 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3558 })
3559 .await
3560 .unwrap();
3561 thread.read_with(cx, |thread, cx| {
3562 assert_eq!(
3563 thread.to_markdown(cx),
3564 indoc! {"
3565 ## User (checkpoint)
3566
3567 Lorem
3568
3569 ## Assistant
3570
3571 LOREM
3572
3573 "}
3574 );
3575 });
3576 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3577 }
3578
3579 #[gpui::test]
3580 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3581 use std::sync::atomic::AtomicUsize;
3582 init_test(cx);
3583
3584 let fs = FakeFs::new(cx.executor());
3585 let project = Project::test(fs, None, cx).await;
3586
3587 // Create a connection that simulates refusal after tool result
3588 let prompt_count = Arc::new(AtomicUsize::new(0));
3589 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3590 let prompt_count = prompt_count.clone();
3591 move |_request, thread, mut cx| {
3592 let count = prompt_count.fetch_add(1, SeqCst);
3593 async move {
3594 if count == 0 {
3595 // First prompt: Generate a tool call with result
3596 thread.update(&mut cx, |thread, cx| {
3597 thread
3598 .handle_session_update(
3599 acp::SessionUpdate::ToolCall(
3600 acp::ToolCall::new("tool1", "Test Tool")
3601 .kind(acp::ToolKind::Fetch)
3602 .status(acp::ToolCallStatus::Completed)
3603 .raw_input(serde_json::json!({"query": "test"}))
3604 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3605 ),
3606 cx,
3607 )
3608 .unwrap();
3609 })?;
3610
3611 // Now return refusal because of the tool result
3612 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3613 } else {
3614 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3615 }
3616 }
3617 .boxed_local()
3618 }
3619 }));
3620
3621 let thread = cx
3622 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3623 .await
3624 .unwrap();
3625
3626 // Track if we see a Refusal event
3627 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3628 let saw_refusal_event_captured = saw_refusal_event.clone();
3629 thread.update(cx, |_thread, cx| {
3630 cx.subscribe(
3631 &thread,
3632 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3633 if matches!(event, AcpThreadEvent::Refusal) {
3634 *saw_refusal_event_captured.lock().unwrap() = true;
3635 }
3636 },
3637 )
3638 .detach();
3639 });
3640
3641 // Send a user message - this will trigger tool call and then refusal
3642 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3643 cx.background_executor.spawn(send_task).detach();
3644 cx.run_until_parked();
3645
3646 // Verify that:
3647 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3648 // 2. The user message was NOT truncated
3649 assert!(
3650 *saw_refusal_event.lock().unwrap(),
3651 "Refusal event should be emitted for tool result refusals"
3652 );
3653
3654 thread.read_with(cx, |thread, _| {
3655 let entries = thread.entries();
3656 assert!(entries.len() >= 2, "Should have user message and tool call");
3657
3658 // Verify user message is still there
3659 assert!(
3660 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3661 "User message should not be truncated"
3662 );
3663
3664 // Verify tool call is there with result
3665 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3666 assert!(
3667 tool_call.raw_output.is_some(),
3668 "Tool call should have output"
3669 );
3670 } else {
3671 panic!("Expected tool call at index 1");
3672 }
3673 });
3674 }
3675
3676 #[gpui::test]
3677 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3678 init_test(cx);
3679
3680 let fs = FakeFs::new(cx.executor());
3681 let project = Project::test(fs, None, cx).await;
3682
3683 let refuse_next = Arc::new(AtomicBool::new(false));
3684 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3685 let refuse_next = refuse_next.clone();
3686 move |_request, _thread, _cx| {
3687 if refuse_next.load(SeqCst) {
3688 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3689 .boxed_local()
3690 } else {
3691 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3692 .boxed_local()
3693 }
3694 }
3695 }));
3696
3697 let thread = cx
3698 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3699 .await
3700 .unwrap();
3701
3702 // Track if we see a Refusal event
3703 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3704 let saw_refusal_event_captured = saw_refusal_event.clone();
3705 thread.update(cx, |_thread, cx| {
3706 cx.subscribe(
3707 &thread,
3708 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3709 if matches!(event, AcpThreadEvent::Refusal) {
3710 *saw_refusal_event_captured.lock().unwrap() = true;
3711 }
3712 },
3713 )
3714 .detach();
3715 });
3716
3717 // Send a message that will be refused
3718 refuse_next.store(true, SeqCst);
3719 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3720 .await
3721 .unwrap();
3722
3723 // Verify that a Refusal event WAS emitted for user prompt refusal
3724 assert!(
3725 *saw_refusal_event.lock().unwrap(),
3726 "Refusal event should be emitted for user prompt refusals"
3727 );
3728
3729 // Verify the message was truncated (user prompt refusal)
3730 thread.read_with(cx, |thread, cx| {
3731 assert_eq!(thread.to_markdown(cx), "");
3732 });
3733 }
3734
3735 #[gpui::test]
3736 async fn test_refusal(cx: &mut TestAppContext) {
3737 init_test(cx);
3738 let fs = FakeFs::new(cx.background_executor.clone());
3739 fs.insert_tree(path!("/"), json!({})).await;
3740 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3741
3742 let refuse_next = Arc::new(AtomicBool::new(false));
3743 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3744 let refuse_next = refuse_next.clone();
3745 move |request, thread, mut cx| {
3746 let refuse_next = refuse_next.clone();
3747 async move {
3748 if refuse_next.load(SeqCst) {
3749 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3750 }
3751
3752 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3753 panic!("expected text content block");
3754 };
3755 thread.update(&mut cx, |thread, cx| {
3756 thread
3757 .handle_session_update(
3758 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3759 content.text.to_uppercase().into(),
3760 )),
3761 cx,
3762 )
3763 .unwrap();
3764 })?;
3765 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3766 }
3767 .boxed_local()
3768 }
3769 }));
3770 let thread = cx
3771 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3772 .await
3773 .unwrap();
3774
3775 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3776 .await
3777 .unwrap();
3778 thread.read_with(cx, |thread, cx| {
3779 assert_eq!(
3780 thread.to_markdown(cx),
3781 indoc! {"
3782 ## User
3783
3784 hello
3785
3786 ## Assistant
3787
3788 HELLO
3789
3790 "}
3791 );
3792 });
3793
3794 // Simulate refusing the second message. The message should be truncated
3795 // when a user prompt is refused.
3796 refuse_next.store(true, SeqCst);
3797 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3798 .await
3799 .unwrap();
3800 thread.read_with(cx, |thread, cx| {
3801 assert_eq!(
3802 thread.to_markdown(cx),
3803 indoc! {"
3804 ## User
3805
3806 hello
3807
3808 ## Assistant
3809
3810 HELLO
3811
3812 "}
3813 );
3814 });
3815 }
3816
3817 async fn run_until_first_tool_call(
3818 thread: &Entity<AcpThread>,
3819 cx: &mut TestAppContext,
3820 ) -> usize {
3821 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3822
3823 let subscription = cx.update(|cx| {
3824 cx.subscribe(thread, move |thread, _, cx| {
3825 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3826 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3827 return tx.try_send(ix).unwrap();
3828 }
3829 }
3830 })
3831 });
3832
3833 select! {
3834 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3835 panic!("Timeout waiting for tool call")
3836 }
3837 ix = rx.next().fuse() => {
3838 drop(subscription);
3839 ix.unwrap()
3840 }
3841 }
3842 }
3843
3844 #[derive(Clone, Default)]
3845 struct FakeAgentConnection {
3846 auth_methods: Vec<acp::AuthMethod>,
3847 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3848 on_user_message: Option<
3849 Rc<
3850 dyn Fn(
3851 acp::PromptRequest,
3852 WeakEntity<AcpThread>,
3853 AsyncApp,
3854 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3855 + 'static,
3856 >,
3857 >,
3858 }
3859
3860 impl FakeAgentConnection {
3861 fn new() -> Self {
3862 Self {
3863 auth_methods: Vec::new(),
3864 on_user_message: None,
3865 sessions: Arc::default(),
3866 }
3867 }
3868
3869 #[expect(unused)]
3870 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3871 self.auth_methods = auth_methods;
3872 self
3873 }
3874
3875 fn on_user_message(
3876 mut self,
3877 handler: impl Fn(
3878 acp::PromptRequest,
3879 WeakEntity<AcpThread>,
3880 AsyncApp,
3881 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3882 + 'static,
3883 ) -> Self {
3884 self.on_user_message.replace(Rc::new(handler));
3885 self
3886 }
3887 }
3888
3889 impl AgentConnection for FakeAgentConnection {
3890 fn telemetry_id(&self) -> SharedString {
3891 "fake".into()
3892 }
3893
3894 fn auth_methods(&self) -> &[acp::AuthMethod] {
3895 &self.auth_methods
3896 }
3897
3898 fn new_session(
3899 self: Rc<Self>,
3900 project: Entity<Project>,
3901 _cwd: &Path,
3902 cx: &mut App,
3903 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3904 let session_id = acp::SessionId::new(
3905 rand::rng()
3906 .sample_iter(&distr::Alphanumeric)
3907 .take(7)
3908 .map(char::from)
3909 .collect::<String>(),
3910 );
3911 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3912 let thread = cx.new(|cx| {
3913 AcpThread::new(
3914 None,
3915 "Test",
3916 self.clone(),
3917 project,
3918 action_log,
3919 session_id.clone(),
3920 watch::Receiver::constant(
3921 acp::PromptCapabilities::new()
3922 .image(true)
3923 .audio(true)
3924 .embedded_context(true),
3925 ),
3926 cx,
3927 )
3928 });
3929 self.sessions.lock().insert(session_id, thread.downgrade());
3930 Task::ready(Ok(thread))
3931 }
3932
3933 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3934 if self.auth_methods().iter().any(|m| m.id == method) {
3935 Task::ready(Ok(()))
3936 } else {
3937 Task::ready(Err(anyhow!("Invalid Auth Method")))
3938 }
3939 }
3940
3941 fn prompt(
3942 &self,
3943 _id: Option<UserMessageId>,
3944 params: acp::PromptRequest,
3945 cx: &mut App,
3946 ) -> Task<gpui::Result<acp::PromptResponse>> {
3947 let sessions = self.sessions.lock();
3948 let thread = sessions.get(¶ms.session_id).unwrap();
3949 if let Some(handler) = &self.on_user_message {
3950 let handler = handler.clone();
3951 let thread = thread.clone();
3952 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3953 } else {
3954 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3955 }
3956 }
3957
3958 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
3959
3960 fn truncate(
3961 &self,
3962 session_id: &acp::SessionId,
3963 _cx: &App,
3964 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3965 Some(Rc::new(FakeAgentSessionEditor {
3966 _session_id: session_id.clone(),
3967 }))
3968 }
3969
3970 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3971 self
3972 }
3973 }
3974
3975 struct FakeAgentSessionEditor {
3976 _session_id: acp::SessionId,
3977 }
3978
3979 impl AgentSessionTruncate for FakeAgentSessionEditor {
3980 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3981 Task::ready(Ok(()))
3982 }
3983 }
3984
3985 #[gpui::test]
3986 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3987 init_test(cx);
3988
3989 let fs = FakeFs::new(cx.executor());
3990 let project = Project::test(fs, [], cx).await;
3991 let connection = Rc::new(FakeAgentConnection::new());
3992 let thread = cx
3993 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3994 .await
3995 .unwrap();
3996
3997 // Try to update a tool call that doesn't exist
3998 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3999 thread.update(cx, |thread, cx| {
4000 let result = thread.handle_session_update(
4001 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4002 nonexistent_id.clone(),
4003 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4004 )),
4005 cx,
4006 );
4007
4008 // The update should succeed (not return an error)
4009 assert!(result.is_ok());
4010
4011 // There should now be exactly one entry in the thread
4012 assert_eq!(thread.entries.len(), 1);
4013
4014 // The entry should be a failed tool call
4015 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4016 assert_eq!(tool_call.id, nonexistent_id);
4017 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4018 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4019
4020 // Check that the content contains the error message
4021 assert_eq!(tool_call.content.len(), 1);
4022 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4023 match content_block {
4024 ContentBlock::Markdown { markdown } => {
4025 let markdown_text = markdown.read(cx).source();
4026 assert!(markdown_text.contains("Tool call not found"));
4027 }
4028 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4029 ContentBlock::ResourceLink { .. } => {
4030 panic!("Expected markdown content, got resource link")
4031 }
4032 ContentBlock::Image { .. } => {
4033 panic!("Expected markdown content, got image")
4034 }
4035 }
4036 } else {
4037 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4038 }
4039 } else {
4040 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4041 }
4042 });
4043 }
4044
4045 /// Tests that restoring a checkpoint properly cleans up terminals that were
4046 /// created after that checkpoint, and cancels any in-progress generation.
4047 ///
4048 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4049 /// that were started after that checkpoint should be terminated, and any in-progress
4050 /// AI generation should be canceled.
4051 #[gpui::test]
4052 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4053 init_test(cx);
4054
4055 let fs = FakeFs::new(cx.executor());
4056 let project = Project::test(fs, [], cx).await;
4057 let connection = Rc::new(FakeAgentConnection::new());
4058 let thread = cx
4059 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4060 .await
4061 .unwrap();
4062
4063 // Send first user message to create a checkpoint
4064 cx.update(|cx| {
4065 thread.update(cx, |thread, cx| {
4066 thread.send(vec!["first message".into()], cx)
4067 })
4068 })
4069 .await
4070 .unwrap();
4071
4072 // Send second message (creates another checkpoint) - we'll restore to this one
4073 cx.update(|cx| {
4074 thread.update(cx, |thread, cx| {
4075 thread.send(vec!["second message".into()], cx)
4076 })
4077 })
4078 .await
4079 .unwrap();
4080
4081 // Create 2 terminals BEFORE the checkpoint that have completed running
4082 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4083 let mock_terminal_1 = cx.new(|cx| {
4084 let builder = ::terminal::TerminalBuilder::new_display_only(
4085 ::terminal::terminal_settings::CursorShape::default(),
4086 ::terminal::terminal_settings::AlternateScroll::On,
4087 None,
4088 0,
4089 cx.background_executor(),
4090 PathStyle::local(),
4091 )
4092 .unwrap();
4093 builder.subscribe(cx)
4094 });
4095
4096 thread.update(cx, |thread, cx| {
4097 thread.on_terminal_provider_event(
4098 TerminalProviderEvent::Created {
4099 terminal_id: terminal_id_1.clone(),
4100 label: "echo 'first'".to_string(),
4101 cwd: Some(PathBuf::from("/test")),
4102 output_byte_limit: None,
4103 terminal: mock_terminal_1.clone(),
4104 },
4105 cx,
4106 );
4107 });
4108
4109 thread.update(cx, |thread, cx| {
4110 thread.on_terminal_provider_event(
4111 TerminalProviderEvent::Output {
4112 terminal_id: terminal_id_1.clone(),
4113 data: b"first\n".to_vec(),
4114 },
4115 cx,
4116 );
4117 });
4118
4119 thread.update(cx, |thread, cx| {
4120 thread.on_terminal_provider_event(
4121 TerminalProviderEvent::Exit {
4122 terminal_id: terminal_id_1.clone(),
4123 status: acp::TerminalExitStatus::new().exit_code(0),
4124 },
4125 cx,
4126 );
4127 });
4128
4129 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4130 let mock_terminal_2 = cx.new(|cx| {
4131 let builder = ::terminal::TerminalBuilder::new_display_only(
4132 ::terminal::terminal_settings::CursorShape::default(),
4133 ::terminal::terminal_settings::AlternateScroll::On,
4134 None,
4135 0,
4136 cx.background_executor(),
4137 PathStyle::local(),
4138 )
4139 .unwrap();
4140 builder.subscribe(cx)
4141 });
4142
4143 thread.update(cx, |thread, cx| {
4144 thread.on_terminal_provider_event(
4145 TerminalProviderEvent::Created {
4146 terminal_id: terminal_id_2.clone(),
4147 label: "echo 'second'".to_string(),
4148 cwd: Some(PathBuf::from("/test")),
4149 output_byte_limit: None,
4150 terminal: mock_terminal_2.clone(),
4151 },
4152 cx,
4153 );
4154 });
4155
4156 thread.update(cx, |thread, cx| {
4157 thread.on_terminal_provider_event(
4158 TerminalProviderEvent::Output {
4159 terminal_id: terminal_id_2.clone(),
4160 data: b"second\n".to_vec(),
4161 },
4162 cx,
4163 );
4164 });
4165
4166 thread.update(cx, |thread, cx| {
4167 thread.on_terminal_provider_event(
4168 TerminalProviderEvent::Exit {
4169 terminal_id: terminal_id_2.clone(),
4170 status: acp::TerminalExitStatus::new().exit_code(0),
4171 },
4172 cx,
4173 );
4174 });
4175
4176 // Get the second message ID to restore to
4177 let second_message_id = thread.read_with(cx, |thread, _| {
4178 // At this point we have:
4179 // - Index 0: First user message (with checkpoint)
4180 // - Index 1: Second user message (with checkpoint)
4181 // No assistant responses because FakeAgentConnection just returns EndTurn
4182 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4183 panic!("expected user message at index 1");
4184 };
4185 message.id.clone().unwrap()
4186 });
4187
4188 // Create a terminal AFTER the checkpoint we'll restore to.
4189 // This simulates the AI agent starting a long-running terminal command.
4190 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4191 let mock_terminal = cx.new(|cx| {
4192 let builder = ::terminal::TerminalBuilder::new_display_only(
4193 ::terminal::terminal_settings::CursorShape::default(),
4194 ::terminal::terminal_settings::AlternateScroll::On,
4195 None,
4196 0,
4197 cx.background_executor(),
4198 PathStyle::local(),
4199 )
4200 .unwrap();
4201 builder.subscribe(cx)
4202 });
4203
4204 // Register the terminal as created
4205 thread.update(cx, |thread, cx| {
4206 thread.on_terminal_provider_event(
4207 TerminalProviderEvent::Created {
4208 terminal_id: terminal_id.clone(),
4209 label: "sleep 1000".to_string(),
4210 cwd: Some(PathBuf::from("/test")),
4211 output_byte_limit: None,
4212 terminal: mock_terminal.clone(),
4213 },
4214 cx,
4215 );
4216 });
4217
4218 // Simulate the terminal producing output (still running)
4219 thread.update(cx, |thread, cx| {
4220 thread.on_terminal_provider_event(
4221 TerminalProviderEvent::Output {
4222 terminal_id: terminal_id.clone(),
4223 data: b"terminal is running...\n".to_vec(),
4224 },
4225 cx,
4226 );
4227 });
4228
4229 // Create a tool call entry that references this terminal
4230 // This represents the agent requesting a terminal command
4231 thread.update(cx, |thread, cx| {
4232 thread
4233 .handle_session_update(
4234 acp::SessionUpdate::ToolCall(
4235 acp::ToolCall::new("terminal-tool-1", "Running command")
4236 .kind(acp::ToolKind::Execute)
4237 .status(acp::ToolCallStatus::InProgress)
4238 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4239 terminal_id.clone(),
4240 ))])
4241 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4242 ),
4243 cx,
4244 )
4245 .unwrap();
4246 });
4247
4248 // Verify terminal exists and is in the thread
4249 let terminal_exists_before =
4250 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4251 assert!(
4252 terminal_exists_before,
4253 "Terminal should exist before checkpoint restore"
4254 );
4255
4256 // Verify the terminal's underlying task is still running (not completed)
4257 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4258 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4259 terminal_entity.read_with(cx, |term, _cx| {
4260 term.output().is_none() // output is None means it's still running
4261 })
4262 });
4263 assert!(
4264 terminal_running_before,
4265 "Terminal should be running before checkpoint restore"
4266 );
4267
4268 // Verify we have the expected entries before restore
4269 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4270 assert!(
4271 entry_count_before > 1,
4272 "Should have multiple entries before restore"
4273 );
4274
4275 // Restore the checkpoint to the second message.
4276 // This should:
4277 // 1. Cancel any in-progress generation (via the cancel() call)
4278 // 2. Remove the terminal that was created after that point
4279 thread
4280 .update(cx, |thread, cx| {
4281 thread.restore_checkpoint(second_message_id, cx)
4282 })
4283 .await
4284 .unwrap();
4285
4286 // Verify that no send_task is in progress after restore
4287 // (cancel() clears the send_task)
4288 let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4289 assert!(
4290 !has_send_task_after,
4291 "Should not have a send_task after restore (cancel should have cleared it)"
4292 );
4293
4294 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4295 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4296 assert_eq!(
4297 entry_count, 1,
4298 "Should have 1 entry after restore (only the first user message)"
4299 );
4300
4301 // Verify the 2 completed terminals from before the checkpoint still exist
4302 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4303 thread.terminals.contains_key(&terminal_id_1)
4304 });
4305 assert!(
4306 terminal_1_exists,
4307 "Terminal 1 (from before checkpoint) should still exist"
4308 );
4309
4310 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4311 thread.terminals.contains_key(&terminal_id_2)
4312 });
4313 assert!(
4314 terminal_2_exists,
4315 "Terminal 2 (from before checkpoint) should still exist"
4316 );
4317
4318 // Verify they're still in completed state
4319 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4320 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4321 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4322 });
4323 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4324
4325 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4326 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4327 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4328 });
4329 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4330
4331 // Verify the running terminal (created after checkpoint) was removed
4332 let terminal_3_exists =
4333 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4334 assert!(
4335 !terminal_3_exists,
4336 "Terminal 3 (created after checkpoint) should have been removed"
4337 );
4338
4339 // Verify total count is 2 (the two from before the checkpoint)
4340 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4341 assert_eq!(
4342 terminal_count, 2,
4343 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4344 );
4345 }
4346
4347 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4348 /// even when a new user message is added while the async checkpoint comparison is in progress.
4349 ///
4350 /// This is a regression test for a bug where update_last_checkpoint would fail with
4351 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4352 /// update_last_checkpoint started and when its async closure ran.
4353 #[gpui::test]
4354 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4355 init_test(cx);
4356
4357 let fs = FakeFs::new(cx.executor());
4358 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4359 .await;
4360 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4361
4362 let handler_done = Arc::new(AtomicBool::new(false));
4363 let handler_done_clone = handler_done.clone();
4364 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4365 move |_, _thread, _cx| {
4366 handler_done_clone.store(true, SeqCst);
4367 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4368 },
4369 ));
4370
4371 let thread = cx
4372 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4373 .await
4374 .unwrap();
4375
4376 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4377 let send_task = cx.background_executor.spawn(send_future);
4378
4379 // Tick until handler completes, then a few more to let update_last_checkpoint start
4380 while !handler_done.load(SeqCst) {
4381 cx.executor().tick();
4382 }
4383 for _ in 0..5 {
4384 cx.executor().tick();
4385 }
4386
4387 thread.update(cx, |thread, cx| {
4388 thread.push_entry(
4389 AgentThreadEntry::UserMessage(UserMessage {
4390 id: Some(UserMessageId::new()),
4391 content: ContentBlock::Empty,
4392 chunks: vec!["Injected message (no checkpoint)".into()],
4393 checkpoint: None,
4394 indented: false,
4395 }),
4396 cx,
4397 );
4398 });
4399
4400 cx.run_until_parked();
4401 let result = send_task.await;
4402
4403 assert!(
4404 result.is_ok(),
4405 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4406 result.err()
4407 );
4408 }
4409
4410 /// Tests that when a follow-up message is sent during generation,
4411 /// the first turn completing does NOT clear `running_turn` because
4412 /// it now belongs to the second turn.
4413 #[gpui::test]
4414 async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
4415 init_test(cx);
4416
4417 let fs = FakeFs::new(cx.executor());
4418 let project = Project::test(fs, [], cx).await;
4419
4420 // First handler waits for this signal before completing
4421 let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
4422 let first_complete_rx = RefCell::new(Some(first_complete_rx));
4423
4424 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
4425 move |params, _thread, _cx| {
4426 let first_complete_rx = first_complete_rx.borrow_mut().take();
4427 let is_first = params
4428 .prompt
4429 .iter()
4430 .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
4431
4432 async move {
4433 if is_first {
4434 // First handler waits until signaled
4435 if let Some(rx) = first_complete_rx {
4436 rx.await.ok();
4437 }
4438 }
4439 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
4440 }
4441 .boxed_local()
4442 }
4443 }));
4444
4445 let thread = cx
4446 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4447 .await
4448 .unwrap();
4449
4450 // Send first message (turn_id=1) - handler will block
4451 let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
4452 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
4453
4454 // Send second message (turn_id=2) while first is still blocked
4455 // This calls cancel() which takes turn 1's running_turn and sets turn 2's
4456 let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
4457 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
4458
4459 let running_turn_after_second_send =
4460 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4461 assert_eq!(
4462 running_turn_after_second_send,
4463 Some(2),
4464 "running_turn should be set to turn 2 after sending second message"
4465 );
4466
4467 // Now signal first handler to complete
4468 first_complete_tx.send(()).ok();
4469
4470 // First request completes - should NOT clear running_turn
4471 // because running_turn now belongs to turn 2
4472 first_request.await.unwrap();
4473
4474 let running_turn_after_first =
4475 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4476 assert_eq!(
4477 running_turn_after_first,
4478 Some(2),
4479 "first turn completing should not clear running_turn (belongs to turn 2)"
4480 );
4481
4482 // Second request completes - SHOULD clear running_turn
4483 second_request.await.unwrap();
4484
4485 let running_turn_after_second =
4486 thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4487 assert!(
4488 !running_turn_after_second,
4489 "second turn completing should clear running_turn"
4490 );
4491 }
4492
4493 #[gpui::test]
4494 async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
4495 cx: &mut TestAppContext,
4496 ) {
4497 init_test(cx);
4498
4499 let fs = FakeFs::new(cx.executor());
4500 let project = Project::test(fs, [], cx).await;
4501
4502 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4503 move |_params, thread, mut cx| {
4504 async move {
4505 thread
4506 .update(&mut cx, |thread, cx| {
4507 thread.handle_session_update(
4508 acp::SessionUpdate::ToolCall(
4509 acp::ToolCall::new(
4510 acp::ToolCallId::new("test-tool"),
4511 "Test Tool",
4512 )
4513 .kind(acp::ToolKind::Fetch)
4514 .status(acp::ToolCallStatus::InProgress),
4515 ),
4516 cx,
4517 )
4518 })
4519 .unwrap()
4520 .unwrap();
4521
4522 Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
4523 }
4524 .boxed_local()
4525 },
4526 ));
4527
4528 let thread = cx
4529 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4530 .await
4531 .unwrap();
4532
4533 let response = thread
4534 .update(cx, |thread, cx| thread.send_raw("test message", cx))
4535 .await;
4536
4537 let response = response
4538 .expect("send should succeed")
4539 .expect("should have response");
4540 assert_eq!(
4541 response.stop_reason,
4542 acp::StopReason::Cancelled,
4543 "response should have Cancelled stop_reason"
4544 );
4545
4546 thread.read_with(cx, |thread, _| {
4547 let tool_entry = thread
4548 .entries
4549 .iter()
4550 .find_map(|e| {
4551 if let AgentThreadEntry::ToolCall(call) = e {
4552 Some(call)
4553 } else {
4554 None
4555 }
4556 })
4557 .expect("should have tool call entry");
4558
4559 assert!(
4560 matches!(tool_entry.status, ToolCallStatus::Canceled),
4561 "tool should be marked as Canceled when response is Cancelled, got {:?}",
4562 tool_entry.status
4563 );
4564 });
4565 }
4566}