1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7
8/// Key used in ACP ToolCall meta to store the tool's programmatic name.
9/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
10pub const TOOL_NAME_META_KEY: &str = "tool_name";
11
12/// The tool name for subagent spawning
13pub const SUBAGENT_TOOL_NAME: &str = "subagent";
14
15/// Helper to extract tool name from ACP meta
16pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
17 meta.as_ref()
18 .and_then(|m| m.get(TOOL_NAME_META_KEY))
19 .and_then(|v| v.as_str())
20 .map(|s| SharedString::from(s.to_owned()))
21}
22
23/// Helper to create meta with tool name
24pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
25 acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
26}
27use collections::HashSet;
28pub use connection::*;
29pub use diff::*;
30use language::language_settings::FormatOnSave;
31pub use mention::*;
32use project::lsp_store::{FormatTrigger, LspFormatTarget};
33use serde::{Deserialize, Serialize};
34use serde_json::to_string_pretty;
35use settings::Settings as _;
36use task::{Shell, ShellBuilder};
37pub use terminal::*;
38
39use action_log::{ActionLog, ActionLogTelemetry};
40use agent_client_protocol::{self as acp};
41use anyhow::{Context as _, Result, anyhow};
42use editor::Bias;
43use futures::{FutureExt, channel::oneshot, future::BoxFuture};
44use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
45use itertools::Itertools;
46use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
47use markdown::Markdown;
48use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
49use std::collections::HashMap;
50use std::error::Error;
51use std::fmt::{Formatter, Write};
52use std::ops::Range;
53use std::process::ExitStatus;
54use std::rc::Rc;
55use std::time::{Duration, Instant};
56use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
57use ui::App;
58use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
59use uuid::Uuid;
60
61#[derive(Debug)]
62pub struct UserMessage {
63 pub id: Option<UserMessageId>,
64 pub content: ContentBlock,
65 pub chunks: Vec<acp::ContentBlock>,
66 pub checkpoint: Option<Checkpoint>,
67 pub indented: bool,
68}
69
70#[derive(Debug)]
71pub struct Checkpoint {
72 git_checkpoint: GitStoreCheckpoint,
73 pub show: bool,
74}
75
76impl UserMessage {
77 fn to_markdown(&self, cx: &App) -> String {
78 let mut markdown = String::new();
79 if self
80 .checkpoint
81 .as_ref()
82 .is_some_and(|checkpoint| checkpoint.show)
83 {
84 writeln!(markdown, "## User (checkpoint)").unwrap();
85 } else {
86 writeln!(markdown, "## User").unwrap();
87 }
88 writeln!(markdown).unwrap();
89 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
90 writeln!(markdown).unwrap();
91 markdown
92 }
93}
94
95#[derive(Debug, PartialEq)]
96pub struct AssistantMessage {
97 pub chunks: Vec<AssistantMessageChunk>,
98 pub indented: bool,
99}
100
101impl AssistantMessage {
102 pub fn to_markdown(&self, cx: &App) -> String {
103 format!(
104 "## Assistant\n\n{}\n\n",
105 self.chunks
106 .iter()
107 .map(|chunk| chunk.to_markdown(cx))
108 .join("\n\n")
109 )
110 }
111}
112
113#[derive(Debug, PartialEq)]
114pub enum AssistantMessageChunk {
115 Message { block: ContentBlock },
116 Thought { block: ContentBlock },
117}
118
119impl AssistantMessageChunk {
120 pub fn from_str(
121 chunk: &str,
122 language_registry: &Arc<LanguageRegistry>,
123 path_style: PathStyle,
124 cx: &mut App,
125 ) -> Self {
126 Self::Message {
127 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
128 }
129 }
130
131 fn to_markdown(&self, cx: &App) -> String {
132 match self {
133 Self::Message { block } => block.to_markdown(cx).to_string(),
134 Self::Thought { block } => {
135 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
136 }
137 }
138 }
139}
140
141#[derive(Debug)]
142pub enum AgentThreadEntry {
143 UserMessage(UserMessage),
144 AssistantMessage(AssistantMessage),
145 ToolCall(ToolCall),
146}
147
148impl AgentThreadEntry {
149 pub fn is_indented(&self) -> bool {
150 match self {
151 Self::UserMessage(message) => message.indented,
152 Self::AssistantMessage(message) => message.indented,
153 Self::ToolCall(_) => false,
154 }
155 }
156
157 pub fn to_markdown(&self, cx: &App) -> String {
158 match self {
159 Self::UserMessage(message) => message.to_markdown(cx),
160 Self::AssistantMessage(message) => message.to_markdown(cx),
161 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
162 }
163 }
164
165 pub fn user_message(&self) -> Option<&UserMessage> {
166 if let AgentThreadEntry::UserMessage(message) = self {
167 Some(message)
168 } else {
169 None
170 }
171 }
172
173 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
174 if let AgentThreadEntry::ToolCall(call) = self {
175 itertools::Either::Left(call.diffs())
176 } else {
177 itertools::Either::Right(std::iter::empty())
178 }
179 }
180
181 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
182 if let AgentThreadEntry::ToolCall(call) = self {
183 itertools::Either::Left(call.terminals())
184 } else {
185 itertools::Either::Right(std::iter::empty())
186 }
187 }
188
189 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
190 if let AgentThreadEntry::ToolCall(ToolCall {
191 locations,
192 resolved_locations,
193 ..
194 }) = self
195 {
196 Some((
197 locations.get(ix)?.clone(),
198 resolved_locations.get(ix)?.clone()?,
199 ))
200 } else {
201 None
202 }
203 }
204}
205
206#[derive(Debug)]
207pub struct ToolCall {
208 pub id: acp::ToolCallId,
209 pub label: Entity<Markdown>,
210 pub kind: acp::ToolKind,
211 pub content: Vec<ToolCallContent>,
212 pub status: ToolCallStatus,
213 pub locations: Vec<acp::ToolCallLocation>,
214 pub resolved_locations: Vec<Option<AgentLocation>>,
215 pub raw_input: Option<serde_json::Value>,
216 pub raw_input_markdown: Option<Entity<Markdown>>,
217 pub raw_output: Option<serde_json::Value>,
218 pub tool_name: Option<SharedString>,
219}
220
221impl ToolCall {
222 fn from_acp(
223 tool_call: acp::ToolCall,
224 status: ToolCallStatus,
225 language_registry: Arc<LanguageRegistry>,
226 path_style: PathStyle,
227 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
228 cx: &mut App,
229 ) -> Result<Self> {
230 let title = if tool_call.kind == acp::ToolKind::Execute {
231 tool_call.title
232 } else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
233 first_line.to_owned() + "…"
234 } else {
235 tool_call.title
236 };
237 let mut content = Vec::with_capacity(tool_call.content.len());
238 for item in tool_call.content {
239 if let Some(item) = ToolCallContent::from_acp(
240 item,
241 language_registry.clone(),
242 path_style,
243 terminals,
244 cx,
245 )? {
246 content.push(item);
247 }
248 }
249
250 let raw_input_markdown = tool_call
251 .raw_input
252 .as_ref()
253 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
254
255 let tool_name = tool_name_from_meta(&tool_call.meta);
256
257 let result = Self {
258 id: tool_call.tool_call_id,
259 label: cx
260 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
261 kind: tool_call.kind,
262 content,
263 locations: tool_call.locations,
264 resolved_locations: Vec::default(),
265 status,
266 raw_input: tool_call.raw_input,
267 raw_input_markdown,
268 raw_output: tool_call.raw_output,
269 tool_name,
270 };
271 Ok(result)
272 }
273
274 fn update_fields(
275 &mut self,
276 fields: acp::ToolCallUpdateFields,
277 language_registry: Arc<LanguageRegistry>,
278 path_style: PathStyle,
279 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
280 cx: &mut App,
281 ) -> Result<()> {
282 let acp::ToolCallUpdateFields {
283 kind,
284 status,
285 title,
286 content,
287 locations,
288 raw_input,
289 raw_output,
290 ..
291 } = fields;
292
293 if let Some(kind) = kind {
294 self.kind = kind;
295 }
296
297 if let Some(status) = status {
298 self.status = status.into();
299 }
300
301 if let Some(title) = title {
302 self.label.update(cx, |label, cx| {
303 if self.kind == acp::ToolKind::Execute {
304 label.replace(title, cx);
305 } else if let Some((first_line, _)) = title.split_once("\n") {
306 label.replace(first_line.to_owned() + "…", cx);
307 } else {
308 label.replace(title, cx);
309 }
310 });
311 }
312
313 if let Some(content) = content {
314 let mut new_content_len = content.len();
315 let mut content = content.into_iter();
316
317 // Reuse existing content if we can
318 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
319 let valid_content =
320 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
321 if !valid_content {
322 new_content_len -= 1;
323 }
324 }
325 for new in content {
326 if let Some(new) = ToolCallContent::from_acp(
327 new,
328 language_registry.clone(),
329 path_style,
330 terminals,
331 cx,
332 )? {
333 self.content.push(new);
334 } else {
335 new_content_len -= 1;
336 }
337 }
338 self.content.truncate(new_content_len);
339 }
340
341 if let Some(locations) = locations {
342 self.locations = locations;
343 }
344
345 if let Some(raw_input) = raw_input {
346 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
347 self.raw_input = Some(raw_input);
348 }
349
350 if let Some(raw_output) = raw_output {
351 if self.content.is_empty()
352 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
353 {
354 self.content
355 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
356 markdown,
357 }));
358 }
359 self.raw_output = Some(raw_output);
360 }
361 Ok(())
362 }
363
364 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
365 self.content.iter().filter_map(|content| match content {
366 ToolCallContent::Diff(diff) => Some(diff),
367 ToolCallContent::ContentBlock(_) => None,
368 ToolCallContent::Terminal(_) => None,
369 ToolCallContent::SubagentThread(_) => None,
370 })
371 }
372
373 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
374 self.content.iter().filter_map(|content| match content {
375 ToolCallContent::Terminal(terminal) => Some(terminal),
376 ToolCallContent::ContentBlock(_) => None,
377 ToolCallContent::Diff(_) => None,
378 ToolCallContent::SubagentThread(_) => None,
379 })
380 }
381
382 pub fn subagent_thread(&self) -> Option<&Entity<AcpThread>> {
383 self.content.iter().find_map(|content| match content {
384 ToolCallContent::SubagentThread(thread) => Some(thread),
385 _ => None,
386 })
387 }
388
389 pub fn is_subagent(&self) -> bool {
390 matches!(self.kind, acp::ToolKind::Other)
391 && self
392 .tool_name
393 .as_ref()
394 .map(|n| n.as_ref() == SUBAGENT_TOOL_NAME)
395 .unwrap_or(false)
396 }
397
398 pub fn to_markdown(&self, cx: &App) -> String {
399 let mut markdown = format!(
400 "**Tool Call: {}**\nStatus: {}\n\n",
401 self.label.read(cx).source(),
402 self.status
403 );
404 for content in &self.content {
405 markdown.push_str(content.to_markdown(cx).as_str());
406 markdown.push_str("\n\n");
407 }
408 markdown
409 }
410
411 async fn resolve_location(
412 location: acp::ToolCallLocation,
413 project: WeakEntity<Project>,
414 cx: &mut AsyncApp,
415 ) -> Option<ResolvedLocation> {
416 let buffer = project
417 .update(cx, |project, cx| {
418 project
419 .project_path_for_absolute_path(&location.path, cx)
420 .map(|path| project.open_buffer(path, cx))
421 })
422 .ok()??;
423 let buffer = buffer.await.log_err()?;
424 let position = buffer.update(cx, |buffer, _| {
425 let snapshot = buffer.snapshot();
426 if let Some(row) = location.line {
427 let column = snapshot.indent_size_for_line(row).len;
428 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
429 snapshot.anchor_before(point)
430 } else {
431 Anchor::min_for_buffer(snapshot.remote_id())
432 }
433 });
434
435 Some(ResolvedLocation { buffer, position })
436 }
437
438 fn resolve_locations(
439 &self,
440 project: Entity<Project>,
441 cx: &mut App,
442 ) -> Task<Vec<Option<ResolvedLocation>>> {
443 let locations = self.locations.clone();
444 project.update(cx, |_, cx| {
445 cx.spawn(async move |project, cx| {
446 let mut new_locations = Vec::new();
447 for location in locations {
448 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
449 }
450 new_locations
451 })
452 })
453 }
454}
455
456// Separate so we can hold a strong reference to the buffer
457// for saving on the thread
458#[derive(Clone, Debug, PartialEq, Eq)]
459struct ResolvedLocation {
460 buffer: Entity<Buffer>,
461 position: Anchor,
462}
463
464impl From<&ResolvedLocation> for AgentLocation {
465 fn from(value: &ResolvedLocation) -> Self {
466 Self {
467 buffer: value.buffer.downgrade(),
468 position: value.position,
469 }
470 }
471}
472
473#[derive(Debug)]
474pub enum ToolCallStatus {
475 /// The tool call hasn't started running yet, but we start showing it to
476 /// the user.
477 Pending,
478 /// The tool call is waiting for confirmation from the user.
479 WaitingForConfirmation {
480 options: Vec<acp::PermissionOption>,
481 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
482 },
483 /// The tool call is currently running.
484 InProgress,
485 /// The tool call completed successfully.
486 Completed,
487 /// The tool call failed.
488 Failed,
489 /// The user rejected the tool call.
490 Rejected,
491 /// The user canceled generation so the tool call was canceled.
492 Canceled,
493}
494
495impl From<acp::ToolCallStatus> for ToolCallStatus {
496 fn from(status: acp::ToolCallStatus) -> Self {
497 match status {
498 acp::ToolCallStatus::Pending => Self::Pending,
499 acp::ToolCallStatus::InProgress => Self::InProgress,
500 acp::ToolCallStatus::Completed => Self::Completed,
501 acp::ToolCallStatus::Failed => Self::Failed,
502 _ => Self::Pending,
503 }
504 }
505}
506
507impl Display for ToolCallStatus {
508 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
509 write!(
510 f,
511 "{}",
512 match self {
513 ToolCallStatus::Pending => "Pending",
514 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
515 ToolCallStatus::InProgress => "In Progress",
516 ToolCallStatus::Completed => "Completed",
517 ToolCallStatus::Failed => "Failed",
518 ToolCallStatus::Rejected => "Rejected",
519 ToolCallStatus::Canceled => "Canceled",
520 }
521 )
522 }
523}
524
525#[derive(Debug, PartialEq, Clone)]
526pub enum ContentBlock {
527 Empty,
528 Markdown { markdown: Entity<Markdown> },
529 ResourceLink { resource_link: acp::ResourceLink },
530 Image { image: Arc<gpui::Image> },
531}
532
533impl ContentBlock {
534 pub fn new(
535 block: acp::ContentBlock,
536 language_registry: &Arc<LanguageRegistry>,
537 path_style: PathStyle,
538 cx: &mut App,
539 ) -> Self {
540 let mut this = Self::Empty;
541 this.append(block, language_registry, path_style, cx);
542 this
543 }
544
545 pub fn new_combined(
546 blocks: impl IntoIterator<Item = acp::ContentBlock>,
547 language_registry: Arc<LanguageRegistry>,
548 path_style: PathStyle,
549 cx: &mut App,
550 ) -> Self {
551 let mut this = Self::Empty;
552 for block in blocks {
553 this.append(block, &language_registry, path_style, cx);
554 }
555 this
556 }
557
558 pub fn append(
559 &mut self,
560 block: acp::ContentBlock,
561 language_registry: &Arc<LanguageRegistry>,
562 path_style: PathStyle,
563 cx: &mut App,
564 ) {
565 match (&mut *self, &block) {
566 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
567 *self = ContentBlock::ResourceLink {
568 resource_link: resource_link.clone(),
569 };
570 }
571 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
572 if let Some(image) = Self::decode_image(image_content) {
573 *self = ContentBlock::Image { image };
574 } else {
575 let new_content = Self::image_md(image_content);
576 *self = Self::create_markdown_block(new_content, language_registry, cx);
577 }
578 }
579 (ContentBlock::Empty, _) => {
580 let new_content = Self::block_string_contents(&block, path_style);
581 *self = Self::create_markdown_block(new_content, language_registry, cx);
582 }
583 (ContentBlock::Markdown { markdown }, _) => {
584 let new_content = Self::block_string_contents(&block, path_style);
585 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
586 }
587 (ContentBlock::ResourceLink { resource_link }, _) => {
588 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
589 let new_content = Self::block_string_contents(&block, path_style);
590 let combined = format!("{}\n{}", existing_content, new_content);
591 *self = Self::create_markdown_block(combined, language_registry, cx);
592 }
593 (ContentBlock::Image { .. }, _) => {
594 let new_content = Self::block_string_contents(&block, path_style);
595 let combined = format!("`Image`\n{}", new_content);
596 *self = Self::create_markdown_block(combined, language_registry, cx);
597 }
598 }
599 }
600
601 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
602 use base64::Engine as _;
603
604 let bytes = base64::engine::general_purpose::STANDARD
605 .decode(image_content.data.as_bytes())
606 .ok()?;
607 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
608 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
609 }
610
611 fn create_markdown_block(
612 content: String,
613 language_registry: &Arc<LanguageRegistry>,
614 cx: &mut App,
615 ) -> ContentBlock {
616 ContentBlock::Markdown {
617 markdown: cx
618 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
619 }
620 }
621
622 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
623 match block {
624 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
625 acp::ContentBlock::ResourceLink(resource_link) => {
626 Self::resource_link_md(&resource_link.uri, path_style)
627 }
628 acp::ContentBlock::Resource(acp::EmbeddedResource {
629 resource:
630 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
631 uri,
632 ..
633 }),
634 ..
635 }) => Self::resource_link_md(uri, path_style),
636 acp::ContentBlock::Image(image) => Self::image_md(image),
637 _ => String::new(),
638 }
639 }
640
641 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
642 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
643 uri.as_link().to_string()
644 } else {
645 uri.to_string()
646 }
647 }
648
649 fn image_md(_image: &acp::ImageContent) -> String {
650 "`Image`".into()
651 }
652
653 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
654 match self {
655 ContentBlock::Empty => "",
656 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
657 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
658 ContentBlock::Image { .. } => "`Image`",
659 }
660 }
661
662 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
663 match self {
664 ContentBlock::Empty => None,
665 ContentBlock::Markdown { markdown } => Some(markdown),
666 ContentBlock::ResourceLink { .. } => None,
667 ContentBlock::Image { .. } => None,
668 }
669 }
670
671 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
672 match self {
673 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
674 _ => None,
675 }
676 }
677
678 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
679 match self {
680 ContentBlock::Image { image } => Some(image),
681 _ => None,
682 }
683 }
684}
685
686#[derive(Debug)]
687pub enum ToolCallContent {
688 ContentBlock(ContentBlock),
689 Diff(Entity<Diff>),
690 Terminal(Entity<Terminal>),
691 SubagentThread(Entity<AcpThread>),
692}
693
694impl ToolCallContent {
695 pub fn from_acp(
696 content: acp::ToolCallContent,
697 language_registry: Arc<LanguageRegistry>,
698 path_style: PathStyle,
699 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
700 cx: &mut App,
701 ) -> Result<Option<Self>> {
702 match content {
703 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
704 Ok(Some(Self::ContentBlock(ContentBlock::new(
705 content,
706 &language_registry,
707 path_style,
708 cx,
709 ))))
710 }
711 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
712 Diff::finalized(
713 diff.path.to_string_lossy().into_owned(),
714 diff.old_text,
715 diff.new_text,
716 language_registry,
717 cx,
718 )
719 })))),
720 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
721 .get(&terminal_id)
722 .cloned()
723 .map(|terminal| Some(Self::Terminal(terminal)))
724 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
725 _ => Ok(None),
726 }
727 }
728
729 pub fn update_from_acp(
730 &mut self,
731 new: acp::ToolCallContent,
732 language_registry: Arc<LanguageRegistry>,
733 path_style: PathStyle,
734 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
735 cx: &mut App,
736 ) -> Result<bool> {
737 let needs_update = match (&self, &new) {
738 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
739 old_diff.read(cx).needs_update(
740 new_diff.old_text.as_deref().unwrap_or(""),
741 &new_diff.new_text,
742 cx,
743 )
744 }
745 _ => true,
746 };
747
748 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
749 if needs_update {
750 *self = update;
751 }
752 Ok(true)
753 } else {
754 Ok(false)
755 }
756 }
757
758 pub fn to_markdown(&self, cx: &App) -> String {
759 match self {
760 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
761 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
762 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
763 Self::SubagentThread(thread) => thread.read(cx).to_markdown(cx),
764 }
765 }
766
767 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
768 match self {
769 Self::ContentBlock(content) => content.image(),
770 _ => None,
771 }
772 }
773
774 pub fn subagent_thread(&self) -> Option<&Entity<AcpThread>> {
775 match self {
776 Self::SubagentThread(thread) => Some(thread),
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, PartialEq)]
783pub enum ToolCallUpdate {
784 UpdateFields(acp::ToolCallUpdate),
785 UpdateDiff(ToolCallUpdateDiff),
786 UpdateTerminal(ToolCallUpdateTerminal),
787 UpdateSubagentThread(ToolCallUpdateSubagentThread),
788}
789
790impl ToolCallUpdate {
791 fn id(&self) -> &acp::ToolCallId {
792 match self {
793 Self::UpdateFields(update) => &update.tool_call_id,
794 Self::UpdateDiff(diff) => &diff.id,
795 Self::UpdateTerminal(terminal) => &terminal.id,
796 Self::UpdateSubagentThread(subagent) => &subagent.id,
797 }
798 }
799}
800
801impl From<acp::ToolCallUpdate> for ToolCallUpdate {
802 fn from(update: acp::ToolCallUpdate) -> Self {
803 Self::UpdateFields(update)
804 }
805}
806
807impl From<ToolCallUpdateDiff> for ToolCallUpdate {
808 fn from(diff: ToolCallUpdateDiff) -> Self {
809 Self::UpdateDiff(diff)
810 }
811}
812
813#[derive(Debug, PartialEq)]
814pub struct ToolCallUpdateDiff {
815 pub id: acp::ToolCallId,
816 pub diff: Entity<Diff>,
817}
818
819impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
820 fn from(terminal: ToolCallUpdateTerminal) -> Self {
821 Self::UpdateTerminal(terminal)
822 }
823}
824
825#[derive(Debug, PartialEq)]
826pub struct ToolCallUpdateTerminal {
827 pub id: acp::ToolCallId,
828 pub terminal: Entity<Terminal>,
829}
830
831impl From<ToolCallUpdateSubagentThread> for ToolCallUpdate {
832 fn from(subagent: ToolCallUpdateSubagentThread) -> Self {
833 Self::UpdateSubagentThread(subagent)
834 }
835}
836
837#[derive(Debug, PartialEq)]
838pub struct ToolCallUpdateSubagentThread {
839 pub id: acp::ToolCallId,
840 pub thread: Entity<AcpThread>,
841}
842
843#[derive(Debug, Default)]
844pub struct Plan {
845 pub entries: Vec<PlanEntry>,
846}
847
848#[derive(Debug)]
849pub struct PlanStats<'a> {
850 pub in_progress_entry: Option<&'a PlanEntry>,
851 pub pending: u32,
852 pub completed: u32,
853}
854
855impl Plan {
856 pub fn is_empty(&self) -> bool {
857 self.entries.is_empty()
858 }
859
860 pub fn stats(&self) -> PlanStats<'_> {
861 let mut stats = PlanStats {
862 in_progress_entry: None,
863 pending: 0,
864 completed: 0,
865 };
866
867 for entry in &self.entries {
868 match &entry.status {
869 acp::PlanEntryStatus::Pending => {
870 stats.pending += 1;
871 }
872 acp::PlanEntryStatus::InProgress => {
873 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
874 }
875 acp::PlanEntryStatus::Completed => {
876 stats.completed += 1;
877 }
878 _ => {}
879 }
880 }
881
882 stats
883 }
884}
885
886#[derive(Debug)]
887pub struct PlanEntry {
888 pub content: Entity<Markdown>,
889 pub priority: acp::PlanEntryPriority,
890 pub status: acp::PlanEntryStatus,
891}
892
893impl PlanEntry {
894 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
895 Self {
896 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
897 priority: entry.priority,
898 status: entry.status,
899 }
900 }
901}
902
903#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
904pub struct TokenUsage {
905 pub max_tokens: u64,
906 pub used_tokens: u64,
907 pub input_tokens: u64,
908 pub output_tokens: u64,
909}
910
911impl TokenUsage {
912 pub fn ratio(&self) -> TokenUsageRatio {
913 #[cfg(debug_assertions)]
914 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
915 .unwrap_or("0.8".to_string())
916 .parse()
917 .unwrap();
918 #[cfg(not(debug_assertions))]
919 let warning_threshold: f32 = 0.8;
920
921 // When the maximum is unknown because there is no selected model,
922 // avoid showing the token limit warning.
923 if self.max_tokens == 0 {
924 TokenUsageRatio::Normal
925 } else if self.used_tokens >= self.max_tokens {
926 TokenUsageRatio::Exceeded
927 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
928 TokenUsageRatio::Warning
929 } else {
930 TokenUsageRatio::Normal
931 }
932 }
933}
934
935#[derive(Debug, Clone, PartialEq, Eq)]
936pub enum TokenUsageRatio {
937 Normal,
938 Warning,
939 Exceeded,
940}
941
942#[derive(Debug, Clone)]
943pub struct RetryStatus {
944 pub last_error: SharedString,
945 pub attempt: usize,
946 pub max_attempts: usize,
947 pub started_at: Instant,
948 pub duration: Duration,
949}
950
951pub struct AcpThread {
952 title: SharedString,
953 entries: Vec<AgentThreadEntry>,
954 plan: Plan,
955 project: Entity<Project>,
956 action_log: Entity<ActionLog>,
957 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
958 send_task: Option<Task<()>>,
959 connection: Rc<dyn AgentConnection>,
960 session_id: acp::SessionId,
961 token_usage: Option<TokenUsage>,
962 prompt_capabilities: acp::PromptCapabilities,
963 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
964 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
965 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
966 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
967}
968
969impl From<&AcpThread> for ActionLogTelemetry {
970 fn from(value: &AcpThread) -> Self {
971 Self {
972 agent_telemetry_id: value.connection().telemetry_id(),
973 session_id: value.session_id.0.clone(),
974 }
975 }
976}
977
978#[derive(Debug)]
979pub enum AcpThreadEvent {
980 NewEntry,
981 TitleUpdated,
982 TokenUsageUpdated,
983 EntryUpdated(usize),
984 EntriesRemoved(Range<usize>),
985 ToolAuthorizationRequired,
986 Retry(RetryStatus),
987 Stopped,
988 Error,
989 LoadError(LoadError),
990 PromptCapabilitiesUpdated,
991 Refusal,
992 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
993 ModeUpdated(acp::SessionModeId),
994 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
995}
996
997impl EventEmitter<AcpThreadEvent> for AcpThread {}
998
999#[derive(Debug, Clone)]
1000pub enum TerminalProviderEvent {
1001 Created {
1002 terminal_id: acp::TerminalId,
1003 label: String,
1004 cwd: Option<PathBuf>,
1005 output_byte_limit: Option<u64>,
1006 terminal: Entity<::terminal::Terminal>,
1007 },
1008 Output {
1009 terminal_id: acp::TerminalId,
1010 data: Vec<u8>,
1011 },
1012 TitleChanged {
1013 terminal_id: acp::TerminalId,
1014 title: String,
1015 },
1016 Exit {
1017 terminal_id: acp::TerminalId,
1018 status: acp::TerminalExitStatus,
1019 },
1020}
1021
1022#[derive(Debug, Clone)]
1023pub enum TerminalProviderCommand {
1024 WriteInput {
1025 terminal_id: acp::TerminalId,
1026 bytes: Vec<u8>,
1027 },
1028 Resize {
1029 terminal_id: acp::TerminalId,
1030 cols: u16,
1031 rows: u16,
1032 },
1033 Close {
1034 terminal_id: acp::TerminalId,
1035 },
1036}
1037
1038impl AcpThread {
1039 pub fn on_terminal_provider_event(
1040 &mut self,
1041 event: TerminalProviderEvent,
1042 cx: &mut Context<Self>,
1043 ) {
1044 match event {
1045 TerminalProviderEvent::Created {
1046 terminal_id,
1047 label,
1048 cwd,
1049 output_byte_limit,
1050 terminal,
1051 } => {
1052 let entity = self.register_terminal_created(
1053 terminal_id.clone(),
1054 label,
1055 cwd,
1056 output_byte_limit,
1057 terminal,
1058 cx,
1059 );
1060
1061 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
1062 for data in chunks.drain(..) {
1063 entity.update(cx, |term, cx| {
1064 term.inner().update(cx, |inner, cx| {
1065 inner.write_output(&data, cx);
1066 })
1067 });
1068 }
1069 }
1070
1071 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1072 entity.update(cx, |_term, cx| {
1073 cx.notify();
1074 });
1075 }
1076
1077 cx.notify();
1078 }
1079 TerminalProviderEvent::Output { terminal_id, data } => {
1080 if let Some(entity) = self.terminals.get(&terminal_id) {
1081 entity.update(cx, |term, cx| {
1082 term.inner().update(cx, |inner, cx| {
1083 inner.write_output(&data, cx);
1084 })
1085 });
1086 } else {
1087 self.pending_terminal_output
1088 .entry(terminal_id)
1089 .or_default()
1090 .push(data);
1091 }
1092 }
1093 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1094 if let Some(entity) = self.terminals.get(&terminal_id) {
1095 entity.update(cx, |term, cx| {
1096 term.inner().update(cx, |inner, cx| {
1097 inner.breadcrumb_text = title;
1098 cx.emit(::terminal::Event::BreadcrumbsChanged);
1099 })
1100 });
1101 }
1102 }
1103 TerminalProviderEvent::Exit {
1104 terminal_id,
1105 status,
1106 } => {
1107 if let Some(entity) = self.terminals.get(&terminal_id) {
1108 entity.update(cx, |_term, cx| {
1109 cx.notify();
1110 });
1111 } else {
1112 self.pending_terminal_exit.insert(terminal_id, status);
1113 }
1114 }
1115 }
1116 }
1117}
1118
1119#[derive(PartialEq, Eq, Debug)]
1120pub enum ThreadStatus {
1121 Idle,
1122 Generating,
1123}
1124
1125#[derive(Debug, Clone)]
1126pub enum LoadError {
1127 Unsupported {
1128 command: SharedString,
1129 current_version: SharedString,
1130 minimum_version: SharedString,
1131 },
1132 FailedToInstall(SharedString),
1133 Exited {
1134 status: ExitStatus,
1135 },
1136 Other(SharedString),
1137}
1138
1139impl Display for LoadError {
1140 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1141 match self {
1142 LoadError::Unsupported {
1143 command: path,
1144 current_version,
1145 minimum_version,
1146 } => {
1147 write!(
1148 f,
1149 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1150 )
1151 }
1152 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1153 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1154 LoadError::Other(msg) => write!(f, "{msg}"),
1155 }
1156 }
1157}
1158
1159impl Error for LoadError {}
1160
1161impl AcpThread {
1162 pub fn new(
1163 title: impl Into<SharedString>,
1164 connection: Rc<dyn AgentConnection>,
1165 project: Entity<Project>,
1166 action_log: Entity<ActionLog>,
1167 session_id: acp::SessionId,
1168 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1169 cx: &mut Context<Self>,
1170 ) -> Self {
1171 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1172 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1173 loop {
1174 let caps = prompt_capabilities_rx.recv().await?;
1175 this.update(cx, |this, cx| {
1176 this.prompt_capabilities = caps;
1177 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1178 })?;
1179 }
1180 });
1181
1182 Self {
1183 action_log,
1184 shared_buffers: Default::default(),
1185 entries: Default::default(),
1186 plan: Default::default(),
1187 title: title.into(),
1188 project,
1189 send_task: None,
1190 connection,
1191 session_id,
1192 token_usage: None,
1193 prompt_capabilities,
1194 _observe_prompt_capabilities: task,
1195 terminals: HashMap::default(),
1196 pending_terminal_output: HashMap::default(),
1197 pending_terminal_exit: HashMap::default(),
1198 }
1199 }
1200
1201 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1202 self.prompt_capabilities.clone()
1203 }
1204
1205 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1206 &self.connection
1207 }
1208
1209 pub fn action_log(&self) -> &Entity<ActionLog> {
1210 &self.action_log
1211 }
1212
1213 pub fn project(&self) -> &Entity<Project> {
1214 &self.project
1215 }
1216
1217 pub fn title(&self) -> SharedString {
1218 self.title.clone()
1219 }
1220
1221 pub fn entries(&self) -> &[AgentThreadEntry] {
1222 &self.entries
1223 }
1224
1225 pub fn session_id(&self) -> &acp::SessionId {
1226 &self.session_id
1227 }
1228
1229 pub fn status(&self) -> ThreadStatus {
1230 if self.send_task.is_some() {
1231 ThreadStatus::Generating
1232 } else {
1233 ThreadStatus::Idle
1234 }
1235 }
1236
1237 pub fn token_usage(&self) -> Option<&TokenUsage> {
1238 self.token_usage.as_ref()
1239 }
1240
1241 pub fn has_pending_edit_tool_calls(&self) -> bool {
1242 for entry in self.entries.iter().rev() {
1243 match entry {
1244 AgentThreadEntry::UserMessage(_) => return false,
1245 AgentThreadEntry::ToolCall(
1246 call @ ToolCall {
1247 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1248 ..
1249 },
1250 ) if call.diffs().next().is_some() => {
1251 return true;
1252 }
1253 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1254 }
1255 }
1256
1257 false
1258 }
1259
1260 pub fn has_in_progress_tool_calls(&self) -> bool {
1261 for entry in self.entries.iter().rev() {
1262 match entry {
1263 AgentThreadEntry::UserMessage(_) => return false,
1264 AgentThreadEntry::ToolCall(ToolCall {
1265 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1266 ..
1267 }) => {
1268 return true;
1269 }
1270 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1271 }
1272 }
1273
1274 false
1275 }
1276
1277 pub fn used_tools_since_last_user_message(&self) -> bool {
1278 for entry in self.entries.iter().rev() {
1279 match entry {
1280 AgentThreadEntry::UserMessage(..) => return false,
1281 AgentThreadEntry::AssistantMessage(..) => continue,
1282 AgentThreadEntry::ToolCall(..) => return true,
1283 }
1284 }
1285
1286 false
1287 }
1288
1289 pub fn handle_session_update(
1290 &mut self,
1291 update: acp::SessionUpdate,
1292 cx: &mut Context<Self>,
1293 ) -> Result<(), acp::Error> {
1294 match update {
1295 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1296 self.push_user_content_block(None, content, cx);
1297 }
1298 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1299 self.push_assistant_content_block(content, false, cx);
1300 }
1301 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1302 self.push_assistant_content_block(content, true, cx);
1303 }
1304 acp::SessionUpdate::ToolCall(tool_call) => {
1305 self.upsert_tool_call(tool_call, cx)?;
1306 }
1307 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1308 self.update_tool_call(tool_call_update, cx)?;
1309 }
1310 acp::SessionUpdate::Plan(plan) => {
1311 self.update_plan(plan, cx);
1312 }
1313 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1314 available_commands,
1315 ..
1316 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1317 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1318 current_mode_id,
1319 ..
1320 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1321 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1322 config_options,
1323 ..
1324 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1325 _ => {}
1326 }
1327 Ok(())
1328 }
1329
1330 pub fn push_user_content_block(
1331 &mut self,
1332 message_id: Option<UserMessageId>,
1333 chunk: acp::ContentBlock,
1334 cx: &mut Context<Self>,
1335 ) {
1336 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1337 }
1338
1339 pub fn push_user_content_block_with_indent(
1340 &mut self,
1341 message_id: Option<UserMessageId>,
1342 chunk: acp::ContentBlock,
1343 indented: bool,
1344 cx: &mut Context<Self>,
1345 ) {
1346 let language_registry = self.project.read(cx).languages().clone();
1347 let path_style = self.project.read(cx).path_style(cx);
1348 let entries_len = self.entries.len();
1349
1350 if let Some(last_entry) = self.entries.last_mut()
1351 && let AgentThreadEntry::UserMessage(UserMessage {
1352 id,
1353 content,
1354 chunks,
1355 indented: existing_indented,
1356 ..
1357 }) = last_entry
1358 && *existing_indented == indented
1359 {
1360 *id = message_id.or(id.take());
1361 content.append(chunk.clone(), &language_registry, path_style, cx);
1362 chunks.push(chunk);
1363 let idx = entries_len - 1;
1364 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1365 } else {
1366 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1367 self.push_entry(
1368 AgentThreadEntry::UserMessage(UserMessage {
1369 id: message_id,
1370 content,
1371 chunks: vec![chunk],
1372 checkpoint: None,
1373 indented,
1374 }),
1375 cx,
1376 );
1377 }
1378 }
1379
1380 pub fn push_assistant_content_block(
1381 &mut self,
1382 chunk: acp::ContentBlock,
1383 is_thought: bool,
1384 cx: &mut Context<Self>,
1385 ) {
1386 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1387 }
1388
1389 pub fn push_assistant_content_block_with_indent(
1390 &mut self,
1391 chunk: acp::ContentBlock,
1392 is_thought: bool,
1393 indented: bool,
1394 cx: &mut Context<Self>,
1395 ) {
1396 let language_registry = self.project.read(cx).languages().clone();
1397 let path_style = self.project.read(cx).path_style(cx);
1398 let entries_len = self.entries.len();
1399 if let Some(last_entry) = self.entries.last_mut()
1400 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1401 chunks,
1402 indented: existing_indented,
1403 }) = last_entry
1404 && *existing_indented == indented
1405 {
1406 let idx = entries_len - 1;
1407 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1408 match (chunks.last_mut(), is_thought) {
1409 (Some(AssistantMessageChunk::Message { block }), false)
1410 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1411 block.append(chunk, &language_registry, path_style, cx)
1412 }
1413 _ => {
1414 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1415 if is_thought {
1416 chunks.push(AssistantMessageChunk::Thought { block })
1417 } else {
1418 chunks.push(AssistantMessageChunk::Message { block })
1419 }
1420 }
1421 }
1422 } else {
1423 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1424 let chunk = if is_thought {
1425 AssistantMessageChunk::Thought { block }
1426 } else {
1427 AssistantMessageChunk::Message { block }
1428 };
1429
1430 self.push_entry(
1431 AgentThreadEntry::AssistantMessage(AssistantMessage {
1432 chunks: vec![chunk],
1433 indented,
1434 }),
1435 cx,
1436 );
1437 }
1438 }
1439
1440 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1441 self.entries.push(entry);
1442 cx.emit(AcpThreadEvent::NewEntry);
1443 }
1444
1445 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1446 self.connection.set_title(&self.session_id, cx).is_some()
1447 }
1448
1449 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1450 if title != self.title {
1451 self.title = title.clone();
1452 cx.emit(AcpThreadEvent::TitleUpdated);
1453 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1454 return set_title.run(title, cx);
1455 }
1456 }
1457 Task::ready(Ok(()))
1458 }
1459
1460 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1461 self.token_usage = usage;
1462 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1463 }
1464
1465 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1466 cx.emit(AcpThreadEvent::Retry(status));
1467 }
1468
1469 pub fn update_tool_call(
1470 &mut self,
1471 update: impl Into<ToolCallUpdate>,
1472 cx: &mut Context<Self>,
1473 ) -> Result<()> {
1474 let update = update.into();
1475 let languages = self.project.read(cx).languages().clone();
1476 let path_style = self.project.read(cx).path_style(cx);
1477
1478 let ix = match self.index_for_tool_call(update.id()) {
1479 Some(ix) => ix,
1480 None => {
1481 // Tool call not found - create a failed tool call entry
1482 let failed_tool_call = ToolCall {
1483 id: update.id().clone(),
1484 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1485 kind: acp::ToolKind::Fetch,
1486 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1487 "Tool call not found".into(),
1488 &languages,
1489 path_style,
1490 cx,
1491 ))],
1492 status: ToolCallStatus::Failed,
1493 locations: Vec::new(),
1494 resolved_locations: Vec::new(),
1495 raw_input: None,
1496 raw_input_markdown: None,
1497 raw_output: None,
1498 tool_name: None,
1499 };
1500 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1501 return Ok(());
1502 }
1503 };
1504 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1505 unreachable!()
1506 };
1507
1508 match update {
1509 ToolCallUpdate::UpdateFields(update) => {
1510 let location_updated = update.fields.locations.is_some();
1511 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1512 if location_updated {
1513 self.resolve_locations(update.tool_call_id, cx);
1514 }
1515 }
1516 ToolCallUpdate::UpdateDiff(update) => {
1517 call.content.clear();
1518 call.content.push(ToolCallContent::Diff(update.diff));
1519 }
1520 ToolCallUpdate::UpdateTerminal(update) => {
1521 call.content.clear();
1522 call.content
1523 .push(ToolCallContent::Terminal(update.terminal));
1524 }
1525 ToolCallUpdate::UpdateSubagentThread(update) => {
1526 debug_assert!(
1527 !call.content.iter().any(|c| {
1528 matches!(c, ToolCallContent::SubagentThread(existing) if existing == &update.thread)
1529 }),
1530 "Duplicate SubagentThread update for the same AcpThread entity"
1531 );
1532 call.content
1533 .push(ToolCallContent::SubagentThread(update.thread));
1534 }
1535 }
1536
1537 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1538
1539 Ok(())
1540 }
1541
1542 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1543 pub fn upsert_tool_call(
1544 &mut self,
1545 tool_call: acp::ToolCall,
1546 cx: &mut Context<Self>,
1547 ) -> Result<(), acp::Error> {
1548 let status = tool_call.status.into();
1549 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1550 }
1551
1552 /// Fails if id does not match an existing entry.
1553 pub fn upsert_tool_call_inner(
1554 &mut self,
1555 update: acp::ToolCallUpdate,
1556 status: ToolCallStatus,
1557 cx: &mut Context<Self>,
1558 ) -> Result<(), acp::Error> {
1559 let language_registry = self.project.read(cx).languages().clone();
1560 let path_style = self.project.read(cx).path_style(cx);
1561 let id = update.tool_call_id.clone();
1562
1563 let agent_telemetry_id = self.connection().telemetry_id();
1564 let session = self.session_id();
1565 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1566 let status = if matches!(status, ToolCallStatus::Completed) {
1567 "completed"
1568 } else {
1569 "failed"
1570 };
1571 telemetry::event!(
1572 "Agent Tool Call Completed",
1573 agent_telemetry_id,
1574 session,
1575 status
1576 );
1577 }
1578
1579 if let Some(ix) = self.index_for_tool_call(&id) {
1580 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1581 unreachable!()
1582 };
1583
1584 call.update_fields(
1585 update.fields,
1586 language_registry,
1587 path_style,
1588 &self.terminals,
1589 cx,
1590 )?;
1591 call.status = status;
1592
1593 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1594 } else {
1595 let call = ToolCall::from_acp(
1596 update.try_into()?,
1597 status,
1598 language_registry,
1599 self.project.read(cx).path_style(cx),
1600 &self.terminals,
1601 cx,
1602 )?;
1603 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1604 };
1605
1606 self.resolve_locations(id, cx);
1607 Ok(())
1608 }
1609
1610 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1611 self.entries
1612 .iter()
1613 .enumerate()
1614 .rev()
1615 .find_map(|(index, entry)| {
1616 if let AgentThreadEntry::ToolCall(tool_call) = entry
1617 && &tool_call.id == id
1618 {
1619 Some(index)
1620 } else {
1621 None
1622 }
1623 })
1624 }
1625
1626 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1627 // The tool call we are looking for is typically the last one, or very close to the end.
1628 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1629 self.entries
1630 .iter_mut()
1631 .enumerate()
1632 .rev()
1633 .find_map(|(index, tool_call)| {
1634 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1635 && &tool_call.id == id
1636 {
1637 Some((index, tool_call))
1638 } else {
1639 None
1640 }
1641 })
1642 }
1643
1644 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1645 self.entries
1646 .iter()
1647 .enumerate()
1648 .rev()
1649 .find_map(|(index, tool_call)| {
1650 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1651 && &tool_call.id == id
1652 {
1653 Some((index, tool_call))
1654 } else {
1655 None
1656 }
1657 })
1658 }
1659
1660 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1661 let project = self.project.clone();
1662 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1663 return;
1664 };
1665 let task = tool_call.resolve_locations(project, cx);
1666 cx.spawn(async move |this, cx| {
1667 let resolved_locations = task.await;
1668
1669 this.update(cx, |this, cx| {
1670 let project = this.project.clone();
1671
1672 for location in resolved_locations.iter().flatten() {
1673 this.shared_buffers
1674 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1675 }
1676 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1677 return;
1678 };
1679
1680 if let Some(Some(location)) = resolved_locations.last() {
1681 project.update(cx, |project, cx| {
1682 let should_ignore = if let Some(agent_location) = project
1683 .agent_location()
1684 .filter(|agent_location| agent_location.buffer == location.buffer)
1685 {
1686 let snapshot = location.buffer.read(cx).snapshot();
1687 let old_position = agent_location.position.to_point(&snapshot);
1688 let new_position = location.position.to_point(&snapshot);
1689
1690 // ignore this so that when we get updates from the edit tool
1691 // the position doesn't reset to the startof line
1692 old_position.row == new_position.row
1693 && old_position.column > new_position.column
1694 } else {
1695 false
1696 };
1697 if !should_ignore {
1698 project.set_agent_location(Some(location.into()), cx);
1699 }
1700 });
1701 }
1702
1703 let resolved_locations = resolved_locations
1704 .iter()
1705 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1706 .collect::<Vec<_>>();
1707
1708 if tool_call.resolved_locations != resolved_locations {
1709 tool_call.resolved_locations = resolved_locations;
1710 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1711 }
1712 })
1713 })
1714 .detach();
1715 }
1716
1717 pub fn request_tool_call_authorization(
1718 &mut self,
1719 tool_call: acp::ToolCallUpdate,
1720 options: Vec<acp::PermissionOption>,
1721 respect_always_allow_setting: bool,
1722 cx: &mut Context<Self>,
1723 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1724 let (tx, rx) = oneshot::channel();
1725
1726 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1727 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1728 // some tools would (incorrectly) continue to auto-accept.
1729 if let Some(allow_once_option) = options.iter().find_map(|option| {
1730 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1731 Some(option.option_id.clone())
1732 } else {
1733 None
1734 }
1735 }) {
1736 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1737 return Ok(async {
1738 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1739 allow_once_option,
1740 ))
1741 }
1742 .boxed());
1743 }
1744 }
1745
1746 let status = ToolCallStatus::WaitingForConfirmation {
1747 options,
1748 respond_tx: tx,
1749 };
1750
1751 self.upsert_tool_call_inner(tool_call, status, cx)?;
1752 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1753
1754 let fut = async {
1755 match rx.await {
1756 Ok(option) => acp::RequestPermissionOutcome::Selected(
1757 acp::SelectedPermissionOutcome::new(option),
1758 ),
1759 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1760 }
1761 }
1762 .boxed();
1763
1764 Ok(fut)
1765 }
1766
1767 pub fn authorize_tool_call(
1768 &mut self,
1769 id: acp::ToolCallId,
1770 option_id: acp::PermissionOptionId,
1771 option_kind: acp::PermissionOptionKind,
1772 cx: &mut Context<Self>,
1773 ) {
1774 let Some((ix, call)) = self.tool_call_mut(&id) else {
1775 return;
1776 };
1777
1778 let new_status = match option_kind {
1779 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1780 ToolCallStatus::Rejected
1781 }
1782 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1783 ToolCallStatus::InProgress
1784 }
1785 _ => ToolCallStatus::InProgress,
1786 };
1787
1788 let curr_status = mem::replace(&mut call.status, new_status);
1789
1790 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1791 respond_tx.send(option_id).log_err();
1792 } else if cfg!(debug_assertions) {
1793 panic!("tried to authorize an already authorized tool call");
1794 }
1795
1796 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1797 }
1798
1799 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1800 let mut first_tool_call = None;
1801
1802 for entry in self.entries.iter().rev() {
1803 match &entry {
1804 AgentThreadEntry::ToolCall(call) => {
1805 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1806 first_tool_call = Some(call);
1807 } else {
1808 continue;
1809 }
1810 }
1811 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1812 // Reached the beginning of the turn.
1813 // If we had pending permission requests in the previous turn, they have been cancelled.
1814 break;
1815 }
1816 }
1817 }
1818
1819 first_tool_call
1820 }
1821
1822 pub fn plan(&self) -> &Plan {
1823 &self.plan
1824 }
1825
1826 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1827 let new_entries_len = request.entries.len();
1828 let mut new_entries = request.entries.into_iter();
1829
1830 // Reuse existing markdown to prevent flickering
1831 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1832 let PlanEntry {
1833 content,
1834 priority,
1835 status,
1836 } = old;
1837 content.update(cx, |old, cx| {
1838 old.replace(new.content, cx);
1839 });
1840 *priority = new.priority;
1841 *status = new.status;
1842 }
1843 for new in new_entries {
1844 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1845 }
1846 self.plan.entries.truncate(new_entries_len);
1847
1848 cx.notify();
1849 }
1850
1851 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1852 self.plan
1853 .entries
1854 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1855 cx.notify();
1856 }
1857
1858 #[cfg(any(test, feature = "test-support"))]
1859 pub fn send_raw(
1860 &mut self,
1861 message: &str,
1862 cx: &mut Context<Self>,
1863 ) -> BoxFuture<'static, Result<()>> {
1864 self.send(vec![message.into()], cx)
1865 }
1866
1867 pub fn send(
1868 &mut self,
1869 message: Vec<acp::ContentBlock>,
1870 cx: &mut Context<Self>,
1871 ) -> BoxFuture<'static, Result<()>> {
1872 let block = ContentBlock::new_combined(
1873 message.clone(),
1874 self.project.read(cx).languages().clone(),
1875 self.project.read(cx).path_style(cx),
1876 cx,
1877 );
1878 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1879 let git_store = self.project.read(cx).git_store().clone();
1880
1881 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1882 Some(UserMessageId::new())
1883 } else {
1884 None
1885 };
1886
1887 self.run_turn(cx, async move |this, cx| {
1888 this.update(cx, |this, cx| {
1889 this.push_entry(
1890 AgentThreadEntry::UserMessage(UserMessage {
1891 id: message_id.clone(),
1892 content: block,
1893 chunks: message,
1894 checkpoint: None,
1895 indented: false,
1896 }),
1897 cx,
1898 );
1899 })
1900 .ok();
1901
1902 let old_checkpoint = git_store
1903 .update(cx, |git, cx| git.checkpoint(cx))
1904 .await
1905 .context("failed to get old checkpoint")
1906 .log_err();
1907 this.update(cx, |this, cx| {
1908 if let Some((_ix, message)) = this.last_user_message() {
1909 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1910 git_checkpoint,
1911 show: false,
1912 });
1913 }
1914 this.connection.prompt(message_id, request, cx)
1915 })?
1916 .await
1917 })
1918 }
1919
1920 pub fn can_resume(&self, cx: &App) -> bool {
1921 self.connection.resume(&self.session_id, cx).is_some()
1922 }
1923
1924 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1925 self.run_turn(cx, async move |this, cx| {
1926 this.update(cx, |this, cx| {
1927 this.connection
1928 .resume(&this.session_id, cx)
1929 .map(|resume| resume.run(cx))
1930 })?
1931 .context("resuming a session is not supported")?
1932 .await
1933 })
1934 }
1935
1936 fn run_turn(
1937 &mut self,
1938 cx: &mut Context<Self>,
1939 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1940 ) -> BoxFuture<'static, Result<()>> {
1941 self.clear_completed_plan_entries(cx);
1942
1943 let (tx, rx) = oneshot::channel();
1944 let cancel_task = self.cancel(cx);
1945
1946 self.send_task = Some(cx.spawn(async move |this, cx| {
1947 cancel_task.await;
1948 tx.send(f(this, cx).await).ok();
1949 }));
1950
1951 cx.spawn(async move |this, cx| {
1952 let response = rx.await;
1953
1954 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1955 .await?;
1956
1957 this.update(cx, |this, cx| {
1958 this.project
1959 .update(cx, |project, cx| project.set_agent_location(None, cx));
1960 match response {
1961 Ok(Err(e)) => {
1962 this.send_task.take();
1963 cx.emit(AcpThreadEvent::Error);
1964 Err(e)
1965 }
1966 result => {
1967 let canceled = matches!(
1968 result,
1969 Ok(Ok(acp::PromptResponse {
1970 stop_reason: acp::StopReason::Cancelled,
1971 ..
1972 }))
1973 );
1974
1975 // We only take the task if the current prompt wasn't canceled.
1976 //
1977 // This prompt may have been canceled because another one was sent
1978 // while it was still generating. In these cases, dropping `send_task`
1979 // would cause the next generation to be canceled.
1980 if !canceled {
1981 this.send_task.take();
1982 }
1983
1984 // Handle refusal - distinguish between user prompt and tool call refusals
1985 if let Ok(Ok(acp::PromptResponse {
1986 stop_reason: acp::StopReason::Refusal,
1987 ..
1988 })) = result
1989 {
1990 if let Some((user_msg_ix, _)) = this.last_user_message() {
1991 // Check if there's a completed tool call with results after the last user message
1992 // This indicates the refusal is in response to tool output, not the user's prompt
1993 let has_completed_tool_call_after_user_msg =
1994 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1995 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1996 // Check if the tool call has completed and has output
1997 matches!(tool_call.status, ToolCallStatus::Completed)
1998 && tool_call.raw_output.is_some()
1999 } else {
2000 false
2001 }
2002 });
2003
2004 if has_completed_tool_call_after_user_msg {
2005 // Refusal is due to tool output - don't truncate, just notify
2006 // The model refused based on what the tool returned
2007 cx.emit(AcpThreadEvent::Refusal);
2008 } else {
2009 // User prompt was refused - truncate back to before the user message
2010 let range = user_msg_ix..this.entries.len();
2011 if range.start < range.end {
2012 this.entries.truncate(user_msg_ix);
2013 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2014 }
2015 cx.emit(AcpThreadEvent::Refusal);
2016 }
2017 } else {
2018 // No user message found, treat as general refusal
2019 cx.emit(AcpThreadEvent::Refusal);
2020 }
2021 }
2022
2023 cx.emit(AcpThreadEvent::Stopped);
2024 Ok(())
2025 }
2026 }
2027 })?
2028 })
2029 .boxed()
2030 }
2031
2032 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
2033 let Some(send_task) = self.send_task.take() else {
2034 return Task::ready(());
2035 };
2036
2037 for entry in self.entries.iter_mut() {
2038 if let AgentThreadEntry::ToolCall(call) = entry {
2039 let cancel = matches!(
2040 call.status,
2041 ToolCallStatus::Pending
2042 | ToolCallStatus::WaitingForConfirmation { .. }
2043 | ToolCallStatus::InProgress
2044 );
2045
2046 if cancel {
2047 call.status = ToolCallStatus::Canceled;
2048 }
2049 }
2050 }
2051
2052 self.connection.cancel(&self.session_id, cx);
2053
2054 // Wait for the send task to complete
2055 cx.foreground_executor().spawn(send_task)
2056 }
2057
2058 /// Restores the git working tree to the state at the given checkpoint (if one exists)
2059 pub fn restore_checkpoint(
2060 &mut self,
2061 id: UserMessageId,
2062 cx: &mut Context<Self>,
2063 ) -> Task<Result<()>> {
2064 let Some((_, message)) = self.user_message_mut(&id) else {
2065 return Task::ready(Err(anyhow!("message not found")));
2066 };
2067
2068 let checkpoint = message
2069 .checkpoint
2070 .as_ref()
2071 .map(|c| c.git_checkpoint.clone());
2072
2073 // Cancel any in-progress generation before restoring
2074 let cancel_task = self.cancel(cx);
2075 let rewind = self.rewind(id.clone(), cx);
2076 let git_store = self.project.read(cx).git_store().clone();
2077
2078 cx.spawn(async move |_, cx| {
2079 cancel_task.await;
2080 rewind.await?;
2081 if let Some(checkpoint) = checkpoint {
2082 git_store
2083 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2084 .await?;
2085 }
2086
2087 Ok(())
2088 })
2089 }
2090
2091 /// Rewinds this thread to before the entry at `index`, removing it and all
2092 /// subsequent entries while rejecting any action_log changes made from that point.
2093 /// Unlike `restore_checkpoint`, this method does not restore from git.
2094 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2095 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2096 return Task::ready(Err(anyhow!("not supported")));
2097 };
2098
2099 let telemetry = ActionLogTelemetry::from(&*self);
2100 cx.spawn(async move |this, cx| {
2101 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2102 this.update(cx, |this, cx| {
2103 if let Some((ix, _)) = this.user_message_mut(&id) {
2104 // Collect all terminals from entries that will be removed
2105 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2106 .iter()
2107 .flat_map(|entry| entry.terminals())
2108 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2109 .collect();
2110
2111 let range = ix..this.entries.len();
2112 this.entries.truncate(ix);
2113 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2114
2115 // Kill and remove the terminals
2116 for terminal_id in terminals_to_remove {
2117 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2118 terminal.update(cx, |terminal, cx| {
2119 terminal.kill(cx);
2120 });
2121 }
2122 }
2123 }
2124 this.action_log().update(cx, |action_log, cx| {
2125 action_log.reject_all_edits(Some(telemetry), cx)
2126 })
2127 })?
2128 .await;
2129 Ok(())
2130 })
2131 }
2132
2133 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2134 let git_store = self.project.read(cx).git_store().clone();
2135
2136 let Some((_, message)) = self.last_user_message() else {
2137 return Task::ready(Ok(()));
2138 };
2139 let Some(user_message_id) = message.id.clone() else {
2140 return Task::ready(Ok(()));
2141 };
2142 let Some(checkpoint) = message.checkpoint.as_ref() else {
2143 return Task::ready(Ok(()));
2144 };
2145 let old_checkpoint = checkpoint.git_checkpoint.clone();
2146
2147 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2148 cx.spawn(async move |this, cx| {
2149 let Some(new_checkpoint) = new_checkpoint
2150 .await
2151 .context("failed to get new checkpoint")
2152 .log_err()
2153 else {
2154 return Ok(());
2155 };
2156
2157 let equal = git_store
2158 .update(cx, |git, cx| {
2159 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2160 })
2161 .await
2162 .unwrap_or(true);
2163
2164 this.update(cx, |this, cx| {
2165 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2166 if let Some(checkpoint) = message.checkpoint.as_mut() {
2167 checkpoint.show = !equal;
2168 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2169 }
2170 }
2171 })?;
2172
2173 Ok(())
2174 })
2175 }
2176
2177 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2178 self.entries
2179 .iter_mut()
2180 .enumerate()
2181 .rev()
2182 .find_map(|(ix, entry)| {
2183 if let AgentThreadEntry::UserMessage(message) = entry {
2184 Some((ix, message))
2185 } else {
2186 None
2187 }
2188 })
2189 }
2190
2191 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2192 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2193 if let AgentThreadEntry::UserMessage(message) = entry {
2194 if message.id.as_ref() == Some(id) {
2195 Some((ix, message))
2196 } else {
2197 None
2198 }
2199 } else {
2200 None
2201 }
2202 })
2203 }
2204
2205 pub fn read_text_file(
2206 &self,
2207 path: PathBuf,
2208 line: Option<u32>,
2209 limit: Option<u32>,
2210 reuse_shared_snapshot: bool,
2211 cx: &mut Context<Self>,
2212 ) -> Task<Result<String, acp::Error>> {
2213 // Args are 1-based, move to 0-based
2214 let line = line.unwrap_or_default().saturating_sub(1);
2215 let limit = limit.unwrap_or(u32::MAX);
2216 let project = self.project.clone();
2217 let action_log = self.action_log.clone();
2218 cx.spawn(async move |this, cx| {
2219 let load = project.update(cx, |project, cx| {
2220 let path = project
2221 .project_path_for_absolute_path(&path, cx)
2222 .ok_or_else(|| {
2223 acp::Error::resource_not_found(Some(path.display().to_string()))
2224 })?;
2225 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2226 })?;
2227
2228 let buffer = load.await?;
2229
2230 let snapshot = if reuse_shared_snapshot {
2231 this.read_with(cx, |this, _| {
2232 this.shared_buffers.get(&buffer.clone()).cloned()
2233 })
2234 .log_err()
2235 .flatten()
2236 } else {
2237 None
2238 };
2239
2240 let snapshot = if let Some(snapshot) = snapshot {
2241 snapshot
2242 } else {
2243 action_log.update(cx, |action_log, cx| {
2244 action_log.buffer_read(buffer.clone(), cx);
2245 });
2246
2247 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2248 this.update(cx, |this, _| {
2249 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2250 })?;
2251 snapshot
2252 };
2253
2254 let max_point = snapshot.max_point();
2255 let start_position = Point::new(line, 0);
2256
2257 if start_position > max_point {
2258 return Err(acp::Error::invalid_params().data(format!(
2259 "Attempting to read beyond the end of the file, line {}:{}",
2260 max_point.row + 1,
2261 max_point.column
2262 )));
2263 }
2264
2265 let start = snapshot.anchor_before(start_position);
2266 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2267
2268 project.update(cx, |project, cx| {
2269 project.set_agent_location(
2270 Some(AgentLocation {
2271 buffer: buffer.downgrade(),
2272 position: start,
2273 }),
2274 cx,
2275 );
2276 });
2277
2278 Ok(snapshot.text_for_range(start..end).collect::<String>())
2279 })
2280 }
2281
2282 pub fn write_text_file(
2283 &self,
2284 path: PathBuf,
2285 content: String,
2286 cx: &mut Context<Self>,
2287 ) -> Task<Result<()>> {
2288 let project = self.project.clone();
2289 let action_log = self.action_log.clone();
2290 cx.spawn(async move |this, cx| {
2291 let load = project.update(cx, |project, cx| {
2292 let path = project
2293 .project_path_for_absolute_path(&path, cx)
2294 .context("invalid path")?;
2295 anyhow::Ok(project.open_buffer(path, cx))
2296 });
2297 let buffer = load?.await?;
2298 let snapshot = this.update(cx, |this, cx| {
2299 this.shared_buffers
2300 .get(&buffer)
2301 .cloned()
2302 .unwrap_or_else(|| buffer.read(cx).snapshot())
2303 })?;
2304 let edits = cx
2305 .background_executor()
2306 .spawn(async move {
2307 let old_text = snapshot.text();
2308 text_diff(old_text.as_str(), &content)
2309 .into_iter()
2310 .map(|(range, replacement)| {
2311 (
2312 snapshot.anchor_after(range.start)
2313 ..snapshot.anchor_before(range.end),
2314 replacement,
2315 )
2316 })
2317 .collect::<Vec<_>>()
2318 })
2319 .await;
2320
2321 project.update(cx, |project, cx| {
2322 project.set_agent_location(
2323 Some(AgentLocation {
2324 buffer: buffer.downgrade(),
2325 position: edits
2326 .last()
2327 .map(|(range, _)| range.end)
2328 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2329 }),
2330 cx,
2331 );
2332 });
2333
2334 let format_on_save = cx.update(|cx| {
2335 action_log.update(cx, |action_log, cx| {
2336 action_log.buffer_read(buffer.clone(), cx);
2337 });
2338
2339 let format_on_save = buffer.update(cx, |buffer, cx| {
2340 buffer.edit(edits, None, cx);
2341
2342 let settings = language::language_settings::language_settings(
2343 buffer.language().map(|l| l.name()),
2344 buffer.file(),
2345 cx,
2346 );
2347
2348 settings.format_on_save != FormatOnSave::Off
2349 });
2350 action_log.update(cx, |action_log, cx| {
2351 action_log.buffer_edited(buffer.clone(), cx);
2352 });
2353 format_on_save
2354 });
2355
2356 if format_on_save {
2357 let format_task = project.update(cx, |project, cx| {
2358 project.format(
2359 HashSet::from_iter([buffer.clone()]),
2360 LspFormatTarget::Buffers,
2361 false,
2362 FormatTrigger::Save,
2363 cx,
2364 )
2365 });
2366 format_task.await.log_err();
2367
2368 action_log.update(cx, |action_log, cx| {
2369 action_log.buffer_edited(buffer.clone(), cx);
2370 });
2371 }
2372
2373 project
2374 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2375 .await
2376 })
2377 }
2378
2379 pub fn create_terminal(
2380 &self,
2381 command: String,
2382 args: Vec<String>,
2383 extra_env: Vec<acp::EnvVariable>,
2384 cwd: Option<PathBuf>,
2385 output_byte_limit: Option<u64>,
2386 cx: &mut Context<Self>,
2387 ) -> Task<Result<Entity<Terminal>>> {
2388 let env = match &cwd {
2389 Some(dir) => self.project.update(cx, |project, cx| {
2390 project.environment().update(cx, |env, cx| {
2391 env.directory_environment(dir.as_path().into(), cx)
2392 })
2393 }),
2394 None => Task::ready(None).shared(),
2395 };
2396 let env = cx.spawn(async move |_, _| {
2397 let mut env = env.await.unwrap_or_default();
2398 // Disables paging for `git` and hopefully other commands
2399 env.insert("PAGER".into(), "".into());
2400 for var in extra_env {
2401 env.insert(var.name, var.value);
2402 }
2403 env
2404 });
2405
2406 let project = self.project.clone();
2407 let language_registry = project.read(cx).languages().clone();
2408 let is_windows = project.read(cx).path_style(cx).is_windows();
2409
2410 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2411 let terminal_task = cx.spawn({
2412 let terminal_id = terminal_id.clone();
2413 async move |_this, cx| {
2414 let env = env.await;
2415 let shell = project
2416 .update(cx, |project, cx| {
2417 project
2418 .remote_client()
2419 .and_then(|r| r.read(cx).default_system_shell())
2420 })
2421 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2422 let (task_command, task_args) =
2423 ShellBuilder::new(&Shell::Program(shell), is_windows)
2424 .redirect_stdin_to_dev_null()
2425 .build(Some(command.clone()), &args);
2426 let terminal = project
2427 .update(cx, |project, cx| {
2428 project.create_terminal_task(
2429 task::SpawnInTerminal {
2430 command: Some(task_command),
2431 args: task_args,
2432 cwd: cwd.clone(),
2433 env,
2434 ..Default::default()
2435 },
2436 cx,
2437 )
2438 })
2439 .await?;
2440
2441 anyhow::Ok(cx.new(|cx| {
2442 Terminal::new(
2443 terminal_id,
2444 &format!("{} {}", command, args.join(" ")),
2445 cwd,
2446 output_byte_limit.map(|l| l as usize),
2447 terminal,
2448 language_registry,
2449 cx,
2450 )
2451 }))
2452 }
2453 });
2454
2455 cx.spawn(async move |this, cx| {
2456 let terminal = terminal_task.await?;
2457 this.update(cx, |this, _cx| {
2458 this.terminals.insert(terminal_id, terminal.clone());
2459 terminal
2460 })
2461 })
2462 }
2463
2464 pub fn kill_terminal(
2465 &mut self,
2466 terminal_id: acp::TerminalId,
2467 cx: &mut Context<Self>,
2468 ) -> Result<()> {
2469 self.terminals
2470 .get(&terminal_id)
2471 .context("Terminal not found")?
2472 .update(cx, |terminal, cx| {
2473 terminal.kill(cx);
2474 });
2475
2476 Ok(())
2477 }
2478
2479 pub fn release_terminal(
2480 &mut self,
2481 terminal_id: acp::TerminalId,
2482 cx: &mut Context<Self>,
2483 ) -> Result<()> {
2484 self.terminals
2485 .remove(&terminal_id)
2486 .context("Terminal not found")?
2487 .update(cx, |terminal, cx| {
2488 terminal.kill(cx);
2489 });
2490
2491 Ok(())
2492 }
2493
2494 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2495 self.terminals
2496 .get(&terminal_id)
2497 .context("Terminal not found")
2498 .cloned()
2499 }
2500
2501 pub fn to_markdown(&self, cx: &App) -> String {
2502 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2503 }
2504
2505 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2506 cx.emit(AcpThreadEvent::LoadError(error));
2507 }
2508
2509 pub fn register_terminal_created(
2510 &mut self,
2511 terminal_id: acp::TerminalId,
2512 command_label: String,
2513 working_dir: Option<PathBuf>,
2514 output_byte_limit: Option<u64>,
2515 terminal: Entity<::terminal::Terminal>,
2516 cx: &mut Context<Self>,
2517 ) -> Entity<Terminal> {
2518 let language_registry = self.project.read(cx).languages().clone();
2519
2520 let entity = cx.new(|cx| {
2521 Terminal::new(
2522 terminal_id.clone(),
2523 &command_label,
2524 working_dir.clone(),
2525 output_byte_limit.map(|l| l as usize),
2526 terminal,
2527 language_registry,
2528 cx,
2529 )
2530 });
2531 self.terminals.insert(terminal_id.clone(), entity.clone());
2532 entity
2533 }
2534}
2535
2536fn markdown_for_raw_output(
2537 raw_output: &serde_json::Value,
2538 language_registry: &Arc<LanguageRegistry>,
2539 cx: &mut App,
2540) -> Option<Entity<Markdown>> {
2541 match raw_output {
2542 serde_json::Value::Null => None,
2543 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2544 Markdown::new(
2545 value.to_string().into(),
2546 Some(language_registry.clone()),
2547 None,
2548 cx,
2549 )
2550 })),
2551 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2552 Markdown::new(
2553 value.to_string().into(),
2554 Some(language_registry.clone()),
2555 None,
2556 cx,
2557 )
2558 })),
2559 serde_json::Value::String(value) => Some(cx.new(|cx| {
2560 Markdown::new(
2561 value.clone().into(),
2562 Some(language_registry.clone()),
2563 None,
2564 cx,
2565 )
2566 })),
2567 value => Some(cx.new(|cx| {
2568 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2569
2570 Markdown::new(
2571 format!("```json\n{}\n```", pretty_json).into(),
2572 Some(language_registry.clone()),
2573 None,
2574 cx,
2575 )
2576 })),
2577 }
2578}
2579
2580#[cfg(test)]
2581mod tests {
2582 use super::*;
2583 use anyhow::anyhow;
2584 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2585 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2586 use indoc::indoc;
2587 use project::{FakeFs, Fs};
2588 use rand::{distr, prelude::*};
2589 use serde_json::json;
2590 use settings::SettingsStore;
2591 use smol::stream::StreamExt as _;
2592 use std::{
2593 any::Any,
2594 cell::RefCell,
2595 path::Path,
2596 rc::Rc,
2597 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2598 time::Duration,
2599 };
2600 use util::path;
2601
2602 fn init_test(cx: &mut TestAppContext) {
2603 env_logger::try_init().ok();
2604 cx.update(|cx| {
2605 let settings_store = SettingsStore::test(cx);
2606 cx.set_global(settings_store);
2607 });
2608 }
2609
2610 #[gpui::test]
2611 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2612 init_test(cx);
2613
2614 let fs = FakeFs::new(cx.executor());
2615 let project = Project::test(fs, [], cx).await;
2616 let connection = Rc::new(FakeAgentConnection::new());
2617 let thread = cx
2618 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2619 .await
2620 .unwrap();
2621
2622 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2623
2624 // Send Output BEFORE Created - should be buffered by acp_thread
2625 thread.update(cx, |thread, cx| {
2626 thread.on_terminal_provider_event(
2627 TerminalProviderEvent::Output {
2628 terminal_id: terminal_id.clone(),
2629 data: b"hello buffered".to_vec(),
2630 },
2631 cx,
2632 );
2633 });
2634
2635 // Create a display-only terminal and then send Created
2636 let lower = cx.new(|cx| {
2637 let builder = ::terminal::TerminalBuilder::new_display_only(
2638 ::terminal::terminal_settings::CursorShape::default(),
2639 ::terminal::terminal_settings::AlternateScroll::On,
2640 None,
2641 0,
2642 )
2643 .unwrap();
2644 builder.subscribe(cx)
2645 });
2646
2647 thread.update(cx, |thread, cx| {
2648 thread.on_terminal_provider_event(
2649 TerminalProviderEvent::Created {
2650 terminal_id: terminal_id.clone(),
2651 label: "Buffered Test".to_string(),
2652 cwd: None,
2653 output_byte_limit: None,
2654 terminal: lower.clone(),
2655 },
2656 cx,
2657 );
2658 });
2659
2660 // After Created, buffered Output should have been flushed into the renderer
2661 let content = thread.read_with(cx, |thread, cx| {
2662 let term = thread.terminal(terminal_id.clone()).unwrap();
2663 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2664 });
2665
2666 assert!(
2667 content.contains("hello buffered"),
2668 "expected buffered output to render, got: {content}"
2669 );
2670 }
2671
2672 #[gpui::test]
2673 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2674 init_test(cx);
2675
2676 let fs = FakeFs::new(cx.executor());
2677 let project = Project::test(fs, [], cx).await;
2678 let connection = Rc::new(FakeAgentConnection::new());
2679 let thread = cx
2680 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2681 .await
2682 .unwrap();
2683
2684 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2685
2686 // Send Output BEFORE Created
2687 thread.update(cx, |thread, cx| {
2688 thread.on_terminal_provider_event(
2689 TerminalProviderEvent::Output {
2690 terminal_id: terminal_id.clone(),
2691 data: b"pre-exit data".to_vec(),
2692 },
2693 cx,
2694 );
2695 });
2696
2697 // Send Exit BEFORE Created
2698 thread.update(cx, |thread, cx| {
2699 thread.on_terminal_provider_event(
2700 TerminalProviderEvent::Exit {
2701 terminal_id: terminal_id.clone(),
2702 status: acp::TerminalExitStatus::new().exit_code(0),
2703 },
2704 cx,
2705 );
2706 });
2707
2708 // Now create a display-only lower-level terminal and send Created
2709 let lower = cx.new(|cx| {
2710 let builder = ::terminal::TerminalBuilder::new_display_only(
2711 ::terminal::terminal_settings::CursorShape::default(),
2712 ::terminal::terminal_settings::AlternateScroll::On,
2713 None,
2714 0,
2715 )
2716 .unwrap();
2717 builder.subscribe(cx)
2718 });
2719
2720 thread.update(cx, |thread, cx| {
2721 thread.on_terminal_provider_event(
2722 TerminalProviderEvent::Created {
2723 terminal_id: terminal_id.clone(),
2724 label: "Buffered Exit Test".to_string(),
2725 cwd: None,
2726 output_byte_limit: None,
2727 terminal: lower.clone(),
2728 },
2729 cx,
2730 );
2731 });
2732
2733 // Output should be present after Created (flushed from buffer)
2734 let content = thread.read_with(cx, |thread, cx| {
2735 let term = thread.terminal(terminal_id.clone()).unwrap();
2736 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2737 });
2738
2739 assert!(
2740 content.contains("pre-exit data"),
2741 "expected pre-exit data to render, got: {content}"
2742 );
2743 }
2744
2745 /// Test that killing a terminal via Terminal::kill properly:
2746 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2747 /// 2. The underlying terminal still has the output that was written before the kill
2748 ///
2749 /// This test verifies that the fix to kill_active_task (which now also kills
2750 /// the shell process in addition to the foreground process) properly allows
2751 /// wait_for_exit to complete instead of hanging indefinitely.
2752 #[cfg(unix)]
2753 #[gpui::test]
2754 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2755 use std::collections::HashMap;
2756 use task::Shell;
2757 use util::shell_builder::ShellBuilder;
2758
2759 init_test(cx);
2760 cx.executor().allow_parking();
2761
2762 let fs = FakeFs::new(cx.executor());
2763 let project = Project::test(fs, [], cx).await;
2764 let connection = Rc::new(FakeAgentConnection::new());
2765 let thread = cx
2766 .update(|cx| connection.new_thread(project.clone(), Path::new(path!("/test")), cx))
2767 .await
2768 .unwrap();
2769
2770 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2771
2772 // Create a real PTY terminal that runs a command which prints output then sleeps
2773 // We use printf instead of echo and chain with && sleep to ensure proper execution
2774 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2775 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2776 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2777 &[],
2778 );
2779
2780 let builder = cx
2781 .update(|cx| {
2782 ::terminal::TerminalBuilder::new(
2783 None,
2784 None,
2785 task::Shell::WithArguments {
2786 program,
2787 args,
2788 title_override: None,
2789 },
2790 HashMap::default(),
2791 ::terminal::terminal_settings::CursorShape::default(),
2792 ::terminal::terminal_settings::AlternateScroll::On,
2793 None,
2794 vec![],
2795 0,
2796 false,
2797 0,
2798 Some(completion_tx),
2799 cx,
2800 vec![],
2801 )
2802 })
2803 .await
2804 .unwrap();
2805
2806 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2807
2808 // Create the acp_thread Terminal wrapper
2809 thread.update(cx, |thread, cx| {
2810 thread.on_terminal_provider_event(
2811 TerminalProviderEvent::Created {
2812 terminal_id: terminal_id.clone(),
2813 label: "printf output_before_kill && sleep 60".to_string(),
2814 cwd: None,
2815 output_byte_limit: None,
2816 terminal: lower_terminal.clone(),
2817 },
2818 cx,
2819 );
2820 });
2821
2822 // Wait for the printf command to execute and produce output
2823 // Use real time since parking is enabled
2824 cx.executor().timer(Duration::from_millis(500)).await;
2825
2826 // Get the acp_thread Terminal and kill it
2827 let wait_for_exit = thread.update(cx, |thread, cx| {
2828 let term = thread.terminals.get(&terminal_id).unwrap();
2829 let wait_for_exit = term.read(cx).wait_for_exit();
2830 term.update(cx, |term, cx| {
2831 term.kill(cx);
2832 });
2833 wait_for_exit
2834 });
2835
2836 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2837 // Before the fix to kill_active_task, this would hang forever because
2838 // only the foreground process was killed, not the shell, so the PTY
2839 // child never exited and wait_for_completed_task never completed.
2840 let exit_result = futures::select! {
2841 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2842 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(5))) => None,
2843 };
2844
2845 assert!(
2846 exit_result.is_some(),
2847 "wait_for_exit should complete after kill, but it timed out. \
2848 This indicates kill_active_task is not properly killing the shell process."
2849 );
2850
2851 // Give the system a chance to process any pending updates
2852 cx.run_until_parked();
2853
2854 // Verify that the underlying terminal still has the output that was
2855 // written before the kill. This verifies that killing doesn't lose output.
2856 let inner_content = thread.read_with(cx, |thread, cx| {
2857 let term = thread.terminals.get(&terminal_id).unwrap();
2858 term.read(cx).inner().read(cx).get_content()
2859 });
2860
2861 assert!(
2862 inner_content.contains("output_before_kill"),
2863 "Underlying terminal should contain output from before kill, got: {}",
2864 inner_content
2865 );
2866 }
2867
2868 #[gpui::test]
2869 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2870 init_test(cx);
2871
2872 let fs = FakeFs::new(cx.executor());
2873 let project = Project::test(fs, [], cx).await;
2874 let connection = Rc::new(FakeAgentConnection::new());
2875 let thread = cx
2876 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2877 .await
2878 .unwrap();
2879
2880 // Test creating a new user message
2881 thread.update(cx, |thread, cx| {
2882 thread.push_user_content_block(None, "Hello, ".into(), cx);
2883 });
2884
2885 thread.update(cx, |thread, cx| {
2886 assert_eq!(thread.entries.len(), 1);
2887 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2888 assert_eq!(user_msg.id, None);
2889 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2890 } else {
2891 panic!("Expected UserMessage");
2892 }
2893 });
2894
2895 // Test appending to existing user message
2896 let message_1_id = UserMessageId::new();
2897 thread.update(cx, |thread, cx| {
2898 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2899 });
2900
2901 thread.update(cx, |thread, cx| {
2902 assert_eq!(thread.entries.len(), 1);
2903 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2904 assert_eq!(user_msg.id, Some(message_1_id));
2905 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2906 } else {
2907 panic!("Expected UserMessage");
2908 }
2909 });
2910
2911 // Test creating new user message after assistant message
2912 thread.update(cx, |thread, cx| {
2913 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2914 });
2915
2916 let message_2_id = UserMessageId::new();
2917 thread.update(cx, |thread, cx| {
2918 thread.push_user_content_block(
2919 Some(message_2_id.clone()),
2920 "New user message".into(),
2921 cx,
2922 );
2923 });
2924
2925 thread.update(cx, |thread, cx| {
2926 assert_eq!(thread.entries.len(), 3);
2927 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2928 assert_eq!(user_msg.id, Some(message_2_id));
2929 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2930 } else {
2931 panic!("Expected UserMessage at index 2");
2932 }
2933 });
2934 }
2935
2936 #[gpui::test]
2937 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2938 init_test(cx);
2939
2940 let fs = FakeFs::new(cx.executor());
2941 let project = Project::test(fs, [], cx).await;
2942 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2943 |_, thread, mut cx| {
2944 async move {
2945 thread.update(&mut cx, |thread, cx| {
2946 thread
2947 .handle_session_update(
2948 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2949 "Thinking ".into(),
2950 )),
2951 cx,
2952 )
2953 .unwrap();
2954 thread
2955 .handle_session_update(
2956 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2957 "hard!".into(),
2958 )),
2959 cx,
2960 )
2961 .unwrap();
2962 })?;
2963 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2964 }
2965 .boxed_local()
2966 },
2967 ));
2968
2969 let thread = cx
2970 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2971 .await
2972 .unwrap();
2973
2974 thread
2975 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2976 .await
2977 .unwrap();
2978
2979 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2980 assert_eq!(
2981 output,
2982 indoc! {r#"
2983 ## User
2984
2985 Hello from Zed!
2986
2987 ## Assistant
2988
2989 <thinking>
2990 Thinking hard!
2991 </thinking>
2992
2993 "#}
2994 );
2995 }
2996
2997 #[gpui::test]
2998 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2999 init_test(cx);
3000
3001 let fs = FakeFs::new(cx.executor());
3002 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
3003 .await;
3004 let project = Project::test(fs.clone(), [], cx).await;
3005 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
3006 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
3007 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
3008 move |_, thread, mut cx| {
3009 let read_file_tx = read_file_tx.clone();
3010 async move {
3011 let content = thread
3012 .update(&mut cx, |thread, cx| {
3013 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3014 })
3015 .unwrap()
3016 .await
3017 .unwrap();
3018 assert_eq!(content, "one\ntwo\nthree\n");
3019 read_file_tx.take().unwrap().send(()).unwrap();
3020 thread
3021 .update(&mut cx, |thread, cx| {
3022 thread.write_text_file(
3023 path!("/tmp/foo").into(),
3024 "one\ntwo\nthree\nfour\nfive\n".to_string(),
3025 cx,
3026 )
3027 })
3028 .unwrap()
3029 .await
3030 .unwrap();
3031 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3032 }
3033 .boxed_local()
3034 },
3035 ));
3036
3037 let (worktree, pathbuf) = project
3038 .update(cx, |project, cx| {
3039 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3040 })
3041 .await
3042 .unwrap();
3043 let buffer = project
3044 .update(cx, |project, cx| {
3045 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
3046 })
3047 .await
3048 .unwrap();
3049
3050 let thread = cx
3051 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3052 .await
3053 .unwrap();
3054
3055 let request = thread.update(cx, |thread, cx| {
3056 thread.send_raw("Extend the count in /tmp/foo", cx)
3057 });
3058 read_file_rx.await.ok();
3059 buffer.update(cx, |buffer, cx| {
3060 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
3061 });
3062 cx.run_until_parked();
3063 assert_eq!(
3064 buffer.read_with(cx, |buffer, _| buffer.text()),
3065 "zero\none\ntwo\nthree\nfour\nfive\n"
3066 );
3067 assert_eq!(
3068 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
3069 "zero\none\ntwo\nthree\nfour\nfive\n"
3070 );
3071 request.await.unwrap();
3072 }
3073
3074 #[gpui::test]
3075 async fn test_reading_from_line(cx: &mut TestAppContext) {
3076 init_test(cx);
3077
3078 let fs = FakeFs::new(cx.executor());
3079 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
3080 .await;
3081 let project = Project::test(fs.clone(), [], cx).await;
3082 project
3083 .update(cx, |project, cx| {
3084 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3085 })
3086 .await
3087 .unwrap();
3088
3089 let connection = Rc::new(FakeAgentConnection::new());
3090
3091 let thread = cx
3092 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3093 .await
3094 .unwrap();
3095
3096 // Whole file
3097 let content = thread
3098 .update(cx, |thread, cx| {
3099 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3100 })
3101 .await
3102 .unwrap();
3103
3104 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3105
3106 // Only start line
3107 let content = thread
3108 .update(cx, |thread, cx| {
3109 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3110 })
3111 .await
3112 .unwrap();
3113
3114 assert_eq!(content, "three\nfour\n");
3115
3116 // Only limit
3117 let content = thread
3118 .update(cx, |thread, cx| {
3119 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3120 })
3121 .await
3122 .unwrap();
3123
3124 assert_eq!(content, "one\ntwo\n");
3125
3126 // Range
3127 let content = thread
3128 .update(cx, |thread, cx| {
3129 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3130 })
3131 .await
3132 .unwrap();
3133
3134 assert_eq!(content, "two\nthree\n");
3135
3136 // Invalid
3137 let err = thread
3138 .update(cx, |thread, cx| {
3139 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3140 })
3141 .await
3142 .unwrap_err();
3143
3144 assert_eq!(
3145 err.to_string(),
3146 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3147 );
3148 }
3149
3150 #[gpui::test]
3151 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3152 init_test(cx);
3153
3154 let fs = FakeFs::new(cx.executor());
3155 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3156 let project = Project::test(fs.clone(), [], cx).await;
3157 project
3158 .update(cx, |project, cx| {
3159 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3160 })
3161 .await
3162 .unwrap();
3163
3164 let connection = Rc::new(FakeAgentConnection::new());
3165
3166 let thread = cx
3167 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3168 .await
3169 .unwrap();
3170
3171 // Whole file
3172 let content = thread
3173 .update(cx, |thread, cx| {
3174 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3175 })
3176 .await
3177 .unwrap();
3178
3179 assert_eq!(content, "");
3180
3181 // Only start line
3182 let content = thread
3183 .update(cx, |thread, cx| {
3184 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3185 })
3186 .await
3187 .unwrap();
3188
3189 assert_eq!(content, "");
3190
3191 // Only limit
3192 let content = thread
3193 .update(cx, |thread, cx| {
3194 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3195 })
3196 .await
3197 .unwrap();
3198
3199 assert_eq!(content, "");
3200
3201 // Range
3202 let content = thread
3203 .update(cx, |thread, cx| {
3204 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3205 })
3206 .await
3207 .unwrap();
3208
3209 assert_eq!(content, "");
3210
3211 // Invalid
3212 let err = thread
3213 .update(cx, |thread, cx| {
3214 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3215 })
3216 .await
3217 .unwrap_err();
3218
3219 assert_eq!(
3220 err.to_string(),
3221 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3222 );
3223 }
3224 #[gpui::test]
3225 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3226 init_test(cx);
3227
3228 let fs = FakeFs::new(cx.executor());
3229 fs.insert_tree(path!("/tmp"), json!({})).await;
3230 let project = Project::test(fs.clone(), [], cx).await;
3231 project
3232 .update(cx, |project, cx| {
3233 project.find_or_create_worktree(path!("/tmp"), true, cx)
3234 })
3235 .await
3236 .unwrap();
3237
3238 let connection = Rc::new(FakeAgentConnection::new());
3239
3240 let thread = cx
3241 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3242 .await
3243 .unwrap();
3244
3245 // Out of project file
3246 let err = thread
3247 .update(cx, |thread, cx| {
3248 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3249 })
3250 .await
3251 .unwrap_err();
3252
3253 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3254 }
3255
3256 #[gpui::test]
3257 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3258 init_test(cx);
3259
3260 let fs = FakeFs::new(cx.executor());
3261 let project = Project::test(fs, [], cx).await;
3262 let id = acp::ToolCallId::new("test");
3263
3264 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3265 let id = id.clone();
3266 move |_, thread, mut cx| {
3267 let id = id.clone();
3268 async move {
3269 thread
3270 .update(&mut cx, |thread, cx| {
3271 thread.handle_session_update(
3272 acp::SessionUpdate::ToolCall(
3273 acp::ToolCall::new(id.clone(), "Label")
3274 .kind(acp::ToolKind::Fetch)
3275 .status(acp::ToolCallStatus::InProgress),
3276 ),
3277 cx,
3278 )
3279 })
3280 .unwrap()
3281 .unwrap();
3282 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3283 }
3284 .boxed_local()
3285 }
3286 }));
3287
3288 let thread = cx
3289 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3290 .await
3291 .unwrap();
3292
3293 let request = thread.update(cx, |thread, cx| {
3294 thread.send_raw("Fetch https://example.com", cx)
3295 });
3296
3297 run_until_first_tool_call(&thread, cx).await;
3298
3299 thread.read_with(cx, |thread, _| {
3300 assert!(matches!(
3301 thread.entries[1],
3302 AgentThreadEntry::ToolCall(ToolCall {
3303 status: ToolCallStatus::InProgress,
3304 ..
3305 })
3306 ));
3307 });
3308
3309 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3310
3311 thread.read_with(cx, |thread, _| {
3312 assert!(matches!(
3313 &thread.entries[1],
3314 AgentThreadEntry::ToolCall(ToolCall {
3315 status: ToolCallStatus::Canceled,
3316 ..
3317 })
3318 ));
3319 });
3320
3321 thread
3322 .update(cx, |thread, cx| {
3323 thread.handle_session_update(
3324 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3325 id,
3326 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3327 )),
3328 cx,
3329 )
3330 })
3331 .unwrap();
3332
3333 request.await.unwrap();
3334
3335 thread.read_with(cx, |thread, _| {
3336 assert!(matches!(
3337 thread.entries[1],
3338 AgentThreadEntry::ToolCall(ToolCall {
3339 status: ToolCallStatus::Completed,
3340 ..
3341 })
3342 ));
3343 });
3344 }
3345
3346 #[gpui::test]
3347 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3348 init_test(cx);
3349 let fs = FakeFs::new(cx.background_executor.clone());
3350 fs.insert_tree(path!("/test"), json!({})).await;
3351 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3352
3353 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3354 move |_, thread, mut cx| {
3355 async move {
3356 thread
3357 .update(&mut cx, |thread, cx| {
3358 thread.handle_session_update(
3359 acp::SessionUpdate::ToolCall(
3360 acp::ToolCall::new("test", "Label")
3361 .kind(acp::ToolKind::Edit)
3362 .status(acp::ToolCallStatus::Completed)
3363 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3364 "/test/test.txt",
3365 "foo",
3366 ))]),
3367 ),
3368 cx,
3369 )
3370 })
3371 .unwrap()
3372 .unwrap();
3373 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3374 }
3375 .boxed_local()
3376 }
3377 }));
3378
3379 let thread = cx
3380 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3381 .await
3382 .unwrap();
3383
3384 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3385 .await
3386 .unwrap();
3387
3388 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3389 }
3390
3391 #[gpui::test(iterations = 10)]
3392 async fn test_checkpoints(cx: &mut TestAppContext) {
3393 init_test(cx);
3394 let fs = FakeFs::new(cx.background_executor.clone());
3395 fs.insert_tree(
3396 path!("/test"),
3397 json!({
3398 ".git": {}
3399 }),
3400 )
3401 .await;
3402 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3403
3404 let simulate_changes = Arc::new(AtomicBool::new(true));
3405 let next_filename = Arc::new(AtomicUsize::new(0));
3406 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3407 let simulate_changes = simulate_changes.clone();
3408 let next_filename = next_filename.clone();
3409 let fs = fs.clone();
3410 move |request, thread, mut cx| {
3411 let fs = fs.clone();
3412 let simulate_changes = simulate_changes.clone();
3413 let next_filename = next_filename.clone();
3414 async move {
3415 if simulate_changes.load(SeqCst) {
3416 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3417 fs.write(Path::new(&filename), b"").await?;
3418 }
3419
3420 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3421 panic!("expected text content block");
3422 };
3423 thread.update(&mut cx, |thread, cx| {
3424 thread
3425 .handle_session_update(
3426 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3427 content.text.to_uppercase().into(),
3428 )),
3429 cx,
3430 )
3431 .unwrap();
3432 })?;
3433 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3434 }
3435 .boxed_local()
3436 }
3437 }));
3438 let thread = cx
3439 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3440 .await
3441 .unwrap();
3442
3443 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3444 .await
3445 .unwrap();
3446 thread.read_with(cx, |thread, cx| {
3447 assert_eq!(
3448 thread.to_markdown(cx),
3449 indoc! {"
3450 ## User (checkpoint)
3451
3452 Lorem
3453
3454 ## Assistant
3455
3456 LOREM
3457
3458 "}
3459 );
3460 });
3461 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3462
3463 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3464 .await
3465 .unwrap();
3466 thread.read_with(cx, |thread, cx| {
3467 assert_eq!(
3468 thread.to_markdown(cx),
3469 indoc! {"
3470 ## User (checkpoint)
3471
3472 Lorem
3473
3474 ## Assistant
3475
3476 LOREM
3477
3478 ## User (checkpoint)
3479
3480 ipsum
3481
3482 ## Assistant
3483
3484 IPSUM
3485
3486 "}
3487 );
3488 });
3489 assert_eq!(
3490 fs.files(),
3491 vec![
3492 Path::new(path!("/test/file-0")),
3493 Path::new(path!("/test/file-1"))
3494 ]
3495 );
3496
3497 // Checkpoint isn't stored when there are no changes.
3498 simulate_changes.store(false, SeqCst);
3499 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3500 .await
3501 .unwrap();
3502 thread.read_with(cx, |thread, cx| {
3503 assert_eq!(
3504 thread.to_markdown(cx),
3505 indoc! {"
3506 ## User (checkpoint)
3507
3508 Lorem
3509
3510 ## Assistant
3511
3512 LOREM
3513
3514 ## User (checkpoint)
3515
3516 ipsum
3517
3518 ## Assistant
3519
3520 IPSUM
3521
3522 ## User
3523
3524 dolor
3525
3526 ## Assistant
3527
3528 DOLOR
3529
3530 "}
3531 );
3532 });
3533 assert_eq!(
3534 fs.files(),
3535 vec![
3536 Path::new(path!("/test/file-0")),
3537 Path::new(path!("/test/file-1"))
3538 ]
3539 );
3540
3541 // Rewinding the conversation truncates the history and restores the checkpoint.
3542 thread
3543 .update(cx, |thread, cx| {
3544 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3545 panic!("unexpected entries {:?}", thread.entries)
3546 };
3547 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3548 })
3549 .await
3550 .unwrap();
3551 thread.read_with(cx, |thread, cx| {
3552 assert_eq!(
3553 thread.to_markdown(cx),
3554 indoc! {"
3555 ## User (checkpoint)
3556
3557 Lorem
3558
3559 ## Assistant
3560
3561 LOREM
3562
3563 "}
3564 );
3565 });
3566 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3567 }
3568
3569 #[gpui::test]
3570 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3571 use std::sync::atomic::AtomicUsize;
3572 init_test(cx);
3573
3574 let fs = FakeFs::new(cx.executor());
3575 let project = Project::test(fs, None, cx).await;
3576
3577 // Create a connection that simulates refusal after tool result
3578 let prompt_count = Arc::new(AtomicUsize::new(0));
3579 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3580 let prompt_count = prompt_count.clone();
3581 move |_request, thread, mut cx| {
3582 let count = prompt_count.fetch_add(1, SeqCst);
3583 async move {
3584 if count == 0 {
3585 // First prompt: Generate a tool call with result
3586 thread.update(&mut cx, |thread, cx| {
3587 thread
3588 .handle_session_update(
3589 acp::SessionUpdate::ToolCall(
3590 acp::ToolCall::new("tool1", "Test Tool")
3591 .kind(acp::ToolKind::Fetch)
3592 .status(acp::ToolCallStatus::Completed)
3593 .raw_input(serde_json::json!({"query": "test"}))
3594 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3595 ),
3596 cx,
3597 )
3598 .unwrap();
3599 })?;
3600
3601 // Now return refusal because of the tool result
3602 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3603 } else {
3604 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3605 }
3606 }
3607 .boxed_local()
3608 }
3609 }));
3610
3611 let thread = cx
3612 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3613 .await
3614 .unwrap();
3615
3616 // Track if we see a Refusal event
3617 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3618 let saw_refusal_event_captured = saw_refusal_event.clone();
3619 thread.update(cx, |_thread, cx| {
3620 cx.subscribe(
3621 &thread,
3622 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3623 if matches!(event, AcpThreadEvent::Refusal) {
3624 *saw_refusal_event_captured.lock().unwrap() = true;
3625 }
3626 },
3627 )
3628 .detach();
3629 });
3630
3631 // Send a user message - this will trigger tool call and then refusal
3632 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3633 cx.background_executor.spawn(send_task).detach();
3634 cx.run_until_parked();
3635
3636 // Verify that:
3637 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3638 // 2. The user message was NOT truncated
3639 assert!(
3640 *saw_refusal_event.lock().unwrap(),
3641 "Refusal event should be emitted for tool result refusals"
3642 );
3643
3644 thread.read_with(cx, |thread, _| {
3645 let entries = thread.entries();
3646 assert!(entries.len() >= 2, "Should have user message and tool call");
3647
3648 // Verify user message is still there
3649 assert!(
3650 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3651 "User message should not be truncated"
3652 );
3653
3654 // Verify tool call is there with result
3655 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3656 assert!(
3657 tool_call.raw_output.is_some(),
3658 "Tool call should have output"
3659 );
3660 } else {
3661 panic!("Expected tool call at index 1");
3662 }
3663 });
3664 }
3665
3666 #[gpui::test]
3667 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3668 init_test(cx);
3669
3670 let fs = FakeFs::new(cx.executor());
3671 let project = Project::test(fs, None, cx).await;
3672
3673 let refuse_next = Arc::new(AtomicBool::new(false));
3674 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3675 let refuse_next = refuse_next.clone();
3676 move |_request, _thread, _cx| {
3677 if refuse_next.load(SeqCst) {
3678 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3679 .boxed_local()
3680 } else {
3681 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3682 .boxed_local()
3683 }
3684 }
3685 }));
3686
3687 let thread = cx
3688 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3689 .await
3690 .unwrap();
3691
3692 // Track if we see a Refusal event
3693 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3694 let saw_refusal_event_captured = saw_refusal_event.clone();
3695 thread.update(cx, |_thread, cx| {
3696 cx.subscribe(
3697 &thread,
3698 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3699 if matches!(event, AcpThreadEvent::Refusal) {
3700 *saw_refusal_event_captured.lock().unwrap() = true;
3701 }
3702 },
3703 )
3704 .detach();
3705 });
3706
3707 // Send a message that will be refused
3708 refuse_next.store(true, SeqCst);
3709 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3710 .await
3711 .unwrap();
3712
3713 // Verify that a Refusal event WAS emitted for user prompt refusal
3714 assert!(
3715 *saw_refusal_event.lock().unwrap(),
3716 "Refusal event should be emitted for user prompt refusals"
3717 );
3718
3719 // Verify the message was truncated (user prompt refusal)
3720 thread.read_with(cx, |thread, cx| {
3721 assert_eq!(thread.to_markdown(cx), "");
3722 });
3723 }
3724
3725 #[gpui::test]
3726 async fn test_refusal(cx: &mut TestAppContext) {
3727 init_test(cx);
3728 let fs = FakeFs::new(cx.background_executor.clone());
3729 fs.insert_tree(path!("/"), json!({})).await;
3730 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3731
3732 let refuse_next = Arc::new(AtomicBool::new(false));
3733 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3734 let refuse_next = refuse_next.clone();
3735 move |request, thread, mut cx| {
3736 let refuse_next = refuse_next.clone();
3737 async move {
3738 if refuse_next.load(SeqCst) {
3739 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3740 }
3741
3742 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3743 panic!("expected text content block");
3744 };
3745 thread.update(&mut cx, |thread, cx| {
3746 thread
3747 .handle_session_update(
3748 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3749 content.text.to_uppercase().into(),
3750 )),
3751 cx,
3752 )
3753 .unwrap();
3754 })?;
3755 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3756 }
3757 .boxed_local()
3758 }
3759 }));
3760 let thread = cx
3761 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3762 .await
3763 .unwrap();
3764
3765 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3766 .await
3767 .unwrap();
3768 thread.read_with(cx, |thread, cx| {
3769 assert_eq!(
3770 thread.to_markdown(cx),
3771 indoc! {"
3772 ## User
3773
3774 hello
3775
3776 ## Assistant
3777
3778 HELLO
3779
3780 "}
3781 );
3782 });
3783
3784 // Simulate refusing the second message. The message should be truncated
3785 // when a user prompt is refused.
3786 refuse_next.store(true, SeqCst);
3787 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3788 .await
3789 .unwrap();
3790 thread.read_with(cx, |thread, cx| {
3791 assert_eq!(
3792 thread.to_markdown(cx),
3793 indoc! {"
3794 ## User
3795
3796 hello
3797
3798 ## Assistant
3799
3800 HELLO
3801
3802 "}
3803 );
3804 });
3805 }
3806
3807 async fn run_until_first_tool_call(
3808 thread: &Entity<AcpThread>,
3809 cx: &mut TestAppContext,
3810 ) -> usize {
3811 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3812
3813 let subscription = cx.update(|cx| {
3814 cx.subscribe(thread, move |thread, _, cx| {
3815 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3816 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3817 return tx.try_send(ix).unwrap();
3818 }
3819 }
3820 })
3821 });
3822
3823 select! {
3824 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(10))) => {
3825 panic!("Timeout waiting for tool call")
3826 }
3827 ix = rx.next().fuse() => {
3828 drop(subscription);
3829 ix.unwrap()
3830 }
3831 }
3832 }
3833
3834 #[derive(Clone, Default)]
3835 struct FakeAgentConnection {
3836 auth_methods: Vec<acp::AuthMethod>,
3837 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3838 on_user_message: Option<
3839 Rc<
3840 dyn Fn(
3841 acp::PromptRequest,
3842 WeakEntity<AcpThread>,
3843 AsyncApp,
3844 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3845 + 'static,
3846 >,
3847 >,
3848 }
3849
3850 impl FakeAgentConnection {
3851 fn new() -> Self {
3852 Self {
3853 auth_methods: Vec::new(),
3854 on_user_message: None,
3855 sessions: Arc::default(),
3856 }
3857 }
3858
3859 #[expect(unused)]
3860 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3861 self.auth_methods = auth_methods;
3862 self
3863 }
3864
3865 fn on_user_message(
3866 mut self,
3867 handler: impl Fn(
3868 acp::PromptRequest,
3869 WeakEntity<AcpThread>,
3870 AsyncApp,
3871 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3872 + 'static,
3873 ) -> Self {
3874 self.on_user_message.replace(Rc::new(handler));
3875 self
3876 }
3877 }
3878
3879 impl AgentConnection for FakeAgentConnection {
3880 fn telemetry_id(&self) -> SharedString {
3881 "fake".into()
3882 }
3883
3884 fn auth_methods(&self) -> &[acp::AuthMethod] {
3885 &self.auth_methods
3886 }
3887
3888 fn new_thread(
3889 self: Rc<Self>,
3890 project: Entity<Project>,
3891 _cwd: &Path,
3892 cx: &mut App,
3893 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3894 let session_id = acp::SessionId::new(
3895 rand::rng()
3896 .sample_iter(&distr::Alphanumeric)
3897 .take(7)
3898 .map(char::from)
3899 .collect::<String>(),
3900 );
3901 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3902 let thread = cx.new(|cx| {
3903 AcpThread::new(
3904 "Test",
3905 self.clone(),
3906 project,
3907 action_log,
3908 session_id.clone(),
3909 watch::Receiver::constant(
3910 acp::PromptCapabilities::new()
3911 .image(true)
3912 .audio(true)
3913 .embedded_context(true),
3914 ),
3915 cx,
3916 )
3917 });
3918 self.sessions.lock().insert(session_id, thread.downgrade());
3919 Task::ready(Ok(thread))
3920 }
3921
3922 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3923 if self.auth_methods().iter().any(|m| m.id == method) {
3924 Task::ready(Ok(()))
3925 } else {
3926 Task::ready(Err(anyhow!("Invalid Auth Method")))
3927 }
3928 }
3929
3930 fn prompt(
3931 &self,
3932 _id: Option<UserMessageId>,
3933 params: acp::PromptRequest,
3934 cx: &mut App,
3935 ) -> Task<gpui::Result<acp::PromptResponse>> {
3936 let sessions = self.sessions.lock();
3937 let thread = sessions.get(¶ms.session_id).unwrap();
3938 if let Some(handler) = &self.on_user_message {
3939 let handler = handler.clone();
3940 let thread = thread.clone();
3941 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3942 } else {
3943 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3944 }
3945 }
3946
3947 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3948 let sessions = self.sessions.lock();
3949 let thread = sessions.get(session_id).unwrap().clone();
3950
3951 cx.spawn(async move |cx| {
3952 thread
3953 .update(cx, |thread, cx| thread.cancel(cx))
3954 .unwrap()
3955 .await
3956 })
3957 .detach();
3958 }
3959
3960 fn truncate(
3961 &self,
3962 session_id: &acp::SessionId,
3963 _cx: &App,
3964 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3965 Some(Rc::new(FakeAgentSessionEditor {
3966 _session_id: session_id.clone(),
3967 }))
3968 }
3969
3970 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3971 self
3972 }
3973 }
3974
3975 struct FakeAgentSessionEditor {
3976 _session_id: acp::SessionId,
3977 }
3978
3979 impl AgentSessionTruncate for FakeAgentSessionEditor {
3980 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3981 Task::ready(Ok(()))
3982 }
3983 }
3984
3985 #[gpui::test]
3986 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3987 init_test(cx);
3988
3989 let fs = FakeFs::new(cx.executor());
3990 let project = Project::test(fs, [], cx).await;
3991 let connection = Rc::new(FakeAgentConnection::new());
3992 let thread = cx
3993 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3994 .await
3995 .unwrap();
3996
3997 // Try to update a tool call that doesn't exist
3998 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3999 thread.update(cx, |thread, cx| {
4000 let result = thread.handle_session_update(
4001 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
4002 nonexistent_id.clone(),
4003 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
4004 )),
4005 cx,
4006 );
4007
4008 // The update should succeed (not return an error)
4009 assert!(result.is_ok());
4010
4011 // There should now be exactly one entry in the thread
4012 assert_eq!(thread.entries.len(), 1);
4013
4014 // The entry should be a failed tool call
4015 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
4016 assert_eq!(tool_call.id, nonexistent_id);
4017 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
4018 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
4019
4020 // Check that the content contains the error message
4021 assert_eq!(tool_call.content.len(), 1);
4022 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
4023 match content_block {
4024 ContentBlock::Markdown { markdown } => {
4025 let markdown_text = markdown.read(cx).source();
4026 assert!(markdown_text.contains("Tool call not found"));
4027 }
4028 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
4029 ContentBlock::ResourceLink { .. } => {
4030 panic!("Expected markdown content, got resource link")
4031 }
4032 ContentBlock::Image { .. } => {
4033 panic!("Expected markdown content, got image")
4034 }
4035 }
4036 } else {
4037 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
4038 }
4039 } else {
4040 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
4041 }
4042 });
4043 }
4044
4045 /// Tests that restoring a checkpoint properly cleans up terminals that were
4046 /// created after that checkpoint, and cancels any in-progress generation.
4047 ///
4048 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
4049 /// that were started after that checkpoint should be terminated, and any in-progress
4050 /// AI generation should be canceled.
4051 #[gpui::test]
4052 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
4053 init_test(cx);
4054
4055 let fs = FakeFs::new(cx.executor());
4056 let project = Project::test(fs, [], cx).await;
4057 let connection = Rc::new(FakeAgentConnection::new());
4058 let thread = cx
4059 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4060 .await
4061 .unwrap();
4062
4063 // Send first user message to create a checkpoint
4064 cx.update(|cx| {
4065 thread.update(cx, |thread, cx| {
4066 thread.send(vec!["first message".into()], cx)
4067 })
4068 })
4069 .await
4070 .unwrap();
4071
4072 // Send second message (creates another checkpoint) - we'll restore to this one
4073 cx.update(|cx| {
4074 thread.update(cx, |thread, cx| {
4075 thread.send(vec!["second message".into()], cx)
4076 })
4077 })
4078 .await
4079 .unwrap();
4080
4081 // Create 2 terminals BEFORE the checkpoint that have completed running
4082 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4083 let mock_terminal_1 = cx.new(|cx| {
4084 let builder = ::terminal::TerminalBuilder::new_display_only(
4085 ::terminal::terminal_settings::CursorShape::default(),
4086 ::terminal::terminal_settings::AlternateScroll::On,
4087 None,
4088 0,
4089 )
4090 .unwrap();
4091 builder.subscribe(cx)
4092 });
4093
4094 thread.update(cx, |thread, cx| {
4095 thread.on_terminal_provider_event(
4096 TerminalProviderEvent::Created {
4097 terminal_id: terminal_id_1.clone(),
4098 label: "echo 'first'".to_string(),
4099 cwd: Some(PathBuf::from("/test")),
4100 output_byte_limit: None,
4101 terminal: mock_terminal_1.clone(),
4102 },
4103 cx,
4104 );
4105 });
4106
4107 thread.update(cx, |thread, cx| {
4108 thread.on_terminal_provider_event(
4109 TerminalProviderEvent::Output {
4110 terminal_id: terminal_id_1.clone(),
4111 data: b"first\n".to_vec(),
4112 },
4113 cx,
4114 );
4115 });
4116
4117 thread.update(cx, |thread, cx| {
4118 thread.on_terminal_provider_event(
4119 TerminalProviderEvent::Exit {
4120 terminal_id: terminal_id_1.clone(),
4121 status: acp::TerminalExitStatus::new().exit_code(0),
4122 },
4123 cx,
4124 );
4125 });
4126
4127 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4128 let mock_terminal_2 = cx.new(|cx| {
4129 let builder = ::terminal::TerminalBuilder::new_display_only(
4130 ::terminal::terminal_settings::CursorShape::default(),
4131 ::terminal::terminal_settings::AlternateScroll::On,
4132 None,
4133 0,
4134 )
4135 .unwrap();
4136 builder.subscribe(cx)
4137 });
4138
4139 thread.update(cx, |thread, cx| {
4140 thread.on_terminal_provider_event(
4141 TerminalProviderEvent::Created {
4142 terminal_id: terminal_id_2.clone(),
4143 label: "echo 'second'".to_string(),
4144 cwd: Some(PathBuf::from("/test")),
4145 output_byte_limit: None,
4146 terminal: mock_terminal_2.clone(),
4147 },
4148 cx,
4149 );
4150 });
4151
4152 thread.update(cx, |thread, cx| {
4153 thread.on_terminal_provider_event(
4154 TerminalProviderEvent::Output {
4155 terminal_id: terminal_id_2.clone(),
4156 data: b"second\n".to_vec(),
4157 },
4158 cx,
4159 );
4160 });
4161
4162 thread.update(cx, |thread, cx| {
4163 thread.on_terminal_provider_event(
4164 TerminalProviderEvent::Exit {
4165 terminal_id: terminal_id_2.clone(),
4166 status: acp::TerminalExitStatus::new().exit_code(0),
4167 },
4168 cx,
4169 );
4170 });
4171
4172 // Get the second message ID to restore to
4173 let second_message_id = thread.read_with(cx, |thread, _| {
4174 // At this point we have:
4175 // - Index 0: First user message (with checkpoint)
4176 // - Index 1: Second user message (with checkpoint)
4177 // No assistant responses because FakeAgentConnection just returns EndTurn
4178 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4179 panic!("expected user message at index 1");
4180 };
4181 message.id.clone().unwrap()
4182 });
4183
4184 // Create a terminal AFTER the checkpoint we'll restore to.
4185 // This simulates the AI agent starting a long-running terminal command.
4186 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4187 let mock_terminal = cx.new(|cx| {
4188 let builder = ::terminal::TerminalBuilder::new_display_only(
4189 ::terminal::terminal_settings::CursorShape::default(),
4190 ::terminal::terminal_settings::AlternateScroll::On,
4191 None,
4192 0,
4193 )
4194 .unwrap();
4195 builder.subscribe(cx)
4196 });
4197
4198 // Register the terminal as created
4199 thread.update(cx, |thread, cx| {
4200 thread.on_terminal_provider_event(
4201 TerminalProviderEvent::Created {
4202 terminal_id: terminal_id.clone(),
4203 label: "sleep 1000".to_string(),
4204 cwd: Some(PathBuf::from("/test")),
4205 output_byte_limit: None,
4206 terminal: mock_terminal.clone(),
4207 },
4208 cx,
4209 );
4210 });
4211
4212 // Simulate the terminal producing output (still running)
4213 thread.update(cx, |thread, cx| {
4214 thread.on_terminal_provider_event(
4215 TerminalProviderEvent::Output {
4216 terminal_id: terminal_id.clone(),
4217 data: b"terminal is running...\n".to_vec(),
4218 },
4219 cx,
4220 );
4221 });
4222
4223 // Create a tool call entry that references this terminal
4224 // This represents the agent requesting a terminal command
4225 thread.update(cx, |thread, cx| {
4226 thread
4227 .handle_session_update(
4228 acp::SessionUpdate::ToolCall(
4229 acp::ToolCall::new("terminal-tool-1", "Running command")
4230 .kind(acp::ToolKind::Execute)
4231 .status(acp::ToolCallStatus::InProgress)
4232 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4233 terminal_id.clone(),
4234 ))])
4235 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4236 ),
4237 cx,
4238 )
4239 .unwrap();
4240 });
4241
4242 // Verify terminal exists and is in the thread
4243 let terminal_exists_before =
4244 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4245 assert!(
4246 terminal_exists_before,
4247 "Terminal should exist before checkpoint restore"
4248 );
4249
4250 // Verify the terminal's underlying task is still running (not completed)
4251 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4252 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4253 terminal_entity.read_with(cx, |term, _cx| {
4254 term.output().is_none() // output is None means it's still running
4255 })
4256 });
4257 assert!(
4258 terminal_running_before,
4259 "Terminal should be running before checkpoint restore"
4260 );
4261
4262 // Verify we have the expected entries before restore
4263 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4264 assert!(
4265 entry_count_before > 1,
4266 "Should have multiple entries before restore"
4267 );
4268
4269 // Restore the checkpoint to the second message.
4270 // This should:
4271 // 1. Cancel any in-progress generation (via the cancel() call)
4272 // 2. Remove the terminal that was created after that point
4273 thread
4274 .update(cx, |thread, cx| {
4275 thread.restore_checkpoint(second_message_id, cx)
4276 })
4277 .await
4278 .unwrap();
4279
4280 // Verify that no send_task is in progress after restore
4281 // (cancel() clears the send_task)
4282 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4283 assert!(
4284 !has_send_task_after,
4285 "Should not have a send_task after restore (cancel should have cleared it)"
4286 );
4287
4288 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4289 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4290 assert_eq!(
4291 entry_count, 1,
4292 "Should have 1 entry after restore (only the first user message)"
4293 );
4294
4295 // Verify the 2 completed terminals from before the checkpoint still exist
4296 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4297 thread.terminals.contains_key(&terminal_id_1)
4298 });
4299 assert!(
4300 terminal_1_exists,
4301 "Terminal 1 (from before checkpoint) should still exist"
4302 );
4303
4304 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4305 thread.terminals.contains_key(&terminal_id_2)
4306 });
4307 assert!(
4308 terminal_2_exists,
4309 "Terminal 2 (from before checkpoint) should still exist"
4310 );
4311
4312 // Verify they're still in completed state
4313 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4314 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4315 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4316 });
4317 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4318
4319 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4320 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4321 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4322 });
4323 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4324
4325 // Verify the running terminal (created after checkpoint) was removed
4326 let terminal_3_exists =
4327 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4328 assert!(
4329 !terminal_3_exists,
4330 "Terminal 3 (created after checkpoint) should have been removed"
4331 );
4332
4333 // Verify total count is 2 (the two from before the checkpoint)
4334 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4335 assert_eq!(
4336 terminal_count, 2,
4337 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4338 );
4339 }
4340
4341 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4342 /// even when a new user message is added while the async checkpoint comparison is in progress.
4343 ///
4344 /// This is a regression test for a bug where update_last_checkpoint would fail with
4345 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4346 /// update_last_checkpoint started and when its async closure ran.
4347 #[gpui::test]
4348 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4349 init_test(cx);
4350
4351 let fs = FakeFs::new(cx.executor());
4352 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4353 .await;
4354 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4355
4356 let handler_done = Arc::new(AtomicBool::new(false));
4357 let handler_done_clone = handler_done.clone();
4358 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4359 move |_, _thread, _cx| {
4360 handler_done_clone.store(true, SeqCst);
4361 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4362 },
4363 ));
4364
4365 let thread = cx
4366 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4367 .await
4368 .unwrap();
4369
4370 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4371 let send_task = cx.background_executor.spawn(send_future);
4372
4373 // Tick until handler completes, then a few more to let update_last_checkpoint start
4374 while !handler_done.load(SeqCst) {
4375 cx.executor().tick();
4376 }
4377 for _ in 0..5 {
4378 cx.executor().tick();
4379 }
4380
4381 thread.update(cx, |thread, cx| {
4382 thread.push_entry(
4383 AgentThreadEntry::UserMessage(UserMessage {
4384 id: Some(UserMessageId::new()),
4385 content: ContentBlock::Empty,
4386 chunks: vec!["Injected message (no checkpoint)".into()],
4387 checkpoint: None,
4388 indented: false,
4389 }),
4390 cx,
4391 );
4392 });
4393
4394 cx.run_until_parked();
4395 let result = send_task.await;
4396
4397 assert!(
4398 result.is_ok(),
4399 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4400 result.err()
4401 );
4402 }
4403}