1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6/// Key used in ACP ToolCall meta to store the tool's programmatic name.
7/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
8pub const TOOL_NAME_META_KEY: &str = "tool_name";
9
10/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned.
11pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id";
12
13/// Helper to extract tool name from ACP meta
14pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
15 meta.as_ref()
16 .and_then(|m| m.get(TOOL_NAME_META_KEY))
17 .and_then(|v| v.as_str())
18 .map(|s| SharedString::from(s.to_owned()))
19}
20
21/// Helper to extract subagent session id from ACP meta
22pub fn subagent_session_id_from_meta(meta: &Option<acp::Meta>) -> Option<acp::SessionId> {
23 meta.as_ref()
24 .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY))
25 .and_then(|v| v.as_str())
26 .map(|s| acp::SessionId::from(s.to_string()))
27}
28
29/// Helper to create meta with tool name
30pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
31 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
32}
33use collections::HashSet;
34pub use connection::*;
35pub use diff::*;
36use language::language_settings::FormatOnSave;
37pub use mention::*;
38use project::lsp_store::{FormatTrigger, LspFormatTarget};
39use serde::{Deserialize, Serialize};
40use serde_json::to_string_pretty;
41
42use task::{Shell, ShellBuilder};
43pub use terminal::*;
44
45use action_log::{ActionLog, ActionLogTelemetry};
46use agent_client_protocol::{self as acp};
47use anyhow::{Context as _, Result, anyhow};
48use futures::{FutureExt, channel::oneshot, future::BoxFuture};
49use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
50use itertools::Itertools;
51use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
52use markdown::Markdown;
53use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
54use std::collections::HashMap;
55use std::error::Error;
56use std::fmt::{Formatter, Write};
57use std::ops::Range;
58use std::process::ExitStatus;
59use std::rc::Rc;
60use std::time::{Duration, Instant};
61use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
62use text::Bias;
63use ui::App;
64use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
65use uuid::Uuid;
66
67#[derive(Debug)]
68pub struct UserMessage {
69 pub id: Option<UserMessageId>,
70 pub content: ContentBlock,
71 pub chunks: Vec<acp::ContentBlock>,
72 pub checkpoint: Option<Checkpoint>,
73 pub indented: bool,
74}
75
76#[derive(Debug)]
77pub struct Checkpoint {
78 git_checkpoint: GitStoreCheckpoint,
79 pub show: bool,
80}
81
82impl UserMessage {
83 fn to_markdown(&self, cx: &App) -> String {
84 let mut markdown = String::new();
85 if self
86 .checkpoint
87 .as_ref()
88 .is_some_and(|checkpoint| checkpoint.show)
89 {
90 writeln!(markdown, "## User (checkpoint)").unwrap();
91 } else {
92 writeln!(markdown, "## User").unwrap();
93 }
94 writeln!(markdown).unwrap();
95 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
96 writeln!(markdown).unwrap();
97 markdown
98 }
99}
100
101#[derive(Debug, PartialEq)]
102pub struct AssistantMessage {
103 pub chunks: Vec<AssistantMessageChunk>,
104 pub indented: bool,
105}
106
107impl AssistantMessage {
108 pub fn to_markdown(&self, cx: &App) -> String {
109 format!(
110 "## Assistant\n\n{}\n\n",
111 self.chunks
112 .iter()
113 .map(|chunk| chunk.to_markdown(cx))
114 .join("\n\n")
115 )
116 }
117}
118
119#[derive(Debug, PartialEq)]
120pub enum AssistantMessageChunk {
121 Message { block: ContentBlock },
122 Thought { block: ContentBlock },
123}
124
125impl AssistantMessageChunk {
126 pub fn from_str(
127 chunk: &str,
128 language_registry: &Arc<LanguageRegistry>,
129 path_style: PathStyle,
130 cx: &mut App,
131 ) -> Self {
132 Self::Message {
133 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
134 }
135 }
136
137 fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::Message { block } => block.to_markdown(cx).to_string(),
140 Self::Thought { block } => {
141 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
142 }
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum AgentThreadEntry {
149 UserMessage(UserMessage),
150 AssistantMessage(AssistantMessage),
151 ToolCall(ToolCall),
152}
153
154impl AgentThreadEntry {
155 pub fn is_indented(&self) -> bool {
156 match self {
157 Self::UserMessage(message) => message.indented,
158 Self::AssistantMessage(message) => message.indented,
159 Self::ToolCall(_) => false,
160 }
161 }
162
163 pub fn to_markdown(&self, cx: &App) -> String {
164 match self {
165 Self::UserMessage(message) => message.to_markdown(cx),
166 Self::AssistantMessage(message) => message.to_markdown(cx),
167 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
168 }
169 }
170
171 pub fn user_message(&self) -> Option<&UserMessage> {
172 if let AgentThreadEntry::UserMessage(message) = self {
173 Some(message)
174 } else {
175 None
176 }
177 }
178
179 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
180 if let AgentThreadEntry::ToolCall(call) = self {
181 itertools::Either::Left(call.diffs())
182 } else {
183 itertools::Either::Right(std::iter::empty())
184 }
185 }
186
187 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
188 if let AgentThreadEntry::ToolCall(call) = self {
189 itertools::Either::Left(call.terminals())
190 } else {
191 itertools::Either::Right(std::iter::empty())
192 }
193 }
194
195 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
196 if let AgentThreadEntry::ToolCall(ToolCall {
197 locations,
198 resolved_locations,
199 ..
200 }) = self
201 {
202 Some((
203 locations.get(ix)?.clone(),
204 resolved_locations.get(ix)?.clone()?,
205 ))
206 } else {
207 None
208 }
209 }
210}
211
212#[derive(Debug)]
213pub struct ToolCall {
214 pub id: acp::ToolCallId,
215 pub label: Entity<Markdown>,
216 pub kind: acp::ToolKind,
217 pub content: Vec<ToolCallContent>,
218 pub status: ToolCallStatus,
219 pub locations: Vec<acp::ToolCallLocation>,
220 pub resolved_locations: Vec<Option<AgentLocation>>,
221 pub raw_input: Option<serde_json::Value>,
222 pub raw_input_markdown: Option<Entity<Markdown>>,
223 pub raw_output: Option<serde_json::Value>,
224 pub tool_name: Option<SharedString>,
225 pub subagent_session_id: Option<acp::SessionId>,
226}
227
228impl ToolCall {
229 fn from_acp(
230 tool_call: acp::ToolCall,
231 status: ToolCallStatus,
232 language_registry: Arc<LanguageRegistry>,
233 path_style: PathStyle,
234 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
235 cx: &mut App,
236 ) -> Result<Self> {
237 let title = if tool_call.kind == acp::ToolKind::Execute {
238 tool_call.title
239 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
240 first_line.to_owned() + "…"
241 } else {
242 tool_call.title
243 };
244 let mut content = Vec::with_capacity(tool_call.content.len());
245 for item in tool_call.content {
246 if let Some(item) = ToolCallContent::from_acp(
247 item,
248 language_registry.clone(),
249 path_style,
250 terminals,
251 cx,
252 )? {
253 content.push(item);
254 }
255 }
256
257 let raw_input_markdown = tool_call
258 .raw_input
259 .as_ref()
260 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
261
262 let tool_name = tool_name_from_meta(&tool_call.meta);
263
264 let subagent_session = subagent_session_id_from_meta(&tool_call.meta);
265
266 let result = Self {
267 id: tool_call.tool_call_id,
268 label: cx
269 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
270 kind: tool_call.kind,
271 content,
272 locations: tool_call.locations,
273 resolved_locations: Vec::default(),
274 status,
275 raw_input: tool_call.raw_input,
276 raw_input_markdown,
277 raw_output: tool_call.raw_output,
278 tool_name,
279 subagent_session_id: subagent_session,
280 };
281 Ok(result)
282 }
283
284 fn update_fields(
285 &mut self,
286 fields: acp::ToolCallUpdateFields,
287 meta: Option<acp::Meta>,
288 language_registry: Arc<LanguageRegistry>,
289 path_style: PathStyle,
290 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
291 cx: &mut App,
292 ) -> Result<()> {
293 let acp::ToolCallUpdateFields {
294 kind,
295 status,
296 title,
297 content,
298 locations,
299 raw_input,
300 raw_output,
301 ..
302 } = fields;
303
304 if let Some(kind) = kind {
305 self.kind = kind;
306 }
307
308 if let Some(status) = status {
309 self.status = status.into();
310 }
311
312 if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) {
313 self.subagent_session_id = Some(subagent_session_id);
314 }
315
316 if let Some(title) = title {
317 if self.kind == acp::ToolKind::Execute {
318 for terminal in self.terminals() {
319 terminal.update(cx, |terminal, cx| {
320 terminal.update_command_label(&title, cx);
321 });
322 }
323 }
324 self.label.update(cx, |label, cx| {
325 if self.kind == acp::ToolKind::Execute {
326 label.replace(title, cx);
327 } else if let Some((first_line, _)) = title.split_once("\n") {
328 label.replace(first_line.to_owned() + "…", cx);
329 } else {
330 label.replace(title, cx);
331 }
332 });
333 }
334
335 if let Some(content) = content {
336 let mut new_content_len = content.len();
337 let mut content = content.into_iter();
338
339 // Reuse existing content if we can
340 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
341 let valid_content =
342 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
343 if !valid_content {
344 new_content_len -= 1;
345 }
346 }
347 for new in content {
348 if let Some(new) = ToolCallContent::from_acp(
349 new,
350 language_registry.clone(),
351 path_style,
352 terminals,
353 cx,
354 )? {
355 self.content.push(new);
356 } else {
357 new_content_len -= 1;
358 }
359 }
360 self.content.truncate(new_content_len);
361 }
362
363 if let Some(locations) = locations {
364 self.locations = locations;
365 }
366
367 if let Some(raw_input) = raw_input {
368 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
369 self.raw_input = Some(raw_input);
370 }
371
372 if let Some(raw_output) = raw_output {
373 if self.content.is_empty()
374 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
375 {
376 self.content
377 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
378 markdown,
379 }));
380 }
381 self.raw_output = Some(raw_output);
382 }
383 Ok(())
384 }
385
386 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
387 self.content.iter().filter_map(|content| match content {
388 ToolCallContent::Diff(diff) => Some(diff),
389 ToolCallContent::ContentBlock(_) => None,
390 ToolCallContent::Terminal(_) => None,
391 })
392 }
393
394 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
395 self.content.iter().filter_map(|content| match content {
396 ToolCallContent::Terminal(terminal) => Some(terminal),
397 ToolCallContent::ContentBlock(_) => None,
398 ToolCallContent::Diff(_) => None,
399 })
400 }
401
402 pub fn is_subagent(&self) -> bool {
403 self.tool_name.as_ref().is_some_and(|s| s == "subagent")
404 || self.subagent_session_id.is_some()
405 }
406
407 pub fn to_markdown(&self, cx: &App) -> String {
408 let mut markdown = format!(
409 "**Tool Call: {}**\nStatus: {}\n\n",
410 self.label.read(cx).source(),
411 self.status
412 );
413 for content in &self.content {
414 markdown.push_str(content.to_markdown(cx).as_str());
415 markdown.push_str("\n\n");
416 }
417 markdown
418 }
419
420 async fn resolve_location(
421 location: acp::ToolCallLocation,
422 project: WeakEntity<Project>,
423 cx: &mut AsyncApp,
424 ) -> Option<ResolvedLocation> {
425 let buffer = project
426 .update(cx, |project, cx| {
427 project
428 .project_path_for_absolute_path(&location.path, cx)
429 .map(|path| project.open_buffer(path, cx))
430 })
431 .ok()??;
432 let buffer = buffer.await.log_err()?;
433 let position = buffer.update(cx, |buffer, _| {
434 let snapshot = buffer.snapshot();
435 if let Some(row) = location.line {
436 let column = snapshot.indent_size_for_line(row).len;
437 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
438 snapshot.anchor_before(point)
439 } else {
440 Anchor::min_for_buffer(snapshot.remote_id())
441 }
442 });
443
444 Some(ResolvedLocation { buffer, position })
445 }
446
447 fn resolve_locations(
448 &self,
449 project: Entity<Project>,
450 cx: &mut App,
451 ) -> Task<Vec<Option<ResolvedLocation>>> {
452 let locations = self.locations.clone();
453 project.update(cx, |_, cx| {
454 cx.spawn(async move |project, cx| {
455 let mut new_locations = Vec::new();
456 for location in locations {
457 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
458 }
459 new_locations
460 })
461 })
462 }
463}
464
465// Separate so we can hold a strong reference to the buffer
466// for saving on the thread
467#[derive(Clone, Debug, PartialEq, Eq)]
468struct ResolvedLocation {
469 buffer: Entity<Buffer>,
470 position: Anchor,
471}
472
473impl From<&ResolvedLocation> for AgentLocation {
474 fn from(value: &ResolvedLocation) -> Self {
475 Self {
476 buffer: value.buffer.downgrade(),
477 position: value.position,
478 }
479 }
480}
481
482#[derive(Debug)]
483pub enum ToolCallStatus {
484 /// The tool call hasn't started running yet, but we start showing it to
485 /// the user.
486 Pending,
487 /// The tool call is waiting for confirmation from the user.
488 WaitingForConfirmation {
489 options: PermissionOptions,
490 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
491 },
492 /// The tool call is currently running.
493 InProgress,
494 /// The tool call completed successfully.
495 Completed,
496 /// The tool call failed.
497 Failed,
498 /// The user rejected the tool call.
499 Rejected,
500 /// The user canceled generation so the tool call was canceled.
501 Canceled,
502}
503
504impl From<acp::ToolCallStatus> for ToolCallStatus {
505 fn from(status: acp::ToolCallStatus) -> Self {
506 match status {
507 acp::ToolCallStatus::Pending => Self::Pending,
508 acp::ToolCallStatus::InProgress => Self::InProgress,
509 acp::ToolCallStatus::Completed => Self::Completed,
510 acp::ToolCallStatus::Failed => Self::Failed,
511 _ => Self::Pending,
512 }
513 }
514}
515
516impl Display for ToolCallStatus {
517 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "{}",
521 match self {
522 ToolCallStatus::Pending => "Pending",
523 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
524 ToolCallStatus::InProgress => "In Progress",
525 ToolCallStatus::Completed => "Completed",
526 ToolCallStatus::Failed => "Failed",
527 ToolCallStatus::Rejected => "Rejected",
528 ToolCallStatus::Canceled => "Canceled",
529 }
530 )
531 }
532}
533
534#[derive(Debug, PartialEq, Clone)]
535pub enum ContentBlock {
536 Empty,
537 Markdown { markdown: Entity<Markdown> },
538 ResourceLink { resource_link: acp::ResourceLink },
539 Image { image: Arc<gpui::Image> },
540}
541
542impl ContentBlock {
543 pub fn new(
544 block: acp::ContentBlock,
545 language_registry: &Arc<LanguageRegistry>,
546 path_style: PathStyle,
547 cx: &mut App,
548 ) -> Self {
549 let mut this = Self::Empty;
550 this.append(block, language_registry, path_style, cx);
551 this
552 }
553
554 pub fn new_combined(
555 blocks: impl IntoIterator<Item = acp::ContentBlock>,
556 language_registry: Arc<LanguageRegistry>,
557 path_style: PathStyle,
558 cx: &mut App,
559 ) -> Self {
560 let mut this = Self::Empty;
561 for block in blocks {
562 this.append(block, &language_registry, path_style, cx);
563 }
564 this
565 }
566
567 pub fn append(
568 &mut self,
569 block: acp::ContentBlock,
570 language_registry: &Arc<LanguageRegistry>,
571 path_style: PathStyle,
572 cx: &mut App,
573 ) {
574 match (&mut *self, &block) {
575 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
576 *self = ContentBlock::ResourceLink {
577 resource_link: resource_link.clone(),
578 };
579 }
580 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
581 if let Some(image) = Self::decode_image(image_content) {
582 *self = ContentBlock::Image { image };
583 } else {
584 let new_content = Self::image_md(image_content);
585 *self = Self::create_markdown_block(new_content, language_registry, cx);
586 }
587 }
588 (ContentBlock::Empty, _) => {
589 let new_content = Self::block_string_contents(&block, path_style);
590 *self = Self::create_markdown_block(new_content, language_registry, cx);
591 }
592 (ContentBlock::Markdown { markdown }, _) => {
593 let new_content = Self::block_string_contents(&block, path_style);
594 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
595 }
596 (ContentBlock::ResourceLink { resource_link }, _) => {
597 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
598 let new_content = Self::block_string_contents(&block, path_style);
599 let combined = format!("{}\n{}", existing_content, new_content);
600 *self = Self::create_markdown_block(combined, language_registry, cx);
601 }
602 (ContentBlock::Image { .. }, _) => {
603 let new_content = Self::block_string_contents(&block, path_style);
604 let combined = format!("`Image`\n{}", new_content);
605 *self = Self::create_markdown_block(combined, language_registry, cx);
606 }
607 }
608 }
609
610 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
611 use base64::Engine as _;
612
613 let bytes = base64::engine::general_purpose::STANDARD
614 .decode(image_content.data.as_bytes())
615 .ok()?;
616 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
617 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
618 }
619
620 fn create_markdown_block(
621 content: String,
622 language_registry: &Arc<LanguageRegistry>,
623 cx: &mut App,
624 ) -> ContentBlock {
625 ContentBlock::Markdown {
626 markdown: cx
627 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
628 }
629 }
630
631 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
632 match block {
633 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
634 acp::ContentBlock::ResourceLink(resource_link) => {
635 Self::resource_link_md(&resource_link.uri, path_style)
636 }
637 acp::ContentBlock::Resource(acp::EmbeddedResource {
638 resource:
639 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
640 uri,
641 ..
642 }),
643 ..
644 }) => Self::resource_link_md(uri, path_style),
645 acp::ContentBlock::Image(image) => Self::image_md(image),
646 _ => String::new(),
647 }
648 }
649
650 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
651 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
652 uri.as_link().to_string()
653 } else {
654 uri.to_string()
655 }
656 }
657
658 fn image_md(_image: &acp::ImageContent) -> String {
659 "`Image`".into()
660 }
661
662 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
663 match self {
664 ContentBlock::Empty => "",
665 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
666 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
667 ContentBlock::Image { .. } => "`Image`",
668 }
669 }
670
671 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
672 match self {
673 ContentBlock::Empty => None,
674 ContentBlock::Markdown { markdown } => Some(markdown),
675 ContentBlock::ResourceLink { .. } => None,
676 ContentBlock::Image { .. } => None,
677 }
678 }
679
680 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
681 match self {
682 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
683 _ => None,
684 }
685 }
686
687 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
688 match self {
689 ContentBlock::Image { image } => Some(image),
690 _ => None,
691 }
692 }
693}
694
695#[derive(Debug)]
696pub enum ToolCallContent {
697 ContentBlock(ContentBlock),
698 Diff(Entity<Diff>),
699 Terminal(Entity<Terminal>),
700}
701
702impl ToolCallContent {
703 pub fn from_acp(
704 content: acp::ToolCallContent,
705 language_registry: Arc<LanguageRegistry>,
706 path_style: PathStyle,
707 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
708 cx: &mut App,
709 ) -> Result<Option<Self>> {
710 match content {
711 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
712 Ok(Some(Self::ContentBlock(ContentBlock::new(
713 content,
714 &language_registry,
715 path_style,
716 cx,
717 ))))
718 }
719 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
720 Diff::finalized(
721 diff.path.to_string_lossy().into_owned(),
722 diff.old_text,
723 diff.new_text,
724 language_registry,
725 cx,
726 )
727 })))),
728 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
729 .get(&terminal_id)
730 .cloned()
731 .map(|terminal| Some(Self::Terminal(terminal)))
732 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
733 _ => Ok(None),
734 }
735 }
736
737 pub fn update_from_acp(
738 &mut self,
739 new: acp::ToolCallContent,
740 language_registry: Arc<LanguageRegistry>,
741 path_style: PathStyle,
742 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
743 cx: &mut App,
744 ) -> Result<bool> {
745 let needs_update = match (&self, &new) {
746 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
747 old_diff.read(cx).needs_update(
748 new_diff.old_text.as_deref().unwrap_or(""),
749 &new_diff.new_text,
750 cx,
751 )
752 }
753 _ => true,
754 };
755
756 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
757 if needs_update {
758 *self = update;
759 }
760 Ok(true)
761 } else {
762 Ok(false)
763 }
764 }
765
766 pub fn to_markdown(&self, cx: &App) -> String {
767 match self {
768 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
769 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
770 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
771 }
772 }
773
774 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
775 match self {
776 Self::ContentBlock(content) => content.image(),
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, PartialEq)]
783pub enum ToolCallUpdate {
784 UpdateFields(acp::ToolCallUpdate),
785 UpdateDiff(ToolCallUpdateDiff),
786 UpdateTerminal(ToolCallUpdateTerminal),
787}
788
789impl ToolCallUpdate {
790 fn id(&self) -> &acp::ToolCallId {
791 match self {
792 Self::UpdateFields(update) => &update.tool_call_id,
793 Self::UpdateDiff(diff) => &diff.id,
794 Self::UpdateTerminal(terminal) => &terminal.id,
795 }
796 }
797}
798
799impl From<acp::ToolCallUpdate> for ToolCallUpdate {
800 fn from(update: acp::ToolCallUpdate) -> Self {
801 Self::UpdateFields(update)
802 }
803}
804
805impl From<ToolCallUpdateDiff> for ToolCallUpdate {
806 fn from(diff: ToolCallUpdateDiff) -> Self {
807 Self::UpdateDiff(diff)
808 }
809}
810
811#[derive(Debug, PartialEq)]
812pub struct ToolCallUpdateDiff {
813 pub id: acp::ToolCallId,
814 pub diff: Entity<Diff>,
815}
816
817impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
818 fn from(terminal: ToolCallUpdateTerminal) -> Self {
819 Self::UpdateTerminal(terminal)
820 }
821}
822
823#[derive(Debug, PartialEq)]
824pub struct ToolCallUpdateTerminal {
825 pub id: acp::ToolCallId,
826 pub terminal: Entity<Terminal>,
827}
828
829#[derive(Debug, Default)]
830pub struct Plan {
831 pub entries: Vec<PlanEntry>,
832}
833
834#[derive(Debug)]
835pub struct PlanStats<'a> {
836 pub in_progress_entry: Option<&'a PlanEntry>,
837 pub pending: u32,
838 pub completed: u32,
839}
840
841impl Plan {
842 pub fn is_empty(&self) -> bool {
843 self.entries.is_empty()
844 }
845
846 pub fn stats(&self) -> PlanStats<'_> {
847 let mut stats = PlanStats {
848 in_progress_entry: None,
849 pending: 0,
850 completed: 0,
851 };
852
853 for entry in &self.entries {
854 match &entry.status {
855 acp::PlanEntryStatus::Pending => {
856 stats.pending += 1;
857 }
858 acp::PlanEntryStatus::InProgress => {
859 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
860 }
861 acp::PlanEntryStatus::Completed => {
862 stats.completed += 1;
863 }
864 _ => {}
865 }
866 }
867
868 stats
869 }
870}
871
872#[derive(Debug)]
873pub struct PlanEntry {
874 pub content: Entity<Markdown>,
875 pub priority: acp::PlanEntryPriority,
876 pub status: acp::PlanEntryStatus,
877}
878
879impl PlanEntry {
880 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
881 Self {
882 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
883 priority: entry.priority,
884 status: entry.status,
885 }
886 }
887}
888
889#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
890pub struct TokenUsage {
891 pub max_tokens: u64,
892 pub used_tokens: u64,
893 pub input_tokens: u64,
894 pub output_tokens: u64,
895 pub max_output_tokens: Option<u64>,
896}
897
898impl TokenUsage {
899 pub fn ratio(&self) -> TokenUsageRatio {
900 #[cfg(debug_assertions)]
901 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
902 .unwrap_or("0.8".to_string())
903 .parse()
904 .unwrap();
905 #[cfg(not(debug_assertions))]
906 let warning_threshold: f32 = 0.8;
907
908 // When the maximum is unknown because there is no selected model,
909 // avoid showing the token limit warning.
910 if self.max_tokens == 0 {
911 TokenUsageRatio::Normal
912 } else if self.used_tokens >= self.max_tokens {
913 TokenUsageRatio::Exceeded
914 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
915 TokenUsageRatio::Warning
916 } else {
917 TokenUsageRatio::Normal
918 }
919 }
920}
921
922#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
923pub enum TokenUsageRatio {
924 Normal,
925 Warning,
926 Exceeded,
927}
928
929#[derive(Debug, Clone)]
930pub struct RetryStatus {
931 pub last_error: SharedString,
932 pub attempt: usize,
933 pub max_attempts: usize,
934 pub started_at: Instant,
935 pub duration: Duration,
936}
937
938struct RunningTurn {
939 id: u32,
940 send_task: Task<()>,
941}
942
943pub struct AcpThread {
944 parent_session_id: Option<acp::SessionId>,
945 title: SharedString,
946 entries: Vec<AgentThreadEntry>,
947 plan: Plan,
948 project: Entity<Project>,
949 action_log: Entity<ActionLog>,
950 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
951 turn_id: u32,
952 running_turn: Option<RunningTurn>,
953 connection: Rc<dyn AgentConnection>,
954 session_id: acp::SessionId,
955 token_usage: Option<TokenUsage>,
956 prompt_capabilities: acp::PromptCapabilities,
957 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
958 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
959 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
960 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
961}
962
963impl From<&AcpThread> for ActionLogTelemetry {
964 fn from(value: &AcpThread) -> Self {
965 Self {
966 agent_telemetry_id: value.connection().telemetry_id(),
967 session_id: value.session_id.0.clone(),
968 }
969 }
970}
971
972#[derive(Debug)]
973pub enum AcpThreadEvent {
974 NewEntry,
975 TitleUpdated,
976 TokenUsageUpdated,
977 EntryUpdated(usize),
978 EntriesRemoved(Range<usize>),
979 ToolAuthorizationRequired,
980 Retry(RetryStatus),
981 SubagentSpawned(acp::SessionId),
982 Stopped,
983 Error,
984 LoadError(LoadError),
985 PromptCapabilitiesUpdated,
986 Refusal,
987 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
988 ModeUpdated(acp::SessionModeId),
989 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
990}
991
992impl EventEmitter<AcpThreadEvent> for AcpThread {}
993
994#[derive(Debug, Clone)]
995pub enum TerminalProviderEvent {
996 Created {
997 terminal_id: acp::TerminalId,
998 label: String,
999 cwd: Option<PathBuf>,
1000 output_byte_limit: Option<u64>,
1001 terminal: Entity<::terminal::Terminal>,
1002 },
1003 Output {
1004 terminal_id: acp::TerminalId,
1005 data: Vec<u8>,
1006 },
1007 TitleChanged {
1008 terminal_id: acp::TerminalId,
1009 title: String,
1010 },
1011 Exit {
1012 terminal_id: acp::TerminalId,
1013 status: acp::TerminalExitStatus,
1014 },
1015}
1016
1017#[derive(Debug, Clone)]
1018pub enum TerminalProviderCommand {
1019 WriteInput {
1020 terminal_id: acp::TerminalId,
1021 bytes: Vec<u8>,
1022 },
1023 Resize {
1024 terminal_id: acp::TerminalId,
1025 cols: u16,
1026 rows: u16,
1027 },
1028 Close {
1029 terminal_id: acp::TerminalId,
1030 },
1031}
1032
1033impl AcpThread {
1034 pub fn on_terminal_provider_event(
1035 &mut self,
1036 event: TerminalProviderEvent,
1037 cx: &mut Context<Self>,
1038 ) {
1039 match event {
1040 TerminalProviderEvent::Created {
1041 terminal_id,
1042 label,
1043 cwd,
1044 output_byte_limit,
1045 terminal,
1046 } => {
1047 let entity = self.register_terminal_created(
1048 terminal_id.clone(),
1049 label,
1050 cwd,
1051 output_byte_limit,
1052 terminal,
1053 cx,
1054 );
1055
1056 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1057 for data in chunks.drain(..) {
1058 entity.update(cx, |term, cx| {
1059 term.inner().update(cx, |inner, cx| {
1060 inner.write_output(&data, cx);
1061 })
1062 });
1063 }
1064 }
1065
1066 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1067 entity.update(cx, |_term, cx| {
1068 cx.notify();
1069 });
1070 }
1071
1072 cx.notify();
1073 }
1074 TerminalProviderEvent::Output { terminal_id, data } => {
1075 if let Some(entity) = self.terminals.get(&terminal_id) {
1076 entity.update(cx, |term, cx| {
1077 term.inner().update(cx, |inner, cx| {
1078 inner.write_output(&data, cx);
1079 })
1080 });
1081 } else {
1082 self.pending_terminal_output
1083 .entry(terminal_id)
1084 .or_default()
1085 .push(data);
1086 }
1087 }
1088 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1089 if let Some(entity) = self.terminals.get(&terminal_id) {
1090 entity.update(cx, |term, cx| {
1091 term.inner().update(cx, |inner, cx| {
1092 inner.breadcrumb_text = title;
1093 cx.emit(::terminal::Event::BreadcrumbsChanged);
1094 })
1095 });
1096 }
1097 }
1098 TerminalProviderEvent::Exit {
1099 terminal_id,
1100 status,
1101 } => {
1102 if let Some(entity) = self.terminals.get(&terminal_id) {
1103 entity.update(cx, |_term, cx| {
1104 cx.notify();
1105 });
1106 } else {
1107 self.pending_terminal_exit.insert(terminal_id, status);
1108 }
1109 }
1110 }
1111 }
1112}
1113
1114#[derive(PartialEq, Eq, Debug)]
1115pub enum ThreadStatus {
1116 Idle,
1117 Generating,
1118}
1119
1120#[derive(Debug, Clone)]
1121pub enum LoadError {
1122 Unsupported {
1123 command: SharedString,
1124 current_version: SharedString,
1125 minimum_version: SharedString,
1126 },
1127 FailedToInstall(SharedString),
1128 Exited {
1129 status: ExitStatus,
1130 },
1131 Other(SharedString),
1132}
1133
1134impl Display for LoadError {
1135 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1136 match self {
1137 LoadError::Unsupported {
1138 command: path,
1139 current_version,
1140 minimum_version,
1141 } => {
1142 write!(
1143 f,
1144 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1145 )
1146 }
1147 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1148 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1149 LoadError::Other(msg) => write!(f, "{msg}"),
1150 }
1151 }
1152}
1153
1154impl Error for LoadError {}
1155
1156impl AcpThread {
1157 pub fn new(
1158 parent_session_id: Option<acp::SessionId>,
1159 title: impl Into<SharedString>,
1160 connection: Rc<dyn AgentConnection>,
1161 project: Entity<Project>,
1162 action_log: Entity<ActionLog>,
1163 session_id: acp::SessionId,
1164 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1165 cx: &mut Context<Self>,
1166 ) -> Self {
1167 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1168 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1169 loop {
1170 let caps = prompt_capabilities_rx.recv().await?;
1171 this.update(cx, |this, cx| {
1172 this.prompt_capabilities = caps;
1173 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1174 })?;
1175 }
1176 });
1177
1178 Self {
1179 parent_session_id,
1180 action_log,
1181 shared_buffers: Default::default(),
1182 entries: Default::default(),
1183 plan: Default::default(),
1184 title: title.into(),
1185 project,
1186 running_turn: None,
1187 turn_id: 0,
1188 connection,
1189 session_id,
1190 token_usage: None,
1191 prompt_capabilities,
1192 _observe_prompt_capabilities: task,
1193 terminals: HashMap::default(),
1194 pending_terminal_output: HashMap::default(),
1195 pending_terminal_exit: HashMap::default(),
1196 }
1197 }
1198
1199 pub fn parent_session_id(&self) -> Option<&acp::SessionId> {
1200 self.parent_session_id.as_ref()
1201 }
1202
1203 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1204 self.prompt_capabilities.clone()
1205 }
1206
1207 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1208 &self.connection
1209 }
1210
1211 pub fn action_log(&self) -> &Entity<ActionLog> {
1212 &self.action_log
1213 }
1214
1215 pub fn project(&self) -> &Entity<Project> {
1216 &self.project
1217 }
1218
1219 pub fn title(&self) -> SharedString {
1220 self.title.clone()
1221 }
1222
1223 pub fn entries(&self) -> &[AgentThreadEntry] {
1224 &self.entries
1225 }
1226
1227 pub fn session_id(&self) -> &acp::SessionId {
1228 &self.session_id
1229 }
1230
1231 pub fn status(&self) -> ThreadStatus {
1232 if self.running_turn.is_some() {
1233 ThreadStatus::Generating
1234 } else {
1235 ThreadStatus::Idle
1236 }
1237 }
1238
1239 pub fn token_usage(&self) -> Option<&TokenUsage> {
1240 self.token_usage.as_ref()
1241 }
1242
1243 pub fn has_pending_edit_tool_calls(&self) -> bool {
1244 for entry in self.entries.iter().rev() {
1245 match entry {
1246 AgentThreadEntry::UserMessage(_) => return false,
1247 AgentThreadEntry::ToolCall(
1248 call @ ToolCall {
1249 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1250 ..
1251 },
1252 ) if call.diffs().next().is_some() => {
1253 return true;
1254 }
1255 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1256 }
1257 }
1258
1259 false
1260 }
1261
1262 pub fn has_in_progress_tool_calls(&self) -> bool {
1263 for entry in self.entries.iter().rev() {
1264 match entry {
1265 AgentThreadEntry::UserMessage(_) => return false,
1266 AgentThreadEntry::ToolCall(ToolCall {
1267 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1268 ..
1269 }) => {
1270 return true;
1271 }
1272 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1273 }
1274 }
1275
1276 false
1277 }
1278
1279 pub fn used_tools_since_last_user_message(&self) -> bool {
1280 for entry in self.entries.iter().rev() {
1281 match entry {
1282 AgentThreadEntry::UserMessage(..) => return false,
1283 AgentThreadEntry::AssistantMessage(..) => continue,
1284 AgentThreadEntry::ToolCall(..) => return true,
1285 }
1286 }
1287
1288 false
1289 }
1290
1291 pub fn handle_session_update(
1292 &mut self,
1293 update: acp::SessionUpdate,
1294 cx: &mut Context<Self>,
1295 ) -> Result<(), acp::Error> {
1296 match update {
1297 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1298 self.push_user_content_block(None, content, cx);
1299 }
1300 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1301 self.push_assistant_content_block(content, false, cx);
1302 }
1303 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1304 self.push_assistant_content_block(content, true, cx);
1305 }
1306 acp::SessionUpdate::ToolCall(tool_call) => {
1307 self.upsert_tool_call(tool_call, cx)?;
1308 }
1309 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1310 self.update_tool_call(tool_call_update, cx)?;
1311 }
1312 acp::SessionUpdate::Plan(plan) => {
1313 self.update_plan(plan, cx);
1314 }
1315 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1316 available_commands,
1317 ..
1318 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1319 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1320 current_mode_id,
1321 ..
1322 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1323 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1324 config_options,
1325 ..
1326 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1327 _ => {}
1328 }
1329 Ok(())
1330 }
1331
1332 pub fn push_user_content_block(
1333 &mut self,
1334 message_id: Option<UserMessageId>,
1335 chunk: acp::ContentBlock,
1336 cx: &mut Context<Self>,
1337 ) {
1338 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1339 }
1340
1341 pub fn push_user_content_block_with_indent(
1342 &mut self,
1343 message_id: Option<UserMessageId>,
1344 chunk: acp::ContentBlock,
1345 indented: bool,
1346 cx: &mut Context<Self>,
1347 ) {
1348 let language_registry = self.project.read(cx).languages().clone();
1349 let path_style = self.project.read(cx).path_style(cx);
1350 let entries_len = self.entries.len();
1351
1352 if let Some(last_entry) = self.entries.last_mut()
1353 && let AgentThreadEntry::UserMessage(UserMessage {
1354 id,
1355 content,
1356 chunks,
1357 indented: existing_indented,
1358 ..
1359 }) = last_entry
1360 && *existing_indented == indented
1361 {
1362 *id = message_id.or(id.take());
1363 content.append(chunk.clone(), &language_registry, path_style, cx);
1364 chunks.push(chunk);
1365 let idx = entries_len - 1;
1366 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1367 } else {
1368 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1369 self.push_entry(
1370 AgentThreadEntry::UserMessage(UserMessage {
1371 id: message_id,
1372 content,
1373 chunks: vec![chunk],
1374 checkpoint: None,
1375 indented,
1376 }),
1377 cx,
1378 );
1379 }
1380 }
1381
1382 pub fn push_assistant_content_block(
1383 &mut self,
1384 chunk: acp::ContentBlock,
1385 is_thought: bool,
1386 cx: &mut Context<Self>,
1387 ) {
1388 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1389 }
1390
1391 pub fn push_assistant_content_block_with_indent(
1392 &mut self,
1393 chunk: acp::ContentBlock,
1394 is_thought: bool,
1395 indented: bool,
1396 cx: &mut Context<Self>,
1397 ) {
1398 let language_registry = self.project.read(cx).languages().clone();
1399 let path_style = self.project.read(cx).path_style(cx);
1400 let entries_len = self.entries.len();
1401 if let Some(last_entry) = self.entries.last_mut()
1402 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1403 chunks,
1404 indented: existing_indented,
1405 }) = last_entry
1406 && *existing_indented == indented
1407 {
1408 let idx = entries_len - 1;
1409 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1410 match (chunks.last_mut(), is_thought) {
1411 (Some(AssistantMessageChunk::Message { block }), false)
1412 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1413 block.append(chunk, &language_registry, path_style, cx)
1414 }
1415 _ => {
1416 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1417 if is_thought {
1418 chunks.push(AssistantMessageChunk::Thought { block })
1419 } else {
1420 chunks.push(AssistantMessageChunk::Message { block })
1421 }
1422 }
1423 }
1424 } else {
1425 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1426 let chunk = if is_thought {
1427 AssistantMessageChunk::Thought { block }
1428 } else {
1429 AssistantMessageChunk::Message { block }
1430 };
1431
1432 self.push_entry(
1433 AgentThreadEntry::AssistantMessage(AssistantMessage {
1434 chunks: vec![chunk],
1435 indented,
1436 }),
1437 cx,
1438 );
1439 }
1440 }
1441
1442 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1443 self.entries.push(entry);
1444 cx.emit(AcpThreadEvent::NewEntry);
1445 }
1446
1447 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1448 self.connection.set_title(&self.session_id, cx).is_some()
1449 }
1450
1451 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1452 if title != self.title {
1453 self.title = title.clone();
1454 cx.emit(AcpThreadEvent::TitleUpdated);
1455 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1456 return set_title.run(title, cx);
1457 }
1458 }
1459 Task::ready(Ok(()))
1460 }
1461
1462 pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
1463 cx.emit(AcpThreadEvent::SubagentSpawned(session_id));
1464 }
1465
1466 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1467 self.token_usage = usage;
1468 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1469 }
1470
1471 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1472 cx.emit(AcpThreadEvent::Retry(status));
1473 }
1474
1475 pub fn update_tool_call(
1476 &mut self,
1477 update: impl Into<ToolCallUpdate>,
1478 cx: &mut Context<Self>,
1479 ) -> Result<()> {
1480 let update = update.into();
1481 let languages = self.project.read(cx).languages().clone();
1482 let path_style = self.project.read(cx).path_style(cx);
1483
1484 let ix = match self.index_for_tool_call(update.id()) {
1485 Some(ix) => ix,
1486 None => {
1487 // Tool call not found - create a failed tool call entry
1488 let failed_tool_call = ToolCall {
1489 id: update.id().clone(),
1490 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1491 kind: acp::ToolKind::Fetch,
1492 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1493 "Tool call not found".into(),
1494 &languages,
1495 path_style,
1496 cx,
1497 ))],
1498 status: ToolCallStatus::Failed,
1499 locations: Vec::new(),
1500 resolved_locations: Vec::new(),
1501 raw_input: None,
1502 raw_input_markdown: None,
1503 raw_output: None,
1504 tool_name: None,
1505 subagent_session_id: None,
1506 };
1507 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1508 return Ok(());
1509 }
1510 };
1511 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1512 unreachable!()
1513 };
1514
1515 match update {
1516 ToolCallUpdate::UpdateFields(update) => {
1517 let location_updated = update.fields.locations.is_some();
1518 call.update_fields(
1519 update.fields,
1520 update.meta,
1521 languages,
1522 path_style,
1523 &self.terminals,
1524 cx,
1525 )?;
1526 if location_updated {
1527 self.resolve_locations(update.tool_call_id, cx);
1528 }
1529 }
1530 ToolCallUpdate::UpdateDiff(update) => {
1531 call.content.clear();
1532 call.content.push(ToolCallContent::Diff(update.diff));
1533 }
1534 ToolCallUpdate::UpdateTerminal(update) => {
1535 call.content.clear();
1536 call.content
1537 .push(ToolCallContent::Terminal(update.terminal));
1538 }
1539 }
1540
1541 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1542
1543 Ok(())
1544 }
1545
1546 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1547 pub fn upsert_tool_call(
1548 &mut self,
1549 tool_call: acp::ToolCall,
1550 cx: &mut Context<Self>,
1551 ) -> Result<(), acp::Error> {
1552 let status = tool_call.status.into();
1553 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1554 }
1555
1556 /// Fails if id does not match an existing entry.
1557 pub fn upsert_tool_call_inner(
1558 &mut self,
1559 update: acp::ToolCallUpdate,
1560 status: ToolCallStatus,
1561 cx: &mut Context<Self>,
1562 ) -> Result<(), acp::Error> {
1563 let language_registry = self.project.read(cx).languages().clone();
1564 let path_style = self.project.read(cx).path_style(cx);
1565 let id = update.tool_call_id.clone();
1566
1567 let agent_telemetry_id = self.connection().telemetry_id();
1568 let session = self.session_id();
1569 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1570 let status = if matches!(status, ToolCallStatus::Completed) {
1571 "completed"
1572 } else {
1573 "failed"
1574 };
1575 telemetry::event!(
1576 "Agent Tool Call Completed",
1577 agent_telemetry_id,
1578 session,
1579 status
1580 );
1581 }
1582
1583 if let Some(ix) = self.index_for_tool_call(&id) {
1584 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1585 unreachable!()
1586 };
1587
1588 call.update_fields(
1589 update.fields,
1590 update.meta,
1591 language_registry,
1592 path_style,
1593 &self.terminals,
1594 cx,
1595 )?;
1596 call.status = status;
1597
1598 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1599 } else {
1600 let call = ToolCall::from_acp(
1601 update.try_into()?,
1602 status,
1603 language_registry,
1604 self.project.read(cx).path_style(cx),
1605 &self.terminals,
1606 cx,
1607 )?;
1608 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1609 };
1610
1611 self.resolve_locations(id, cx);
1612 Ok(())
1613 }
1614
1615 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1616 self.entries
1617 .iter()
1618 .enumerate()
1619 .rev()
1620 .find_map(|(index, entry)| {
1621 if let AgentThreadEntry::ToolCall(tool_call) = entry
1622 && &tool_call.id == id
1623 {
1624 Some(index)
1625 } else {
1626 None
1627 }
1628 })
1629 }
1630
1631 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1632 // The tool call we are looking for is typically the last one, or very close to the end.
1633 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1634 self.entries
1635 .iter_mut()
1636 .enumerate()
1637 .rev()
1638 .find_map(|(index, tool_call)| {
1639 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1640 && &tool_call.id == id
1641 {
1642 Some((index, tool_call))
1643 } else {
1644 None
1645 }
1646 })
1647 }
1648
1649 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1650 self.entries
1651 .iter()
1652 .enumerate()
1653 .rev()
1654 .find_map(|(index, tool_call)| {
1655 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1656 && &tool_call.id == id
1657 {
1658 Some((index, tool_call))
1659 } else {
1660 None
1661 }
1662 })
1663 }
1664
1665 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1666 let project = self.project.clone();
1667 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1668 return;
1669 };
1670 let task = tool_call.resolve_locations(project, cx);
1671 cx.spawn(async move |this, cx| {
1672 let resolved_locations = task.await;
1673
1674 this.update(cx, |this, cx| {
1675 let project = this.project.clone();
1676
1677 for location in resolved_locations.iter().flatten() {
1678 this.shared_buffers
1679 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1680 }
1681 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1682 return;
1683 };
1684
1685 if let Some(Some(location)) = resolved_locations.last() {
1686 project.update(cx, |project, cx| {
1687 let should_ignore = if let Some(agent_location) = project
1688 .agent_location()
1689 .filter(|agent_location| agent_location.buffer == location.buffer)
1690 {
1691 let snapshot = location.buffer.read(cx).snapshot();
1692 let old_position = agent_location.position.to_point(&snapshot);
1693 let new_position = location.position.to_point(&snapshot);
1694
1695 // ignore this so that when we get updates from the edit tool
1696 // the position doesn't reset to the startof line
1697 old_position.row == new_position.row
1698 && old_position.column > new_position.column
1699 } else {
1700 false
1701 };
1702 if !should_ignore {
1703 project.set_agent_location(Some(location.into()), cx);
1704 }
1705 });
1706 }
1707
1708 let resolved_locations = resolved_locations
1709 .iter()
1710 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1711 .collect::<Vec<_>>();
1712
1713 if tool_call.resolved_locations != resolved_locations {
1714 tool_call.resolved_locations = resolved_locations;
1715 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1716 }
1717 })
1718 })
1719 .detach();
1720 }
1721
1722 pub fn request_tool_call_authorization(
1723 &mut self,
1724 tool_call: acp::ToolCallUpdate,
1725 options: PermissionOptions,
1726 cx: &mut Context<Self>,
1727 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1728 let (tx, rx) = oneshot::channel();
1729
1730 let status = ToolCallStatus::WaitingForConfirmation {
1731 options,
1732 respond_tx: tx,
1733 };
1734
1735 self.upsert_tool_call_inner(tool_call, status, cx)?;
1736 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1737
1738 let fut = async {
1739 match rx.await {
1740 Ok(option) => acp::RequestPermissionOutcome::Selected(
1741 acp::SelectedPermissionOutcome::new(option),
1742 ),
1743 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1744 }
1745 }
1746 .boxed();
1747
1748 Ok(fut)
1749 }
1750
1751 pub fn authorize_tool_call(
1752 &mut self,
1753 id: acp::ToolCallId,
1754 option_id: acp::PermissionOptionId,
1755 option_kind: acp::PermissionOptionKind,
1756 cx: &mut Context<Self>,
1757 ) {
1758 let Some((ix, call)) = self.tool_call_mut(&id) else {
1759 return;
1760 };
1761
1762 let new_status = match option_kind {
1763 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1764 ToolCallStatus::Rejected
1765 }
1766 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1767 ToolCallStatus::InProgress
1768 }
1769 _ => ToolCallStatus::InProgress,
1770 };
1771
1772 let curr_status = mem::replace(&mut call.status, new_status);
1773
1774 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1775 respond_tx.send(option_id).log_err();
1776 } else if cfg!(debug_assertions) {
1777 panic!("tried to authorize an already authorized tool call");
1778 }
1779
1780 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1781 }
1782
1783 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1784 let mut first_tool_call = None;
1785
1786 for entry in self.entries.iter().rev() {
1787 match &entry {
1788 AgentThreadEntry::ToolCall(call) => {
1789 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1790 first_tool_call = Some(call);
1791 } else {
1792 continue;
1793 }
1794 }
1795 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1796 // Reached the beginning of the turn.
1797 // If we had pending permission requests in the previous turn, they have been cancelled.
1798 break;
1799 }
1800 }
1801 }
1802
1803 first_tool_call
1804 }
1805
1806 pub fn plan(&self) -> &Plan {
1807 &self.plan
1808 }
1809
1810 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1811 let new_entries_len = request.entries.len();
1812 let mut new_entries = request.entries.into_iter();
1813
1814 // Reuse existing markdown to prevent flickering
1815 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1816 let PlanEntry {
1817 content,
1818 priority,
1819 status,
1820 } = old;
1821 content.update(cx, |old, cx| {
1822 old.replace(new.content, cx);
1823 });
1824 *priority = new.priority;
1825 *status = new.status;
1826 }
1827 for new in new_entries {
1828 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1829 }
1830 self.plan.entries.truncate(new_entries_len);
1831
1832 cx.notify();
1833 }
1834
1835 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1836 self.plan
1837 .entries
1838 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1839 cx.notify();
1840 }
1841
1842 #[cfg(any(test, feature = "test-support"))]
1843 pub fn send_raw(
1844 &mut self,
1845 message: &str,
1846 cx: &mut Context<Self>,
1847 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1848 self.send(vec![message.into()], cx)
1849 }
1850
1851 pub fn send(
1852 &mut self,
1853 message: Vec<acp::ContentBlock>,
1854 cx: &mut Context<Self>,
1855 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1856 let block = ContentBlock::new_combined(
1857 message.clone(),
1858 self.project.read(cx).languages().clone(),
1859 self.project.read(cx).path_style(cx),
1860 cx,
1861 );
1862 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1863 let git_store = self.project.read(cx).git_store().clone();
1864
1865 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1866 Some(UserMessageId::new())
1867 } else {
1868 None
1869 };
1870
1871 self.run_turn(cx, async move |this, cx| {
1872 this.update(cx, |this, cx| {
1873 this.push_entry(
1874 AgentThreadEntry::UserMessage(UserMessage {
1875 id: message_id.clone(),
1876 content: block,
1877 chunks: message,
1878 checkpoint: None,
1879 indented: false,
1880 }),
1881 cx,
1882 );
1883 })
1884 .ok();
1885
1886 let old_checkpoint = git_store
1887 .update(cx, |git, cx| git.checkpoint(cx))
1888 .await
1889 .context("failed to get old checkpoint")
1890 .log_err();
1891 this.update(cx, |this, cx| {
1892 if let Some((_ix, message)) = this.last_user_message() {
1893 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1894 git_checkpoint,
1895 show: false,
1896 });
1897 }
1898 this.connection.prompt(message_id, request, cx)
1899 })?
1900 .await
1901 })
1902 }
1903
1904 pub fn can_retry(&self, cx: &App) -> bool {
1905 self.connection.retry(&self.session_id, cx).is_some()
1906 }
1907
1908 pub fn retry(
1909 &mut self,
1910 cx: &mut Context<Self>,
1911 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1912 self.run_turn(cx, async move |this, cx| {
1913 this.update(cx, |this, cx| {
1914 this.connection
1915 .retry(&this.session_id, cx)
1916 .map(|retry| retry.run(cx))
1917 })?
1918 .context("retrying a session is not supported")?
1919 .await
1920 })
1921 }
1922
1923 fn run_turn(
1924 &mut self,
1925 cx: &mut Context<Self>,
1926 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1927 ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
1928 self.clear_completed_plan_entries(cx);
1929
1930 let (tx, rx) = oneshot::channel();
1931 let cancel_task = self.cancel(cx);
1932
1933 self.turn_id += 1;
1934 let turn_id = self.turn_id;
1935 self.running_turn = Some(RunningTurn {
1936 id: turn_id,
1937 send_task: cx.spawn(async move |this, cx| {
1938 cancel_task.await;
1939 tx.send(f(this, cx).await).ok();
1940 }),
1941 });
1942
1943 cx.spawn(async move |this, cx| {
1944 let response = rx.await;
1945
1946 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1947 .await?;
1948
1949 this.update(cx, |this, cx| {
1950 this.project
1951 .update(cx, |project, cx| project.set_agent_location(None, cx));
1952
1953 let Ok(response) = response else {
1954 // tx dropped, just return
1955 return Ok(None);
1956 };
1957
1958 let is_same_turn = this
1959 .running_turn
1960 .as_ref()
1961 .is_some_and(|turn| turn_id == turn.id);
1962
1963 // If the user submitted a follow up message, running_turn might
1964 // already point to a different turn. Therefore we only want to
1965 // take the task if it's the same turn.
1966 if is_same_turn {
1967 this.running_turn.take();
1968 }
1969
1970 match response {
1971 Ok(r) => {
1972 if r.stop_reason == acp::StopReason::MaxTokens {
1973 cx.emit(AcpThreadEvent::Error);
1974 log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
1975 return Err(anyhow!("Max tokens reached"));
1976 }
1977
1978 let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
1979 if canceled {
1980 this.mark_pending_tools_as_canceled();
1981 }
1982
1983 // Handle refusal - distinguish between user prompt and tool call refusals
1984 if let acp::StopReason::Refusal = r.stop_reason {
1985 if let Some((user_msg_ix, _)) = this.last_user_message() {
1986 // Check if there's a completed tool call with results after the last user message
1987 // This indicates the refusal is in response to tool output, not the user's prompt
1988 let has_completed_tool_call_after_user_msg =
1989 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1990 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1991 // Check if the tool call has completed and has output
1992 matches!(tool_call.status, ToolCallStatus::Completed)
1993 && tool_call.raw_output.is_some()
1994 } else {
1995 false
1996 }
1997 });
1998
1999 if has_completed_tool_call_after_user_msg {
2000 // Refusal is due to tool output - don't truncate, just notify
2001 // The model refused based on what the tool returned
2002 cx.emit(AcpThreadEvent::Refusal);
2003 } else {
2004 // User prompt was refused - truncate back to before the user message
2005 let range = user_msg_ix..this.entries.len();
2006 if range.start < range.end {
2007 this.entries.truncate(user_msg_ix);
2008 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2009 }
2010 cx.emit(AcpThreadEvent::Refusal);
2011 }
2012 } else {
2013 // No user message found, treat as general refusal
2014 cx.emit(AcpThreadEvent::Refusal);
2015 }
2016 }
2017
2018 cx.emit(AcpThreadEvent::Stopped);
2019 Ok(Some(r))
2020 }
2021 Err(e) => {
2022 cx.emit(AcpThreadEvent::Error);
2023 log::error!("Error in run turn: {:?}", e);
2024 Err(e)
2025 }
2026 }
2027 })?
2028 })
2029 .boxed()
2030 }
2031
2032 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2033 let Some(turn) = self.running_turn.take() else {
2034 return Task::ready(());
2035 };
2036 self.connection.cancel(&self.session_id, cx);
2037
2038 self.mark_pending_tools_as_canceled();
2039
2040 // Wait for the send task to complete
2041 cx.background_spawn(turn.send_task)
2042 }
2043
2044 fn mark_pending_tools_as_canceled(&mut self) {
2045 for entry in self.entries.iter_mut() {
2046 if let AgentThreadEntry::ToolCall(call) = entry {
2047 let cancel = matches!(
2048 call.status,
2049 ToolCallStatus::Pending
2050 | ToolCallStatus::WaitingForConfirmation { .. }
2051 | ToolCallStatus::InProgress
2052 );
2053
2054 if cancel {
2055 call.status = ToolCallStatus::Canceled;
2056 }
2057 }
2058 }
2059 }
2060
2061 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2062 pub fn restore_checkpoint(
2063 &mut self,
2064 id: UserMessageId,
2065 cx: &mut Context<Self>,
2066 ) -> Task<Result<()>> {
2067 let Some((_, message)) = self.user_message_mut(&id) else {
2068 return Task::ready(Err(anyhow!("message not found")));
2069 };
2070
2071 let checkpoint = message
2072 .checkpoint
2073 .as_ref()
2074 .map(|c| c.git_checkpoint.clone());
2075
2076 // Cancel any in-progress generation before restoring
2077 let cancel_task = self.cancel(cx);
2078 let rewind = self.rewind(id.clone(), cx);
2079 let git_store = self.project.read(cx).git_store().clone();
2080
2081 cx.spawn(async move |_, cx| {
2082 cancel_task.await;
2083 rewind.await?;
2084 if let Some(checkpoint) = checkpoint {
2085 git_store
2086 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2087 .await?;
2088 }
2089
2090 Ok(())
2091 })
2092 }
2093
2094 /// Rewinds this thread to before the entry at `index`, removing it and all
2095 /// subsequent entries while rejecting any action_log changes made from that point.
2096 /// Unlike `restore_checkpoint`, this method does not restore from git.
2097 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2098 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2099 return Task::ready(Err(anyhow!("not supported")));
2100 };
2101
2102 let telemetry = ActionLogTelemetry::from(&*self);
2103 cx.spawn(async move |this, cx| {
2104 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2105 this.update(cx, |this, cx| {
2106 if let Some((ix, _)) = this.user_message_mut(&id) {
2107 // Collect all terminals from entries that will be removed
2108 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2109 .iter()
2110 .flat_map(|entry| entry.terminals())
2111 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2112 .collect();
2113
2114 let range = ix..this.entries.len();
2115 this.entries.truncate(ix);
2116 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2117
2118 // Kill and remove the terminals
2119 for terminal_id in terminals_to_remove {
2120 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2121 terminal.update(cx, |terminal, cx| {
2122 terminal.kill(cx);
2123 });
2124 }
2125 }
2126 }
2127 this.action_log().update(cx, |action_log, cx| {
2128 action_log.reject_all_edits(Some(telemetry), cx)
2129 })
2130 })?
2131 .await;
2132 Ok(())
2133 })
2134 }
2135
2136 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2137 let git_store = self.project.read(cx).git_store().clone();
2138
2139 let Some((_, message)) = self.last_user_message() else {
2140 return Task::ready(Ok(()));
2141 };
2142 let Some(user_message_id) = message.id.clone() else {
2143 return Task::ready(Ok(()));
2144 };
2145 let Some(checkpoint) = message.checkpoint.as_ref() else {
2146 return Task::ready(Ok(()));
2147 };
2148 let old_checkpoint = checkpoint.git_checkpoint.clone();
2149
2150 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2151 cx.spawn(async move |this, cx| {
2152 let Some(new_checkpoint) = new_checkpoint
2153 .await
2154 .context("failed to get new checkpoint")
2155 .log_err()
2156 else {
2157 return Ok(());
2158 };
2159
2160 let equal = git_store
2161 .update(cx, |git, cx| {
2162 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2163 })
2164 .await
2165 .unwrap_or(true);
2166
2167 this.update(cx, |this, cx| {
2168 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2169 if let Some(checkpoint) = message.checkpoint.as_mut() {
2170 checkpoint.show = !equal;
2171 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2172 }
2173 }
2174 })?;
2175
2176 Ok(())
2177 })
2178 }
2179
2180 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2181 self.entries
2182 .iter_mut()
2183 .enumerate()
2184 .rev()
2185 .find_map(|(ix, entry)| {
2186 if let AgentThreadEntry::UserMessage(message) = entry {
2187 Some((ix, message))
2188 } else {
2189 None
2190 }
2191 })
2192 }
2193
2194 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2195 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2196 if let AgentThreadEntry::UserMessage(message) = entry {
2197 if message.id.as_ref() == Some(id) {
2198 Some((ix, message))
2199 } else {
2200 None
2201 }
2202 } else {
2203 None
2204 }
2205 })
2206 }
2207
2208 pub fn read_text_file(
2209 &self,
2210 path: PathBuf,
2211 line: Option<u32>,
2212 limit: Option<u32>,
2213 reuse_shared_snapshot: bool,
2214 cx: &mut Context<Self>,
2215 ) -> Task<Result<String, acp::Error>> {
2216 // Args are 1-based, move to 0-based
2217 let line = line.unwrap_or_default().saturating_sub(1);
2218 let limit = limit.unwrap_or(u32::MAX);
2219 let project = self.project.clone();
2220 let action_log = self.action_log.clone();
2221 cx.spawn(async move |this, cx| {
2222 let load = project.update(cx, |project, cx| {
2223 let path = project
2224 .project_path_for_absolute_path(&path, cx)
2225 .ok_or_else(|| {
2226 acp::Error::resource_not_found(Some(path.display().to_string()))
2227 })?;
2228 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2229 })?;
2230
2231 let buffer = load.await?;
2232
2233 let snapshot = if reuse_shared_snapshot {
2234 this.read_with(cx, |this, _| {
2235 this.shared_buffers.get(&buffer.clone()).cloned()
2236 })
2237 .log_err()
2238 .flatten()
2239 } else {
2240 None
2241 };
2242
2243 let snapshot = if let Some(snapshot) = snapshot {
2244 snapshot
2245 } else {
2246 action_log.update(cx, |action_log, cx| {
2247 action_log.buffer_read(buffer.clone(), cx);
2248 });
2249
2250 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2251 this.update(cx, |this, _| {
2252 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2253 })?;
2254 snapshot
2255 };
2256
2257 let max_point = snapshot.max_point();
2258 let start_position = Point::new(line, 0);
2259
2260 if start_position > max_point {
2261 return Err(acp::Error::invalid_params().data(format!(
2262 "Attempting to read beyond the end of the file, line {}:{}",
2263 max_point.row + 1,
2264 max_point.column
2265 )));
2266 }
2267
2268 let start = snapshot.anchor_before(start_position);
2269 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2270
2271 project.update(cx, |project, cx| {
2272 project.set_agent_location(
2273 Some(AgentLocation {
2274 buffer: buffer.downgrade(),
2275 position: start,
2276 }),
2277 cx,
2278 );
2279 });
2280
2281 Ok(snapshot.text_for_range(start..end).collect::<String>())
2282 })
2283 }
2284
2285 pub fn write_text_file(
2286 &self,
2287 path: PathBuf,
2288 content: String,
2289 cx: &mut Context<Self>,
2290 ) -> Task<Result<()>> {
2291 let project = self.project.clone();
2292 let action_log = self.action_log.clone();
2293 cx.spawn(async move |this, cx| {
2294 let load = project.update(cx, |project, cx| {
2295 let path = project
2296 .project_path_for_absolute_path(&path, cx)
2297 .context("invalid path")?;
2298 anyhow::Ok(project.open_buffer(path, cx))
2299 });
2300 let buffer = load?.await?;
2301 let snapshot = this.update(cx, |this, cx| {
2302 this.shared_buffers
2303 .get(&buffer)
2304 .cloned()
2305 .unwrap_or_else(|| buffer.read(cx).snapshot())
2306 })?;
2307 let edits = cx
2308 .background_executor()
2309 .spawn(async move {
2310 let old_text = snapshot.text();
2311 text_diff(old_text.as_str(), &content)
2312 .into_iter()
2313 .map(|(range, replacement)| {
2314 (
2315 snapshot.anchor_after(range.start)
2316 ..snapshot.anchor_before(range.end),
2317 replacement,
2318 )
2319 })
2320 .collect::<Vec<_>>()
2321 })
2322 .await;
2323
2324 project.update(cx, |project, cx| {
2325 project.set_agent_location(
2326 Some(AgentLocation {
2327 buffer: buffer.downgrade(),
2328 position: edits
2329 .last()
2330 .map(|(range, _)| range.end)
2331 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2332 }),
2333 cx,
2334 );
2335 });
2336
2337 let format_on_save = cx.update(|cx| {
2338 action_log.update(cx, |action_log, cx| {
2339 action_log.buffer_read(buffer.clone(), cx);
2340 });
2341
2342 let format_on_save = buffer.update(cx, |buffer, cx| {
2343 buffer.edit(edits, None, cx);
2344
2345 let settings = language::language_settings::language_settings(
2346 buffer.language().map(|l| l.name()),
2347 buffer.file(),
2348 cx,
2349 );
2350
2351 settings.format_on_save != FormatOnSave::Off
2352 });
2353 action_log.update(cx, |action_log, cx| {
2354 action_log.buffer_edited(buffer.clone(), cx);
2355 });
2356 format_on_save
2357 });
2358
2359 if format_on_save {
2360 let format_task = project.update(cx, |project, cx| {
2361 project.format(
2362 HashSet::from_iter([buffer.clone()]),
2363 LspFormatTarget::Buffers,
2364 false,
2365 FormatTrigger::Save,
2366 cx,
2367 )
2368 });
2369 format_task.await.log_err();
2370
2371 action_log.update(cx, |action_log, cx| {
2372 action_log.buffer_edited(buffer.clone(), cx);
2373 });
2374 }
2375
2376 project
2377 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2378 .await
2379 })
2380 }
2381
2382 pub fn create_terminal(
2383 &self,
2384 command: String,
2385 args: Vec<String>,
2386 extra_env: Vec<acp::EnvVariable>,
2387 cwd: Option<PathBuf>,
2388 output_byte_limit: Option<u64>,
2389 cx: &mut Context<Self>,
2390 ) -> Task<Result<Entity<Terminal>>> {
2391 let env = match &cwd {
2392 Some(dir) => self.project.update(cx, |project, cx| {
2393 project.environment().update(cx, |env, cx| {
2394 env.directory_environment(dir.as_path().into(), cx)
2395 })
2396 }),
2397 None => Task::ready(None).shared(),
2398 };
2399 let env = cx.spawn(async move |_, _| {
2400 let mut env = env.await.unwrap_or_default();
2401 // Disables paging for `git` and hopefully other commands
2402 env.insert("PAGER".into(), "".into());
2403 for var in extra_env {
2404 env.insert(var.name, var.value);
2405 }
2406 env
2407 });
2408
2409 let project = self.project.clone();
2410 let language_registry = project.read(cx).languages().clone();
2411 let is_windows = project.read(cx).path_style(cx).is_windows();
2412
2413 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2414 let terminal_task = cx.spawn({
2415 let terminal_id = terminal_id.clone();
2416 async move |_this, cx| {
2417 let env = env.await;
2418 let shell = project
2419 .update(cx, |project, cx| {
2420 project
2421 .remote_client()
2422 .and_then(|r| r.read(cx).default_system_shell())
2423 })
2424 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2425 let (task_command, task_args) =
2426 ShellBuilder::new(&Shell::Program(shell), is_windows)
2427 .redirect_stdin_to_dev_null()
2428 .build(Some(command.clone()), &args);
2429 let terminal = project
2430 .update(cx, |project, cx| {
2431 project.create_terminal_task(
2432 task::SpawnInTerminal {
2433 command: Some(task_command),
2434 args: task_args,
2435 cwd: cwd.clone(),
2436 env,
2437 ..Default::default()
2438 },
2439 cx,
2440 )
2441 })
2442 .await?;
2443
2444 anyhow::Ok(cx.new(|cx| {
2445 Terminal::new(
2446 terminal_id,
2447 &format!("{} {}", command, args.join(" ")),
2448 cwd,
2449 output_byte_limit.map(|l| l as usize),
2450 terminal,
2451 language_registry,
2452 cx,
2453 )
2454 }))
2455 }
2456 });
2457
2458 cx.spawn(async move |this, cx| {
2459 let terminal = terminal_task.await?;
2460 this.update(cx, |this, _cx| {
2461 this.terminals.insert(terminal_id, terminal.clone());
2462 terminal
2463 })
2464 })
2465 }
2466
2467 pub fn kill_terminal(
2468 &mut self,
2469 terminal_id: acp::TerminalId,
2470 cx: &mut Context<Self>,
2471 ) -> Result<()> {
2472 self.terminals
2473 .get(&terminal_id)
2474 .context("Terminal not found")?
2475 .update(cx, |terminal, cx| {
2476 terminal.kill(cx);
2477 });
2478
2479 Ok(())
2480 }
2481
2482 pub fn release_terminal(
2483 &mut self,
2484 terminal_id: acp::TerminalId,
2485 cx: &mut Context<Self>,
2486 ) -> Result<()> {
2487 self.terminals
2488 .remove(&terminal_id)
2489 .context("Terminal not found")?
2490 .update(cx, |terminal, cx| {
2491 terminal.kill(cx);
2492 });
2493
2494 Ok(())
2495 }
2496
2497 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2498 self.terminals
2499 .get(&terminal_id)
2500 .context("Terminal not found")
2501 .cloned()
2502 }
2503
2504 pub fn to_markdown(&self, cx: &App) -> String {
2505 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2506 }
2507
2508 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2509 cx.emit(AcpThreadEvent::LoadError(error));
2510 }
2511
2512 pub fn register_terminal_created(
2513 &mut self,
2514 terminal_id: acp::TerminalId,
2515 command_label: String,
2516 working_dir: Option<PathBuf>,
2517 output_byte_limit: Option<u64>,
2518 terminal: Entity<::terminal::Terminal>,
2519 cx: &mut Context<Self>,
2520 ) -> Entity<Terminal> {
2521 let language_registry = self.project.read(cx).languages().clone();
2522
2523 let entity = cx.new(|cx| {
2524 Terminal::new(
2525 terminal_id.clone(),
2526 &command_label,
2527 working_dir.clone(),
2528 output_byte_limit.map(|l| l as usize),
2529 terminal,
2530 language_registry,
2531 cx,
2532 )
2533 });
2534 self.terminals.insert(terminal_id.clone(), entity.clone());
2535 entity
2536 }
2537}
2538
2539fn markdown_for_raw_output(
2540 raw_output: &serde_json::Value,
2541 language_registry: &Arc<LanguageRegistry>,
2542 cx: &mut App,
2543) -> Option<Entity<Markdown>> {
2544 match raw_output {
2545 serde_json::Value::Null => None,
2546 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2547 Markdown::new(
2548 value.to_string().into(),
2549 Some(language_registry.clone()),
2550 None,
2551 cx,
2552 )
2553 })),
2554 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2555 Markdown::new(
2556 value.to_string().into(),
2557 Some(language_registry.clone()),
2558 None,
2559 cx,
2560 )
2561 })),
2562 serde_json::Value::String(value) => Some(cx.new(|cx| {
2563 Markdown::new(
2564 value.clone().into(),
2565 Some(language_registry.clone()),
2566 None,
2567 cx,
2568 )
2569 })),
2570 value => Some(cx.new(|cx| {
2571 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2572
2573 Markdown::new(
2574 format!("```json\n{}\n```", pretty_json).into(),
2575 Some(language_registry.clone()),
2576 None,
2577 cx,
2578 )
2579 })),
2580 }
2581}
2582
2583#[cfg(test)]
2584mod tests {
2585 use super::*;
2586 use anyhow::anyhow;
2587 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2588 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2589 use indoc::indoc;
2590 use project::{FakeFs, Fs};
2591 use rand::{distr, prelude::*};
2592 use serde_json::json;
2593 use settings::SettingsStore;
2594 use smol::stream::StreamExt as _;
2595 use std::{
2596 any::Any,
2597 cell::RefCell,
2598 path::Path,
2599 rc::Rc,
2600 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2601 time::Duration,
2602 };
2603 use util::path;
2604
2605 fn init_test(cx: &mut TestAppContext) {
2606 env_logger::try_init().ok();
2607 cx.update(|cx| {
2608 let settings_store = SettingsStore::test(cx);
2609 cx.set_global(settings_store);
2610 });
2611 }
2612
2613 #[gpui::test]
2614 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2615 init_test(cx);
2616
2617 let fs = FakeFs::new(cx.executor());
2618 let project = Project::test(fs, [], cx).await;
2619 let connection = Rc::new(FakeAgentConnection::new());
2620 let thread = cx
2621 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2622 .await
2623 .unwrap();
2624
2625 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2626
2627 // Send Output BEFORE Created - should be buffered by acp_thread
2628 thread.update(cx, |thread, cx| {
2629 thread.on_terminal_provider_event(
2630 TerminalProviderEvent::Output {
2631 terminal_id: terminal_id.clone(),
2632 data: b"hello buffered".to_vec(),
2633 },
2634 cx,
2635 );
2636 });
2637
2638 // Create a display-only terminal and then send Created
2639 let lower = cx.new(|cx| {
2640 let builder = ::terminal::TerminalBuilder::new_display_only(
2641 ::terminal::terminal_settings::CursorShape::default(),
2642 ::terminal::terminal_settings::AlternateScroll::On,
2643 None,
2644 0,
2645 cx.background_executor(),
2646 PathStyle::local(),
2647 )
2648 .unwrap();
2649 builder.subscribe(cx)
2650 });
2651
2652 thread.update(cx, |thread, cx| {
2653 thread.on_terminal_provider_event(
2654 TerminalProviderEvent::Created {
2655 terminal_id: terminal_id.clone(),
2656 label: "Buffered Test".to_string(),
2657 cwd: None,
2658 output_byte_limit: None,
2659 terminal: lower.clone(),
2660 },
2661 cx,
2662 );
2663 });
2664
2665 // After Created, buffered Output should have been flushed into the renderer
2666 let content = thread.read_with(cx, |thread, cx| {
2667 let term = thread.terminal(terminal_id.clone()).unwrap();
2668 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2669 });
2670
2671 assert!(
2672 content.contains("hello buffered"),
2673 "expected buffered output to render, got: {content}"
2674 );
2675 }
2676
2677 #[gpui::test]
2678 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2679 init_test(cx);
2680
2681 let fs = FakeFs::new(cx.executor());
2682 let project = Project::test(fs, [], cx).await;
2683 let connection = Rc::new(FakeAgentConnection::new());
2684 let thread = cx
2685 .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
2686 .await
2687 .unwrap();
2688
2689 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2690
2691 // Send Output BEFORE Created
2692 thread.update(cx, |thread, cx| {
2693 thread.on_terminal_provider_event(
2694 TerminalProviderEvent::Output {
2695 terminal_id: terminal_id.clone(),
2696 data: b"pre-exit data".to_vec(),
2697 },
2698 cx,
2699 );
2700 });
2701
2702 // Send Exit BEFORE Created
2703 thread.update(cx, |thread, cx| {
2704 thread.on_terminal_provider_event(
2705 TerminalProviderEvent::Exit {
2706 terminal_id: terminal_id.clone(),
2707 status: acp::TerminalExitStatus::new().exit_code(0),
2708 },
2709 cx,
2710 );
2711 });
2712
2713 // Now create a display-only lower-level terminal and send Created
2714 let lower = cx.new(|cx| {
2715 let builder = ::terminal::TerminalBuilder::new_display_only(
2716 ::terminal::terminal_settings::CursorShape::default(),
2717 ::terminal::terminal_settings::AlternateScroll::On,
2718 None,
2719 0,
2720 cx.background_executor(),
2721 PathStyle::local(),
2722 )
2723 .unwrap();
2724 builder.subscribe(cx)
2725 });
2726
2727 thread.update(cx, |thread, cx| {
2728 thread.on_terminal_provider_event(
2729 TerminalProviderEvent::Created {
2730 terminal_id: terminal_id.clone(),
2731 label: "Buffered Exit Test".to_string(),
2732 cwd: None,
2733 output_byte_limit: None,
2734 terminal: lower.clone(),
2735 },
2736 cx,
2737 );
2738 });
2739
2740 // Output should be present after Created (flushed from buffer)
2741 let content = thread.read_with(cx, |thread, cx| {
2742 let term = thread.terminal(terminal_id.clone()).unwrap();
2743 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2744 });
2745
2746 assert!(
2747 content.contains("pre-exit data"),
2748 "expected pre-exit data to render, got: {content}"
2749 );
2750 }
2751
2752 /// Test that killing a terminal via Terminal::kill properly:
2753 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2754 /// 2. The underlying terminal still has the output that was written before the kill
2755 ///
2756 /// This test verifies that the fix to kill_active_task (which now also kills
2757 /// the shell process in addition to the foreground process) properly allows
2758 /// wait_for_exit to complete instead of hanging indefinitely.
2759 #[cfg(unix)]
2760 #[gpui::test]
2761 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2762 use std::collections::HashMap;
2763 use task::Shell;
2764 use util::shell_builder::ShellBuilder;
2765
2766 init_test(cx);
2767 cx.executor().allow_parking();
2768
2769 let fs = FakeFs::new(cx.executor());
2770 let project = Project::test(fs, [], cx).await;
2771 let connection = Rc::new(FakeAgentConnection::new());
2772 let thread = cx
2773 .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
2774 .await
2775 .unwrap();
2776
2777 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2778
2779 // Create a real PTY terminal that runs a command which prints output then sleeps
2780 // We use printf instead of echo and chain with && sleep to ensure proper execution
2781 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2782 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2783 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2784 &[],
2785 );
2786
2787 let builder = cx
2788 .update(|cx| {
2789 ::terminal::TerminalBuilder::new(
2790 None,
2791 None,
2792 task::Shell::WithArguments {
2793 program,
2794 args,
2795 title_override: None,
2796 },
2797 HashMap::default(),
2798 ::terminal::terminal_settings::CursorShape::default(),
2799 ::terminal::terminal_settings::AlternateScroll::On,
2800 None,
2801 vec![],
2802 0,
2803 false,
2804 0,
2805 Some(completion_tx),
2806 cx,
2807 vec![],
2808 PathStyle::local(),
2809 )
2810 })
2811 .await
2812 .unwrap();
2813
2814 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2815
2816 // Create the acp_thread Terminal wrapper
2817 thread.update(cx, |thread, cx| {
2818 thread.on_terminal_provider_event(
2819 TerminalProviderEvent::Created {
2820 terminal_id: terminal_id.clone(),
2821 label: "printf output_before_kill && sleep 60".to_string(),
2822 cwd: None,
2823 output_byte_limit: None,
2824 terminal: lower_terminal.clone(),
2825 },
2826 cx,
2827 );
2828 });
2829
2830 // Wait for the printf command to execute and produce output
2831 // Use real time since parking is enabled
2832 cx.executor().timer(Duration::from_millis(500)).await;
2833
2834 // Get the acp_thread Terminal and kill it
2835 let wait_for_exit = thread.update(cx, |thread, cx| {
2836 let term = thread.terminals.get(&terminal_id).unwrap();
2837 let wait_for_exit = term.read(cx).wait_for_exit();
2838 term.update(cx, |term, cx| {
2839 term.kill(cx);
2840 });
2841 wait_for_exit
2842 });
2843
2844 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2845 // Before the fix to kill_active_task, this would hang forever because
2846 // only the foreground process was killed, not the shell, so the PTY
2847 // child never exited and wait_for_completed_task never completed.
2848 let exit_result = futures::select! {
2849 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2850 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2851 };
2852
2853 assert!(
2854 exit_result.is_some(),
2855 "wait_for_exit should complete after kill, but it timed out. \
2856 This indicates kill_active_task is not properly killing the shell process."
2857 );
2858
2859 // Give the system a chance to process any pending updates
2860 cx.run_until_parked();
2861
2862 // Verify that the underlying terminal still has the output that was
2863 // written before the kill. This verifies that killing doesn't lose output.
2864 let inner_content = thread.read_with(cx, |thread, cx| {
2865 let term = thread.terminals.get(&terminal_id).unwrap();
2866 term.read(cx).inner().read(cx).get_content()
2867 });
2868
2869 assert!(
2870 inner_content.contains("output_before_kill"),
2871 "Underlying terminal should contain output from before kill, got: {}",
2872 inner_content
2873 );
2874 }
2875
2876 #[gpui::test]
2877 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2878 init_test(cx);
2879
2880 let fs = FakeFs::new(cx.executor());
2881 let project = Project::test(fs, [], cx).await;
2882 let connection = Rc::new(FakeAgentConnection::new());
2883 let thread = cx
2884 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2885 .await
2886 .unwrap();
2887
2888 // Test creating a new user message
2889 thread.update(cx, |thread, cx| {
2890 thread.push_user_content_block(None, "Hello, ".into(), cx);
2891 });
2892
2893 thread.update(cx, |thread, cx| {
2894 assert_eq!(thread.entries.len(), 1);
2895 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2896 assert_eq!(user_msg.id, None);
2897 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2898 } else {
2899 panic!("Expected UserMessage");
2900 }
2901 });
2902
2903 // Test appending to existing user message
2904 let message_1_id = UserMessageId::new();
2905 thread.update(cx, |thread, cx| {
2906 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2907 });
2908
2909 thread.update(cx, |thread, cx| {
2910 assert_eq!(thread.entries.len(), 1);
2911 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2912 assert_eq!(user_msg.id, Some(message_1_id));
2913 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2914 } else {
2915 panic!("Expected UserMessage");
2916 }
2917 });
2918
2919 // Test creating new user message after assistant message
2920 thread.update(cx, |thread, cx| {
2921 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2922 });
2923
2924 let message_2_id = UserMessageId::new();
2925 thread.update(cx, |thread, cx| {
2926 thread.push_user_content_block(
2927 Some(message_2_id.clone()),
2928 "New user message".into(),
2929 cx,
2930 );
2931 });
2932
2933 thread.update(cx, |thread, cx| {
2934 assert_eq!(thread.entries.len(), 3);
2935 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2936 assert_eq!(user_msg.id, Some(message_2_id));
2937 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2938 } else {
2939 panic!("Expected UserMessage at index 2");
2940 }
2941 });
2942 }
2943
2944 #[gpui::test]
2945 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2946 init_test(cx);
2947
2948 let fs = FakeFs::new(cx.executor());
2949 let project = Project::test(fs, [], cx).await;
2950 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2951 |_, thread, mut cx| {
2952 async move {
2953 thread.update(&mut cx, |thread, cx| {
2954 thread
2955 .handle_session_update(
2956 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2957 "Thinking ".into(),
2958 )),
2959 cx,
2960 )
2961 .unwrap();
2962 thread
2963 .handle_session_update(
2964 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2965 "hard!".into(),
2966 )),
2967 cx,
2968 )
2969 .unwrap();
2970 })?;
2971 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2972 }
2973 .boxed_local()
2974 },
2975 ));
2976
2977 let thread = cx
2978 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
2979 .await
2980 .unwrap();
2981
2982 thread
2983 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2984 .await
2985 .unwrap();
2986
2987 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2988 assert_eq!(
2989 output,
2990 indoc! {r#"
2991 ## User
2992
2993 Hello from Zed!
2994
2995 ## Assistant
2996
2997 <thinking>
2998 Thinking hard!
2999 </thinking>
3000
3001 "#}
3002 );
3003 }
3004
3005 #[gpui::test]
3006 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
3007 init_test(cx);
3008
3009 let fs = FakeFs::new(cx.executor());
3010 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3011 .await;
3012 let project = Project::test(fs.clone(), [], cx).await;
3013 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3014 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3015 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3016 move |_, thread, mut cx| {
3017 let read_file_tx = read_file_tx.clone();
3018 async move {
3019 let content = thread
3020 .update(&mut cx, |thread, cx| {
3021 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3022 })
3023 .unwrap()
3024 .await
3025 .unwrap();
3026 assert_eq!(content, "one\ntwo\nthree\n");
3027 read_file_tx.take().unwrap().send(()).unwrap();
3028 thread
3029 .update(&mut cx, |thread, cx| {
3030 thread.write_text_file(
3031 path!("/tmp/foo").into(),
3032 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3033 cx,
3034 )
3035 })
3036 .unwrap()
3037 .await
3038 .unwrap();
3039 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3040 }
3041 .boxed_local()
3042 },
3043 ));
3044
3045 let (worktree, pathbuf) = project
3046 .update(cx, |project, cx| {
3047 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3048 })
3049 .await
3050 .unwrap();
3051 let buffer = project
3052 .update(cx, |project, cx| {
3053 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3054 })
3055 .await
3056 .unwrap();
3057
3058 let thread = cx
3059 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3060 .await
3061 .unwrap();
3062
3063 let request = thread.update(cx, |thread, cx| {
3064 thread.send_raw("Extend the count in /tmp/foo", cx)
3065 });
3066 read_file_rx.await.ok();
3067 buffer.update(cx, |buffer, cx| {
3068 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3069 });
3070 cx.run_until_parked();
3071 assert_eq!(
3072 buffer.read_with(cx, |buffer, _| buffer.text()),
3073 "zero\none\ntwo\nthree\nfour\nfive\n"
3074 );
3075 assert_eq!(
3076 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3077 "zero\none\ntwo\nthree\nfour\nfive\n"
3078 );
3079 request.await.unwrap();
3080 }
3081
3082 #[gpui::test]
3083 async fn test_reading_from_line(cx: &mut TestAppContext) {
3084 init_test(cx);
3085
3086 let fs = FakeFs::new(cx.executor());
3087 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3088 .await;
3089 let project = Project::test(fs.clone(), [], cx).await;
3090 project
3091 .update(cx, |project, cx| {
3092 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3093 })
3094 .await
3095 .unwrap();
3096
3097 let connection = Rc::new(FakeAgentConnection::new());
3098
3099 let thread = cx
3100 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3101 .await
3102 .unwrap();
3103
3104 // Whole file
3105 let content = thread
3106 .update(cx, |thread, cx| {
3107 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3108 })
3109 .await
3110 .unwrap();
3111
3112 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3113
3114 // Only start line
3115 let content = thread
3116 .update(cx, |thread, cx| {
3117 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3118 })
3119 .await
3120 .unwrap();
3121
3122 assert_eq!(content, "three\nfour\n");
3123
3124 // Only limit
3125 let content = thread
3126 .update(cx, |thread, cx| {
3127 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3128 })
3129 .await
3130 .unwrap();
3131
3132 assert_eq!(content, "one\ntwo\n");
3133
3134 // Range
3135 let content = thread
3136 .update(cx, |thread, cx| {
3137 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3138 })
3139 .await
3140 .unwrap();
3141
3142 assert_eq!(content, "two\nthree\n");
3143
3144 // Invalid
3145 let err = thread
3146 .update(cx, |thread, cx| {
3147 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3148 })
3149 .await
3150 .unwrap_err();
3151
3152 assert_eq!(
3153 err.to_string(),
3154 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3155 );
3156 }
3157
3158 #[gpui::test]
3159 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3160 init_test(cx);
3161
3162 let fs = FakeFs::new(cx.executor());
3163 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3164 let project = Project::test(fs.clone(), [], cx).await;
3165 project
3166 .update(cx, |project, cx| {
3167 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3168 })
3169 .await
3170 .unwrap();
3171
3172 let connection = Rc::new(FakeAgentConnection::new());
3173
3174 let thread = cx
3175 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3176 .await
3177 .unwrap();
3178
3179 // Whole file
3180 let content = thread
3181 .update(cx, |thread, cx| {
3182 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3183 })
3184 .await
3185 .unwrap();
3186
3187 assert_eq!(content, "");
3188
3189 // Only start line
3190 let content = thread
3191 .update(cx, |thread, cx| {
3192 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3193 })
3194 .await
3195 .unwrap();
3196
3197 assert_eq!(content, "");
3198
3199 // Only limit
3200 let content = thread
3201 .update(cx, |thread, cx| {
3202 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3203 })
3204 .await
3205 .unwrap();
3206
3207 assert_eq!(content, "");
3208
3209 // Range
3210 let content = thread
3211 .update(cx, |thread, cx| {
3212 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3213 })
3214 .await
3215 .unwrap();
3216
3217 assert_eq!(content, "");
3218
3219 // Invalid
3220 let err = thread
3221 .update(cx, |thread, cx| {
3222 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3223 })
3224 .await
3225 .unwrap_err();
3226
3227 assert_eq!(
3228 err.to_string(),
3229 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3230 );
3231 }
3232 #[gpui::test]
3233 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3234 init_test(cx);
3235
3236 let fs = FakeFs::new(cx.executor());
3237 fs.insert_tree(path!("/tmp"), json!({})).await;
3238 let project = Project::test(fs.clone(), [], cx).await;
3239 project
3240 .update(cx, |project, cx| {
3241 project.find_or_create_worktree(path!("/tmp"), true, cx)
3242 })
3243 .await
3244 .unwrap();
3245
3246 let connection = Rc::new(FakeAgentConnection::new());
3247
3248 let thread = cx
3249 .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
3250 .await
3251 .unwrap();
3252
3253 // Out of project file
3254 let err = thread
3255 .update(cx, |thread, cx| {
3256 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3257 })
3258 .await
3259 .unwrap_err();
3260
3261 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3262 }
3263
3264 #[gpui::test]
3265 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3266 init_test(cx);
3267
3268 let fs = FakeFs::new(cx.executor());
3269 let project = Project::test(fs, [], cx).await;
3270 let id = acp::ToolCallId::new("test");
3271
3272 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3273 let id = id.clone();
3274 move |_, thread, mut cx| {
3275 let id = id.clone();
3276 async move {
3277 thread
3278 .update(&mut cx, |thread, cx| {
3279 thread.handle_session_update(
3280 acp::SessionUpdate::ToolCall(
3281 acp::ToolCall::new(id.clone(), "Label")
3282 .kind(acp::ToolKind::Fetch)
3283 .status(acp::ToolCallStatus::InProgress),
3284 ),
3285 cx,
3286 )
3287 })
3288 .unwrap()
3289 .unwrap();
3290 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3291 }
3292 .boxed_local()
3293 }
3294 }));
3295
3296 let thread = cx
3297 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3298 .await
3299 .unwrap();
3300
3301 let request = thread.update(cx, |thread, cx| {
3302 thread.send_raw("Fetch https://example.com", cx)
3303 });
3304
3305 run_until_first_tool_call(&thread, cx).await;
3306
3307 thread.read_with(cx, |thread, _| {
3308 assert!(matches!(
3309 thread.entries[1],
3310 AgentThreadEntry::ToolCall(ToolCall {
3311 status: ToolCallStatus::InProgress,
3312 ..
3313 })
3314 ));
3315 });
3316
3317 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3318
3319 thread.read_with(cx, |thread, _| {
3320 assert!(matches!(
3321 &thread.entries[1],
3322 AgentThreadEntry::ToolCall(ToolCall {
3323 status: ToolCallStatus::Canceled,
3324 ..
3325 })
3326 ));
3327 });
3328
3329 thread
3330 .update(cx, |thread, cx| {
3331 thread.handle_session_update(
3332 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3333 id,
3334 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3335 )),
3336 cx,
3337 )
3338 })
3339 .unwrap();
3340
3341 request.await.unwrap();
3342
3343 thread.read_with(cx, |thread, _| {
3344 assert!(matches!(
3345 thread.entries[1],
3346 AgentThreadEntry::ToolCall(ToolCall {
3347 status: ToolCallStatus::Completed,
3348 ..
3349 })
3350 ));
3351 });
3352 }
3353
3354 #[gpui::test]
3355 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3356 init_test(cx);
3357 let fs = FakeFs::new(cx.background_executor.clone());
3358 fs.insert_tree(path!("/test"), json!({})).await;
3359 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3360
3361 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3362 move |_, thread, mut cx| {
3363 async move {
3364 thread
3365 .update(&mut cx, |thread, cx| {
3366 thread.handle_session_update(
3367 acp::SessionUpdate::ToolCall(
3368 acp::ToolCall::new("test", "Label")
3369 .kind(acp::ToolKind::Edit)
3370 .status(acp::ToolCallStatus::Completed)
3371 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3372 "/test/test.txt",
3373 "foo",
3374 ))]),
3375 ),
3376 cx,
3377 )
3378 })
3379 .unwrap()
3380 .unwrap();
3381 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3382 }
3383 .boxed_local()
3384 }
3385 }));
3386
3387 let thread = cx
3388 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3389 .await
3390 .unwrap();
3391
3392 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3393 .await
3394 .unwrap();
3395
3396 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3397 }
3398
3399 #[gpui::test(iterations = 10)]
3400 async fn test_checkpoints(cx: &mut TestAppContext) {
3401 init_test(cx);
3402 let fs = FakeFs::new(cx.background_executor.clone());
3403 fs.insert_tree(
3404 path!("/test"),
3405 json!({
3406 ".git": {}
3407 }),
3408 )
3409 .await;
3410 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3411
3412 let simulate_changes = Arc::new(AtomicBool::new(true));
3413 let next_filename = Arc::new(AtomicUsize::new(0));
3414 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3415 let simulate_changes = simulate_changes.clone();
3416 let next_filename = next_filename.clone();
3417 let fs = fs.clone();
3418 move |request, thread, mut cx| {
3419 let fs = fs.clone();
3420 let simulate_changes = simulate_changes.clone();
3421 let next_filename = next_filename.clone();
3422 async move {
3423 if simulate_changes.load(SeqCst) {
3424 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3425 fs.write(Path::new(&filename), b"").await?;
3426 }
3427
3428 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3429 panic!("expected text content block");
3430 };
3431 thread.update(&mut cx, |thread, cx| {
3432 thread
3433 .handle_session_update(
3434 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3435 content.text.to_uppercase().into(),
3436 )),
3437 cx,
3438 )
3439 .unwrap();
3440 })?;
3441 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3442 }
3443 .boxed_local()
3444 }
3445 }));
3446 let thread = cx
3447 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3448 .await
3449 .unwrap();
3450
3451 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3452 .await
3453 .unwrap();
3454 thread.read_with(cx, |thread, cx| {
3455 assert_eq!(
3456 thread.to_markdown(cx),
3457 indoc! {"
3458 ## User (checkpoint)
3459
3460 Lorem
3461
3462 ## Assistant
3463
3464 LOREM
3465
3466 "}
3467 );
3468 });
3469 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3470
3471 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3472 .await
3473 .unwrap();
3474 thread.read_with(cx, |thread, cx| {
3475 assert_eq!(
3476 thread.to_markdown(cx),
3477 indoc! {"
3478 ## User (checkpoint)
3479
3480 Lorem
3481
3482 ## Assistant
3483
3484 LOREM
3485
3486 ## User (checkpoint)
3487
3488 ipsum
3489
3490 ## Assistant
3491
3492 IPSUM
3493
3494 "}
3495 );
3496 });
3497 assert_eq!(
3498 fs.files(),
3499 vec![
3500 Path::new(path!("/test/file-0")),
3501 Path::new(path!("/test/file-1"))
3502 ]
3503 );
3504
3505 // Checkpoint isn't stored when there are no changes.
3506 simulate_changes.store(false, SeqCst);
3507 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3508 .await
3509 .unwrap();
3510 thread.read_with(cx, |thread, cx| {
3511 assert_eq!(
3512 thread.to_markdown(cx),
3513 indoc! {"
3514 ## User (checkpoint)
3515
3516 Lorem
3517
3518 ## Assistant
3519
3520 LOREM
3521
3522 ## User (checkpoint)
3523
3524 ipsum
3525
3526 ## Assistant
3527
3528 IPSUM
3529
3530 ## User
3531
3532 dolor
3533
3534 ## Assistant
3535
3536 DOLOR
3537
3538 "}
3539 );
3540 });
3541 assert_eq!(
3542 fs.files(),
3543 vec![
3544 Path::new(path!("/test/file-0")),
3545 Path::new(path!("/test/file-1"))
3546 ]
3547 );
3548
3549 // Rewinding the conversation truncates the history and restores the checkpoint.
3550 thread
3551 .update(cx, |thread, cx| {
3552 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3553 panic!("unexpected entries {:?}", thread.entries)
3554 };
3555 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3556 })
3557 .await
3558 .unwrap();
3559 thread.read_with(cx, |thread, cx| {
3560 assert_eq!(
3561 thread.to_markdown(cx),
3562 indoc! {"
3563 ## User (checkpoint)
3564
3565 Lorem
3566
3567 ## Assistant
3568
3569 LOREM
3570
3571 "}
3572 );
3573 });
3574 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3575 }
3576
3577 #[gpui::test]
3578 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3579 use std::sync::atomic::AtomicUsize;
3580 init_test(cx);
3581
3582 let fs = FakeFs::new(cx.executor());
3583 let project = Project::test(fs, None, cx).await;
3584
3585 // Create a connection that simulates refusal after tool result
3586 let prompt_count = Arc::new(AtomicUsize::new(0));
3587 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3588 let prompt_count = prompt_count.clone();
3589 move |_request, thread, mut cx| {
3590 let count = prompt_count.fetch_add(1, SeqCst);
3591 async move {
3592 if count == 0 {
3593 // First prompt: Generate a tool call with result
3594 thread.update(&mut cx, |thread, cx| {
3595 thread
3596 .handle_session_update(
3597 acp::SessionUpdate::ToolCall(
3598 acp::ToolCall::new("tool1", "Test Tool")
3599 .kind(acp::ToolKind::Fetch)
3600 .status(acp::ToolCallStatus::Completed)
3601 .raw_input(serde_json::json!({"query": "test"}))
3602 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3603 ),
3604 cx,
3605 )
3606 .unwrap();
3607 })?;
3608
3609 // Now return refusal because of the tool result
3610 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3611 } else {
3612 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3613 }
3614 }
3615 .boxed_local()
3616 }
3617 }));
3618
3619 let thread = cx
3620 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3621 .await
3622 .unwrap();
3623
3624 // Track if we see a Refusal event
3625 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3626 let saw_refusal_event_captured = saw_refusal_event.clone();
3627 thread.update(cx, |_thread, cx| {
3628 cx.subscribe(
3629 &thread,
3630 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3631 if matches!(event, AcpThreadEvent::Refusal) {
3632 *saw_refusal_event_captured.lock().unwrap() = true;
3633 }
3634 },
3635 )
3636 .detach();
3637 });
3638
3639 // Send a user message - this will trigger tool call and then refusal
3640 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3641 cx.background_executor.spawn(send_task).detach();
3642 cx.run_until_parked();
3643
3644 // Verify that:
3645 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3646 // 2. The user message was NOT truncated
3647 assert!(
3648 *saw_refusal_event.lock().unwrap(),
3649 "Refusal event should be emitted for tool result refusals"
3650 );
3651
3652 thread.read_with(cx, |thread, _| {
3653 let entries = thread.entries();
3654 assert!(entries.len() >= 2, "Should have user message and tool call");
3655
3656 // Verify user message is still there
3657 assert!(
3658 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3659 "User message should not be truncated"
3660 );
3661
3662 // Verify tool call is there with result
3663 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3664 assert!(
3665 tool_call.raw_output.is_some(),
3666 "Tool call should have output"
3667 );
3668 } else {
3669 panic!("Expected tool call at index 1");
3670 }
3671 });
3672 }
3673
3674 #[gpui::test]
3675 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3676 init_test(cx);
3677
3678 let fs = FakeFs::new(cx.executor());
3679 let project = Project::test(fs, None, cx).await;
3680
3681 let refuse_next = Arc::new(AtomicBool::new(false));
3682 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3683 let refuse_next = refuse_next.clone();
3684 move |_request, _thread, _cx| {
3685 if refuse_next.load(SeqCst) {
3686 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3687 .boxed_local()
3688 } else {
3689 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3690 .boxed_local()
3691 }
3692 }
3693 }));
3694
3695 let thread = cx
3696 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3697 .await
3698 .unwrap();
3699
3700 // Track if we see a Refusal event
3701 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3702 let saw_refusal_event_captured = saw_refusal_event.clone();
3703 thread.update(cx, |_thread, cx| {
3704 cx.subscribe(
3705 &thread,
3706 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3707 if matches!(event, AcpThreadEvent::Refusal) {
3708 *saw_refusal_event_captured.lock().unwrap() = true;
3709 }
3710 },
3711 )
3712 .detach();
3713 });
3714
3715 // Send a message that will be refused
3716 refuse_next.store(true, SeqCst);
3717 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3718 .await
3719 .unwrap();
3720
3721 // Verify that a Refusal event WAS emitted for user prompt refusal
3722 assert!(
3723 *saw_refusal_event.lock().unwrap(),
3724 "Refusal event should be emitted for user prompt refusals"
3725 );
3726
3727 // Verify the message was truncated (user prompt refusal)
3728 thread.read_with(cx, |thread, cx| {
3729 assert_eq!(thread.to_markdown(cx), "");
3730 });
3731 }
3732
3733 #[gpui::test]
3734 async fn test_refusal(cx: &mut TestAppContext) {
3735 init_test(cx);
3736 let fs = FakeFs::new(cx.background_executor.clone());
3737 fs.insert_tree(path!("/"), json!({})).await;
3738 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3739
3740 let refuse_next = Arc::new(AtomicBool::new(false));
3741 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3742 let refuse_next = refuse_next.clone();
3743 move |request, thread, mut cx| {
3744 let refuse_next = refuse_next.clone();
3745 async move {
3746 if refuse_next.load(SeqCst) {
3747 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3748 }
3749
3750 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3751 panic!("expected text content block");
3752 };
3753 thread.update(&mut cx, |thread, cx| {
3754 thread
3755 .handle_session_update(
3756 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3757 content.text.to_uppercase().into(),
3758 )),
3759 cx,
3760 )
3761 .unwrap();
3762 })?;
3763 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3764 }
3765 .boxed_local()
3766 }
3767 }));
3768 let thread = cx
3769 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3770 .await
3771 .unwrap();
3772
3773 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3774 .await
3775 .unwrap();
3776 thread.read_with(cx, |thread, cx| {
3777 assert_eq!(
3778 thread.to_markdown(cx),
3779 indoc! {"
3780 ## User
3781
3782 hello
3783
3784 ## Assistant
3785
3786 HELLO
3787
3788 "}
3789 );
3790 });
3791
3792 // Simulate refusing the second message. The message should be truncated
3793 // when a user prompt is refused.
3794 refuse_next.store(true, SeqCst);
3795 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3796 .await
3797 .unwrap();
3798 thread.read_with(cx, |thread, cx| {
3799 assert_eq!(
3800 thread.to_markdown(cx),
3801 indoc! {"
3802 ## User
3803
3804 hello
3805
3806 ## Assistant
3807
3808 HELLO
3809
3810 "}
3811 );
3812 });
3813 }
3814
3815 async fn run_until_first_tool_call(
3816 thread: &Entity<AcpThread>,
3817 cx: &mut TestAppContext,
3818 ) -> usize {
3819 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3820
3821 let subscription = cx.update(|cx| {
3822 cx.subscribe(thread, move |thread, _, cx| {
3823 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3824 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3825 return tx.try_send(ix).unwrap();
3826 }
3827 }
3828 })
3829 });
3830
3831 select! {
3832 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3833 panic!("Timeout waiting for tool call")
3834 }
3835 ix = rx.next().fuse() => {
3836 drop(subscription);
3837 ix.unwrap()
3838 }
3839 }
3840 }
3841
3842 #[derive(Clone, Default)]
3843 struct FakeAgentConnection {
3844 auth_methods: Vec<acp::AuthMethod>,
3845 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3846 on_user_message: Option<
3847 Rc<
3848 dyn Fn(
3849 acp::PromptRequest,
3850 WeakEntity<AcpThread>,
3851 AsyncApp,
3852 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3853 + 'static,
3854 >,
3855 >,
3856 }
3857
3858 impl FakeAgentConnection {
3859 fn new() -> Self {
3860 Self {
3861 auth_methods: Vec::new(),
3862 on_user_message: None,
3863 sessions: Arc::default(),
3864 }
3865 }
3866
3867 #[expect(unused)]
3868 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3869 self.auth_methods = auth_methods;
3870 self
3871 }
3872
3873 fn on_user_message(
3874 mut self,
3875 handler: impl Fn(
3876 acp::PromptRequest,
3877 WeakEntity<AcpThread>,
3878 AsyncApp,
3879 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3880 + 'static,
3881 ) -> Self {
3882 self.on_user_message.replace(Rc::new(handler));
3883 self
3884 }
3885 }
3886
3887 impl AgentConnection for FakeAgentConnection {
3888 fn telemetry_id(&self) -> SharedString {
3889 "fake".into()
3890 }
3891
3892 fn auth_methods(&self) -> &[acp::AuthMethod] {
3893 &self.auth_methods
3894 }
3895
3896 fn new_session(
3897 self: Rc<Self>,
3898 project: Entity<Project>,
3899 _cwd: &Path,
3900 cx: &mut App,
3901 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3902 let session_id = acp::SessionId::new(
3903 rand::rng()
3904 .sample_iter(&distr::Alphanumeric)
3905 .take(7)
3906 .map(char::from)
3907 .collect::<String>(),
3908 );
3909 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3910 let thread = cx.new(|cx| {
3911 AcpThread::new(
3912 None,
3913 "Test",
3914 self.clone(),
3915 project,
3916 action_log,
3917 session_id.clone(),
3918 watch::Receiver::constant(
3919 acp::PromptCapabilities::new()
3920 .image(true)
3921 .audio(true)
3922 .embedded_context(true),
3923 ),
3924 cx,
3925 )
3926 });
3927 self.sessions.lock().insert(session_id, thread.downgrade());
3928 Task::ready(Ok(thread))
3929 }
3930
3931 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3932 if self.auth_methods().iter().any(|m| m.id == method) {
3933 Task::ready(Ok(()))
3934 } else {
3935 Task::ready(Err(anyhow!("Invalid Auth Method")))
3936 }
3937 }
3938
3939 fn prompt(
3940 &self,
3941 _id: Option<UserMessageId>,
3942 params: acp::PromptRequest,
3943 cx: &mut App,
3944 ) -> Task<gpui::Result<acp::PromptResponse>> {
3945 let sessions = self.sessions.lock();
3946 let thread = sessions.get(¶ms.session_id).unwrap();
3947 if let Some(handler) = &self.on_user_message {
3948 let handler = handler.clone();
3949 let thread = thread.clone();
3950 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3951 } else {
3952 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3953 }
3954 }
3955
3956 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
3957
3958 fn truncate(
3959 &self,
3960 session_id: &acp::SessionId,
3961 _cx: &App,
3962 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3963 Some(Rc::new(FakeAgentSessionEditor {
3964 _session_id: session_id.clone(),
3965 }))
3966 }
3967
3968 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3969 self
3970 }
3971 }
3972
3973 struct FakeAgentSessionEditor {
3974 _session_id: acp::SessionId,
3975 }
3976
3977 impl AgentSessionTruncate for FakeAgentSessionEditor {
3978 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3979 Task::ready(Ok(()))
3980 }
3981 }
3982
3983 #[gpui::test]
3984 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3985 init_test(cx);
3986
3987 let fs = FakeFs::new(cx.executor());
3988 let project = Project::test(fs, [], cx).await;
3989 let connection = Rc::new(FakeAgentConnection::new());
3990 let thread = cx
3991 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
3992 .await
3993 .unwrap();
3994
3995 // Try to update a tool call that doesn't exist
3996 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3997 thread.update(cx, |thread, cx| {
3998 let result = thread.handle_session_update(
3999 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4000 nonexistent_id.clone(),
4001 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4002 )),
4003 cx,
4004 );
4005
4006 // The update should succeed (not return an error)
4007 assert!(result.is_ok());
4008
4009 // There should now be exactly one entry in the thread
4010 assert_eq!(thread.entries.len(), 1);
4011
4012 // The entry should be a failed tool call
4013 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4014 assert_eq!(tool_call.id, nonexistent_id);
4015 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4016 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4017
4018 // Check that the content contains the error message
4019 assert_eq!(tool_call.content.len(), 1);
4020 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4021 match content_block {
4022 ContentBlock::Markdown { markdown } => {
4023 let markdown_text = markdown.read(cx).source();
4024 assert!(markdown_text.contains("Tool call not found"));
4025 }
4026 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4027 ContentBlock::ResourceLink { .. } => {
4028 panic!("Expected markdown content, got resource link")
4029 }
4030 ContentBlock::Image { .. } => {
4031 panic!("Expected markdown content, got image")
4032 }
4033 }
4034 } else {
4035 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4036 }
4037 } else {
4038 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4039 }
4040 });
4041 }
4042
4043 /// Tests that restoring a checkpoint properly cleans up terminals that were
4044 /// created after that checkpoint, and cancels any in-progress generation.
4045 ///
4046 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4047 /// that were started after that checkpoint should be terminated, and any in-progress
4048 /// AI generation should be canceled.
4049 #[gpui::test]
4050 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4051 init_test(cx);
4052
4053 let fs = FakeFs::new(cx.executor());
4054 let project = Project::test(fs, [], cx).await;
4055 let connection = Rc::new(FakeAgentConnection::new());
4056 let thread = cx
4057 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4058 .await
4059 .unwrap();
4060
4061 // Send first user message to create a checkpoint
4062 cx.update(|cx| {
4063 thread.update(cx, |thread, cx| {
4064 thread.send(vec!["first message".into()], cx)
4065 })
4066 })
4067 .await
4068 .unwrap();
4069
4070 // Send second message (creates another checkpoint) - we'll restore to this one
4071 cx.update(|cx| {
4072 thread.update(cx, |thread, cx| {
4073 thread.send(vec!["second message".into()], cx)
4074 })
4075 })
4076 .await
4077 .unwrap();
4078
4079 // Create 2 terminals BEFORE the checkpoint that have completed running
4080 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4081 let mock_terminal_1 = cx.new(|cx| {
4082 let builder = ::terminal::TerminalBuilder::new_display_only(
4083 ::terminal::terminal_settings::CursorShape::default(),
4084 ::terminal::terminal_settings::AlternateScroll::On,
4085 None,
4086 0,
4087 cx.background_executor(),
4088 PathStyle::local(),
4089 )
4090 .unwrap();
4091 builder.subscribe(cx)
4092 });
4093
4094 thread.update(cx, |thread, cx| {
4095 thread.on_terminal_provider_event(
4096 TerminalProviderEvent::Created {
4097 terminal_id: terminal_id_1.clone(),
4098 label: "echo 'first'".to_string(),
4099 cwd: Some(PathBuf::from("/test")),
4100 output_byte_limit: None,
4101 terminal: mock_terminal_1.clone(),
4102 },
4103 cx,
4104 );
4105 });
4106
4107 thread.update(cx, |thread, cx| {
4108 thread.on_terminal_provider_event(
4109 TerminalProviderEvent::Output {
4110 terminal_id: terminal_id_1.clone(),
4111 data: b"first\n".to_vec(),
4112 },
4113 cx,
4114 );
4115 });
4116
4117 thread.update(cx, |thread, cx| {
4118 thread.on_terminal_provider_event(
4119 TerminalProviderEvent::Exit {
4120 terminal_id: terminal_id_1.clone(),
4121 status: acp::TerminalExitStatus::new().exit_code(0),
4122 },
4123 cx,
4124 );
4125 });
4126
4127 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4128 let mock_terminal_2 = cx.new(|cx| {
4129 let builder = ::terminal::TerminalBuilder::new_display_only(
4130 ::terminal::terminal_settings::CursorShape::default(),
4131 ::terminal::terminal_settings::AlternateScroll::On,
4132 None,
4133 0,
4134 cx.background_executor(),
4135 PathStyle::local(),
4136 )
4137 .unwrap();
4138 builder.subscribe(cx)
4139 });
4140
4141 thread.update(cx, |thread, cx| {
4142 thread.on_terminal_provider_event(
4143 TerminalProviderEvent::Created {
4144 terminal_id: terminal_id_2.clone(),
4145 label: "echo 'second'".to_string(),
4146 cwd: Some(PathBuf::from("/test")),
4147 output_byte_limit: None,
4148 terminal: mock_terminal_2.clone(),
4149 },
4150 cx,
4151 );
4152 });
4153
4154 thread.update(cx, |thread, cx| {
4155 thread.on_terminal_provider_event(
4156 TerminalProviderEvent::Output {
4157 terminal_id: terminal_id_2.clone(),
4158 data: b"second\n".to_vec(),
4159 },
4160 cx,
4161 );
4162 });
4163
4164 thread.update(cx, |thread, cx| {
4165 thread.on_terminal_provider_event(
4166 TerminalProviderEvent::Exit {
4167 terminal_id: terminal_id_2.clone(),
4168 status: acp::TerminalExitStatus::new().exit_code(0),
4169 },
4170 cx,
4171 );
4172 });
4173
4174 // Get the second message ID to restore to
4175 let second_message_id = thread.read_with(cx, |thread, _| {
4176 // At this point we have:
4177 // - Index 0: First user message (with checkpoint)
4178 // - Index 1: Second user message (with checkpoint)
4179 // No assistant responses because FakeAgentConnection just returns EndTurn
4180 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4181 panic!("expected user message at index 1");
4182 };
4183 message.id.clone().unwrap()
4184 });
4185
4186 // Create a terminal AFTER the checkpoint we'll restore to.
4187 // This simulates the AI agent starting a long-running terminal command.
4188 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4189 let mock_terminal = cx.new(|cx| {
4190 let builder = ::terminal::TerminalBuilder::new_display_only(
4191 ::terminal::terminal_settings::CursorShape::default(),
4192 ::terminal::terminal_settings::AlternateScroll::On,
4193 None,
4194 0,
4195 cx.background_executor(),
4196 PathStyle::local(),
4197 )
4198 .unwrap();
4199 builder.subscribe(cx)
4200 });
4201
4202 // Register the terminal as created
4203 thread.update(cx, |thread, cx| {
4204 thread.on_terminal_provider_event(
4205 TerminalProviderEvent::Created {
4206 terminal_id: terminal_id.clone(),
4207 label: "sleep 1000".to_string(),
4208 cwd: Some(PathBuf::from("/test")),
4209 output_byte_limit: None,
4210 terminal: mock_terminal.clone(),
4211 },
4212 cx,
4213 );
4214 });
4215
4216 // Simulate the terminal producing output (still running)
4217 thread.update(cx, |thread, cx| {
4218 thread.on_terminal_provider_event(
4219 TerminalProviderEvent::Output {
4220 terminal_id: terminal_id.clone(),
4221 data: b"terminal is running...\n".to_vec(),
4222 },
4223 cx,
4224 );
4225 });
4226
4227 // Create a tool call entry that references this terminal
4228 // This represents the agent requesting a terminal command
4229 thread.update(cx, |thread, cx| {
4230 thread
4231 .handle_session_update(
4232 acp::SessionUpdate::ToolCall(
4233 acp::ToolCall::new("terminal-tool-1", "Running command")
4234 .kind(acp::ToolKind::Execute)
4235 .status(acp::ToolCallStatus::InProgress)
4236 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4237 terminal_id.clone(),
4238 ))])
4239 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4240 ),
4241 cx,
4242 )
4243 .unwrap();
4244 });
4245
4246 // Verify terminal exists and is in the thread
4247 let terminal_exists_before =
4248 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4249 assert!(
4250 terminal_exists_before,
4251 "Terminal should exist before checkpoint restore"
4252 );
4253
4254 // Verify the terminal's underlying task is still running (not completed)
4255 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4256 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4257 terminal_entity.read_with(cx, |term, _cx| {
4258 term.output().is_none() // output is None means it's still running
4259 })
4260 });
4261 assert!(
4262 terminal_running_before,
4263 "Terminal should be running before checkpoint restore"
4264 );
4265
4266 // Verify we have the expected entries before restore
4267 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4268 assert!(
4269 entry_count_before > 1,
4270 "Should have multiple entries before restore"
4271 );
4272
4273 // Restore the checkpoint to the second message.
4274 // This should:
4275 // 1. Cancel any in-progress generation (via the cancel() call)
4276 // 2. Remove the terminal that was created after that point
4277 thread
4278 .update(cx, |thread, cx| {
4279 thread.restore_checkpoint(second_message_id, cx)
4280 })
4281 .await
4282 .unwrap();
4283
4284 // Verify that no send_task is in progress after restore
4285 // (cancel() clears the send_task)
4286 let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4287 assert!(
4288 !has_send_task_after,
4289 "Should not have a send_task after restore (cancel should have cleared it)"
4290 );
4291
4292 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4293 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4294 assert_eq!(
4295 entry_count, 1,
4296 "Should have 1 entry after restore (only the first user message)"
4297 );
4298
4299 // Verify the 2 completed terminals from before the checkpoint still exist
4300 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4301 thread.terminals.contains_key(&terminal_id_1)
4302 });
4303 assert!(
4304 terminal_1_exists,
4305 "Terminal 1 (from before checkpoint) should still exist"
4306 );
4307
4308 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4309 thread.terminals.contains_key(&terminal_id_2)
4310 });
4311 assert!(
4312 terminal_2_exists,
4313 "Terminal 2 (from before checkpoint) should still exist"
4314 );
4315
4316 // Verify they're still in completed state
4317 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4318 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4319 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4320 });
4321 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4322
4323 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4324 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4325 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4326 });
4327 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4328
4329 // Verify the running terminal (created after checkpoint) was removed
4330 let terminal_3_exists =
4331 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4332 assert!(
4333 !terminal_3_exists,
4334 "Terminal 3 (created after checkpoint) should have been removed"
4335 );
4336
4337 // Verify total count is 2 (the two from before the checkpoint)
4338 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4339 assert_eq!(
4340 terminal_count, 2,
4341 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4342 );
4343 }
4344
4345 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4346 /// even when a new user message is added while the async checkpoint comparison is in progress.
4347 ///
4348 /// This is a regression test for a bug where update_last_checkpoint would fail with
4349 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4350 /// update_last_checkpoint started and when its async closure ran.
4351 #[gpui::test]
4352 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4353 init_test(cx);
4354
4355 let fs = FakeFs::new(cx.executor());
4356 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4357 .await;
4358 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4359
4360 let handler_done = Arc::new(AtomicBool::new(false));
4361 let handler_done_clone = handler_done.clone();
4362 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4363 move |_, _thread, _cx| {
4364 handler_done_clone.store(true, SeqCst);
4365 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4366 },
4367 ));
4368
4369 let thread = cx
4370 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4371 .await
4372 .unwrap();
4373
4374 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4375 let send_task = cx.background_executor.spawn(send_future);
4376
4377 // Tick until handler completes, then a few more to let update_last_checkpoint start
4378 while !handler_done.load(SeqCst) {
4379 cx.executor().tick();
4380 }
4381 for _ in 0..5 {
4382 cx.executor().tick();
4383 }
4384
4385 thread.update(cx, |thread, cx| {
4386 thread.push_entry(
4387 AgentThreadEntry::UserMessage(UserMessage {
4388 id: Some(UserMessageId::new()),
4389 content: ContentBlock::Empty,
4390 chunks: vec!["Injected message (no checkpoint)".into()],
4391 checkpoint: None,
4392 indented: false,
4393 }),
4394 cx,
4395 );
4396 });
4397
4398 cx.run_until_parked();
4399 let result = send_task.await;
4400
4401 assert!(
4402 result.is_ok(),
4403 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4404 result.err()
4405 );
4406 }
4407
4408 /// Tests that when a follow-up message is sent during generation,
4409 /// the first turn completing does NOT clear `running_turn` because
4410 /// it now belongs to the second turn.
4411 #[gpui::test]
4412 async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
4413 init_test(cx);
4414
4415 let fs = FakeFs::new(cx.executor());
4416 let project = Project::test(fs, [], cx).await;
4417
4418 // First handler waits for this signal before completing
4419 let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
4420 let first_complete_rx = RefCell::new(Some(first_complete_rx));
4421
4422 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
4423 move |params, _thread, _cx| {
4424 let first_complete_rx = first_complete_rx.borrow_mut().take();
4425 let is_first = params
4426 .prompt
4427 .iter()
4428 .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
4429
4430 async move {
4431 if is_first {
4432 // First handler waits until signaled
4433 if let Some(rx) = first_complete_rx {
4434 rx.await.ok();
4435 }
4436 }
4437 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
4438 }
4439 .boxed_local()
4440 }
4441 }));
4442
4443 let thread = cx
4444 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4445 .await
4446 .unwrap();
4447
4448 // Send first message (turn_id=1) - handler will block
4449 let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
4450 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
4451
4452 // Send second message (turn_id=2) while first is still blocked
4453 // This calls cancel() which takes turn 1's running_turn and sets turn 2's
4454 let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
4455 assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
4456
4457 let running_turn_after_second_send =
4458 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4459 assert_eq!(
4460 running_turn_after_second_send,
4461 Some(2),
4462 "running_turn should be set to turn 2 after sending second message"
4463 );
4464
4465 // Now signal first handler to complete
4466 first_complete_tx.send(()).ok();
4467
4468 // First request completes - should NOT clear running_turn
4469 // because running_turn now belongs to turn 2
4470 first_request.await.unwrap();
4471
4472 let running_turn_after_first =
4473 thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
4474 assert_eq!(
4475 running_turn_after_first,
4476 Some(2),
4477 "first turn completing should not clear running_turn (belongs to turn 2)"
4478 );
4479
4480 // Second request completes - SHOULD clear running_turn
4481 second_request.await.unwrap();
4482
4483 let running_turn_after_second =
4484 thread.read_with(cx, |thread, _| thread.running_turn.is_some());
4485 assert!(
4486 !running_turn_after_second,
4487 "second turn completing should clear running_turn"
4488 );
4489 }
4490
4491 #[gpui::test]
4492 async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
4493 cx: &mut TestAppContext,
4494 ) {
4495 init_test(cx);
4496
4497 let fs = FakeFs::new(cx.executor());
4498 let project = Project::test(fs, [], cx).await;
4499
4500 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4501 move |_params, thread, mut cx| {
4502 async move {
4503 thread
4504 .update(&mut cx, |thread, cx| {
4505 thread.handle_session_update(
4506 acp::SessionUpdate::ToolCall(
4507 acp::ToolCall::new(
4508 acp::ToolCallId::new("test-tool"),
4509 "Test Tool",
4510 )
4511 .kind(acp::ToolKind::Fetch)
4512 .status(acp::ToolCallStatus::InProgress),
4513 ),
4514 cx,
4515 )
4516 })
4517 .unwrap()
4518 .unwrap();
4519
4520 Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
4521 }
4522 .boxed_local()
4523 },
4524 ));
4525
4526 let thread = cx
4527 .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
4528 .await
4529 .unwrap();
4530
4531 let response = thread
4532 .update(cx, |thread, cx| thread.send_raw("test message", cx))
4533 .await;
4534
4535 let response = response
4536 .expect("send should succeed")
4537 .expect("should have response");
4538 assert_eq!(
4539 response.stop_reason,
4540 acp::StopReason::Cancelled,
4541 "response should have Cancelled stop_reason"
4542 );
4543
4544 thread.read_with(cx, |thread, _| {
4545 let tool_entry = thread
4546 .entries
4547 .iter()
4548 .find_map(|e| {
4549 if let AgentThreadEntry::ToolCall(call) = e {
4550 Some(call)
4551 } else {
4552 None
4553 }
4554 })
4555 .expect("should have tool call entry");
4556
4557 assert!(
4558 matches!(tool_entry.status, ToolCallStatus::Canceled),
4559 "tool should be marked as Canceled when response is Cancelled, got {:?}",
4560 tool_entry.status
4561 );
4562 });
4563 }
4564}