1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer.update(cx, |buffer, _| {
379 let snapshot = buffer.snapshot();
380 if let Some(row) = location.line {
381 let column = snapshot.indent_size_for_line(row).len;
382 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
383 snapshot.anchor_before(point)
384 } else {
385 Anchor::min_for_buffer(snapshot.remote_id())
386 }
387 });
388
389 Some(ResolvedLocation { buffer, position })
390 }
391
392 fn resolve_locations(
393 &self,
394 project: Entity<Project>,
395 cx: &mut App,
396 ) -> Task<Vec<Option<ResolvedLocation>>> {
397 let locations = self.locations.clone();
398 project.update(cx, |_, cx| {
399 cx.spawn(async move |project, cx| {
400 let mut new_locations = Vec::new();
401 for location in locations {
402 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
403 }
404 new_locations
405 })
406 })
407 }
408}
409
410// Separate so we can hold a strong reference to the buffer
411// for saving on the thread
412#[derive(Clone, Debug, PartialEq, Eq)]
413struct ResolvedLocation {
414 buffer: Entity<Buffer>,
415 position: Anchor,
416}
417
418impl From<&ResolvedLocation> for AgentLocation {
419 fn from(value: &ResolvedLocation) -> Self {
420 Self {
421 buffer: value.buffer.downgrade(),
422 position: value.position,
423 }
424 }
425}
426
427#[derive(Debug)]
428pub enum ToolCallStatus {
429 /// The tool call hasn't started running yet, but we start showing it to
430 /// the user.
431 Pending,
432 /// The tool call is waiting for confirmation from the user.
433 WaitingForConfirmation {
434 options: Vec<acp::PermissionOption>,
435 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
436 },
437 /// The tool call is currently running.
438 InProgress,
439 /// The tool call completed successfully.
440 Completed,
441 /// The tool call failed.
442 Failed,
443 /// The user rejected the tool call.
444 Rejected,
445 /// The user canceled generation so the tool call was canceled.
446 Canceled,
447}
448
449impl From<acp::ToolCallStatus> for ToolCallStatus {
450 fn from(status: acp::ToolCallStatus) -> Self {
451 match status {
452 acp::ToolCallStatus::Pending => Self::Pending,
453 acp::ToolCallStatus::InProgress => Self::InProgress,
454 acp::ToolCallStatus::Completed => Self::Completed,
455 acp::ToolCallStatus::Failed => Self::Failed,
456 _ => Self::Pending,
457 }
458 }
459}
460
461impl Display for ToolCallStatus {
462 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463 write!(
464 f,
465 "{}",
466 match self {
467 ToolCallStatus::Pending => "Pending",
468 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
469 ToolCallStatus::InProgress => "In Progress",
470 ToolCallStatus::Completed => "Completed",
471 ToolCallStatus::Failed => "Failed",
472 ToolCallStatus::Rejected => "Rejected",
473 ToolCallStatus::Canceled => "Canceled",
474 }
475 )
476 }
477}
478
479#[derive(Debug, PartialEq, Clone)]
480pub enum ContentBlock {
481 Empty,
482 Markdown { markdown: Entity<Markdown> },
483 ResourceLink { resource_link: acp::ResourceLink },
484 Image { image: Arc<gpui::Image> },
485}
486
487impl ContentBlock {
488 pub fn new(
489 block: acp::ContentBlock,
490 language_registry: &Arc<LanguageRegistry>,
491 path_style: PathStyle,
492 cx: &mut App,
493 ) -> Self {
494 let mut this = Self::Empty;
495 this.append(block, language_registry, path_style, cx);
496 this
497 }
498
499 pub fn new_combined(
500 blocks: impl IntoIterator<Item = acp::ContentBlock>,
501 language_registry: Arc<LanguageRegistry>,
502 path_style: PathStyle,
503 cx: &mut App,
504 ) -> Self {
505 let mut this = Self::Empty;
506 for block in blocks {
507 this.append(block, &language_registry, path_style, cx);
508 }
509 this
510 }
511
512 pub fn append(
513 &mut self,
514 block: acp::ContentBlock,
515 language_registry: &Arc<LanguageRegistry>,
516 path_style: PathStyle,
517 cx: &mut App,
518 ) {
519 match (&mut *self, &block) {
520 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
521 *self = ContentBlock::ResourceLink {
522 resource_link: resource_link.clone(),
523 };
524 }
525 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
526 if let Some(image) = Self::decode_image(image_content) {
527 *self = ContentBlock::Image { image };
528 } else {
529 let new_content = Self::image_md(image_content);
530 *self = Self::create_markdown_block(new_content, language_registry, cx);
531 }
532 }
533 (ContentBlock::Empty, _) => {
534 let new_content = Self::block_string_contents(&block, path_style);
535 *self = Self::create_markdown_block(new_content, language_registry, cx);
536 }
537 (ContentBlock::Markdown { markdown }, _) => {
538 let new_content = Self::block_string_contents(&block, path_style);
539 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
540 }
541 (ContentBlock::ResourceLink { resource_link }, _) => {
542 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
543 let new_content = Self::block_string_contents(&block, path_style);
544 let combined = format!("{}\n{}", existing_content, new_content);
545 *self = Self::create_markdown_block(combined, language_registry, cx);
546 }
547 (ContentBlock::Image { .. }, _) => {
548 let new_content = Self::block_string_contents(&block, path_style);
549 let combined = format!("`Image`\n{}", new_content);
550 *self = Self::create_markdown_block(combined, language_registry, cx);
551 }
552 }
553 }
554
555 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
556 use base64::Engine as _;
557
558 let bytes = base64::engine::general_purpose::STANDARD
559 .decode(image_content.data.as_bytes())
560 .ok()?;
561 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
562 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
563 }
564
565 fn create_markdown_block(
566 content: String,
567 language_registry: &Arc<LanguageRegistry>,
568 cx: &mut App,
569 ) -> ContentBlock {
570 ContentBlock::Markdown {
571 markdown: cx
572 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
573 }
574 }
575
576 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
577 match block {
578 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
579 acp::ContentBlock::ResourceLink(resource_link) => {
580 Self::resource_link_md(&resource_link.uri, path_style)
581 }
582 acp::ContentBlock::Resource(acp::EmbeddedResource {
583 resource:
584 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
585 uri,
586 ..
587 }),
588 ..
589 }) => Self::resource_link_md(uri, path_style),
590 acp::ContentBlock::Image(image) => Self::image_md(image),
591 _ => String::new(),
592 }
593 }
594
595 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
596 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
597 uri.as_link().to_string()
598 } else {
599 uri.to_string()
600 }
601 }
602
603 fn image_md(_image: &acp::ImageContent) -> String {
604 "`Image`".into()
605 }
606
607 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
608 match self {
609 ContentBlock::Empty => "",
610 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
611 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
612 ContentBlock::Image { .. } => "`Image`",
613 }
614 }
615
616 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
617 match self {
618 ContentBlock::Empty => None,
619 ContentBlock::Markdown { markdown } => Some(markdown),
620 ContentBlock::ResourceLink { .. } => None,
621 ContentBlock::Image { .. } => None,
622 }
623 }
624
625 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
626 match self {
627 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
628 _ => None,
629 }
630 }
631
632 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
633 match self {
634 ContentBlock::Image { image } => Some(image),
635 _ => None,
636 }
637 }
638}
639
640#[derive(Debug)]
641pub enum ToolCallContent {
642 ContentBlock(ContentBlock),
643 Diff(Entity<Diff>),
644 Terminal(Entity<Terminal>),
645}
646
647impl ToolCallContent {
648 pub fn from_acp(
649 content: acp::ToolCallContent,
650 language_registry: Arc<LanguageRegistry>,
651 path_style: PathStyle,
652 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
653 cx: &mut App,
654 ) -> Result<Option<Self>> {
655 match content {
656 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
657 Ok(Some(Self::ContentBlock(ContentBlock::new(
658 content,
659 &language_registry,
660 path_style,
661 cx,
662 ))))
663 }
664 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
665 Diff::finalized(
666 diff.path.to_string_lossy().into_owned(),
667 diff.old_text,
668 diff.new_text,
669 language_registry,
670 cx,
671 )
672 })))),
673 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
674 .get(&terminal_id)
675 .cloned()
676 .map(|terminal| Some(Self::Terminal(terminal)))
677 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
678 _ => Ok(None),
679 }
680 }
681
682 pub fn update_from_acp(
683 &mut self,
684 new: acp::ToolCallContent,
685 language_registry: Arc<LanguageRegistry>,
686 path_style: PathStyle,
687 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
688 cx: &mut App,
689 ) -> Result<bool> {
690 let needs_update = match (&self, &new) {
691 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
692 old_diff.read(cx).needs_update(
693 new_diff.old_text.as_deref().unwrap_or(""),
694 &new_diff.new_text,
695 cx,
696 )
697 }
698 _ => true,
699 };
700
701 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
702 if needs_update {
703 *self = update;
704 }
705 Ok(true)
706 } else {
707 Ok(false)
708 }
709 }
710
711 pub fn to_markdown(&self, cx: &App) -> String {
712 match self {
713 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
714 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
715 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
716 }
717 }
718
719 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
720 match self {
721 Self::ContentBlock(content) => content.image(),
722 _ => None,
723 }
724 }
725}
726
727#[derive(Debug, PartialEq)]
728pub enum ToolCallUpdate {
729 UpdateFields(acp::ToolCallUpdate),
730 UpdateDiff(ToolCallUpdateDiff),
731 UpdateTerminal(ToolCallUpdateTerminal),
732}
733
734impl ToolCallUpdate {
735 fn id(&self) -> &acp::ToolCallId {
736 match self {
737 Self::UpdateFields(update) => &update.tool_call_id,
738 Self::UpdateDiff(diff) => &diff.id,
739 Self::UpdateTerminal(terminal) => &terminal.id,
740 }
741 }
742}
743
744impl From<acp::ToolCallUpdate> for ToolCallUpdate {
745 fn from(update: acp::ToolCallUpdate) -> Self {
746 Self::UpdateFields(update)
747 }
748}
749
750impl From<ToolCallUpdateDiff> for ToolCallUpdate {
751 fn from(diff: ToolCallUpdateDiff) -> Self {
752 Self::UpdateDiff(diff)
753 }
754}
755
756#[derive(Debug, PartialEq)]
757pub struct ToolCallUpdateDiff {
758 pub id: acp::ToolCallId,
759 pub diff: Entity<Diff>,
760}
761
762impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
763 fn from(terminal: ToolCallUpdateTerminal) -> Self {
764 Self::UpdateTerminal(terminal)
765 }
766}
767
768#[derive(Debug, PartialEq)]
769pub struct ToolCallUpdateTerminal {
770 pub id: acp::ToolCallId,
771 pub terminal: Entity<Terminal>,
772}
773
774#[derive(Debug, Default)]
775pub struct Plan {
776 pub entries: Vec<PlanEntry>,
777}
778
779#[derive(Debug)]
780pub struct PlanStats<'a> {
781 pub in_progress_entry: Option<&'a PlanEntry>,
782 pub pending: u32,
783 pub completed: u32,
784}
785
786impl Plan {
787 pub fn is_empty(&self) -> bool {
788 self.entries.is_empty()
789 }
790
791 pub fn stats(&self) -> PlanStats<'_> {
792 let mut stats = PlanStats {
793 in_progress_entry: None,
794 pending: 0,
795 completed: 0,
796 };
797
798 for entry in &self.entries {
799 match &entry.status {
800 acp::PlanEntryStatus::Pending => {
801 stats.pending += 1;
802 }
803 acp::PlanEntryStatus::InProgress => {
804 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
805 }
806 acp::PlanEntryStatus::Completed => {
807 stats.completed += 1;
808 }
809 _ => {}
810 }
811 }
812
813 stats
814 }
815}
816
817#[derive(Debug)]
818pub struct PlanEntry {
819 pub content: Entity<Markdown>,
820 pub priority: acp::PlanEntryPriority,
821 pub status: acp::PlanEntryStatus,
822}
823
824impl PlanEntry {
825 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
826 Self {
827 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
828 priority: entry.priority,
829 status: entry.status,
830 }
831 }
832}
833
834#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
835pub struct TokenUsage {
836 pub max_tokens: u64,
837 pub used_tokens: u64,
838 pub output_tokens: u64,
839}
840
841impl TokenUsage {
842 pub fn ratio(&self) -> TokenUsageRatio {
843 #[cfg(debug_assertions)]
844 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
845 .unwrap_or("0.8".to_string())
846 .parse()
847 .unwrap();
848 #[cfg(not(debug_assertions))]
849 let warning_threshold: f32 = 0.8;
850
851 // When the maximum is unknown because there is no selected model,
852 // avoid showing the token limit warning.
853 if self.max_tokens == 0 {
854 TokenUsageRatio::Normal
855 } else if self.used_tokens >= self.max_tokens {
856 TokenUsageRatio::Exceeded
857 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
858 TokenUsageRatio::Warning
859 } else {
860 TokenUsageRatio::Normal
861 }
862 }
863}
864
865#[derive(Debug, Clone, PartialEq, Eq)]
866pub enum TokenUsageRatio {
867 Normal,
868 Warning,
869 Exceeded,
870}
871
872#[derive(Debug, Clone)]
873pub struct RetryStatus {
874 pub last_error: SharedString,
875 pub attempt: usize,
876 pub max_attempts: usize,
877 pub started_at: Instant,
878 pub duration: Duration,
879}
880
881pub struct AcpThread {
882 title: SharedString,
883 entries: Vec<AgentThreadEntry>,
884 plan: Plan,
885 project: Entity<Project>,
886 action_log: Entity<ActionLog>,
887 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
888 send_task: Option<Task<()>>,
889 connection: Rc<dyn AgentConnection>,
890 session_id: acp::SessionId,
891 token_usage: Option<TokenUsage>,
892 prompt_capabilities: acp::PromptCapabilities,
893 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
894 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
895 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
896 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
897}
898
899impl From<&AcpThread> for ActionLogTelemetry {
900 fn from(value: &AcpThread) -> Self {
901 Self {
902 agent_telemetry_id: value.connection().telemetry_id(),
903 session_id: value.session_id.0.clone(),
904 }
905 }
906}
907
908#[derive(Debug)]
909pub enum AcpThreadEvent {
910 NewEntry,
911 TitleUpdated,
912 TokenUsageUpdated,
913 EntryUpdated(usize),
914 EntriesRemoved(Range<usize>),
915 ToolAuthorizationRequired,
916 Retry(RetryStatus),
917 Stopped,
918 Error,
919 LoadError(LoadError),
920 PromptCapabilitiesUpdated,
921 Refusal,
922 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
923 ModeUpdated(acp::SessionModeId),
924 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
925}
926
927impl EventEmitter<AcpThreadEvent> for AcpThread {}
928
929#[derive(Debug, Clone)]
930pub enum TerminalProviderEvent {
931 Created {
932 terminal_id: acp::TerminalId,
933 label: String,
934 cwd: Option<PathBuf>,
935 output_byte_limit: Option<u64>,
936 terminal: Entity<::terminal::Terminal>,
937 },
938 Output {
939 terminal_id: acp::TerminalId,
940 data: Vec<u8>,
941 },
942 TitleChanged {
943 terminal_id: acp::TerminalId,
944 title: String,
945 },
946 Exit {
947 terminal_id: acp::TerminalId,
948 status: acp::TerminalExitStatus,
949 },
950}
951
952#[derive(Debug, Clone)]
953pub enum TerminalProviderCommand {
954 WriteInput {
955 terminal_id: acp::TerminalId,
956 bytes: Vec<u8>,
957 },
958 Resize {
959 terminal_id: acp::TerminalId,
960 cols: u16,
961 rows: u16,
962 },
963 Close {
964 terminal_id: acp::TerminalId,
965 },
966}
967
968impl AcpThread {
969 pub fn on_terminal_provider_event(
970 &mut self,
971 event: TerminalProviderEvent,
972 cx: &mut Context<Self>,
973 ) {
974 match event {
975 TerminalProviderEvent::Created {
976 terminal_id,
977 label,
978 cwd,
979 output_byte_limit,
980 terminal,
981 } => {
982 let entity = self.register_terminal_created(
983 terminal_id.clone(),
984 label,
985 cwd,
986 output_byte_limit,
987 terminal,
988 cx,
989 );
990
991 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
992 for data in chunks.drain(..) {
993 entity.update(cx, |term, cx| {
994 term.inner().update(cx, |inner, cx| {
995 inner.write_output(&data, cx);
996 })
997 });
998 }
999 }
1000
1001 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1002 entity.update(cx, |_term, cx| {
1003 cx.notify();
1004 });
1005 }
1006
1007 cx.notify();
1008 }
1009 TerminalProviderEvent::Output { terminal_id, data } => {
1010 if let Some(entity) = self.terminals.get(&terminal_id) {
1011 entity.update(cx, |term, cx| {
1012 term.inner().update(cx, |inner, cx| {
1013 inner.write_output(&data, cx);
1014 })
1015 });
1016 } else {
1017 self.pending_terminal_output
1018 .entry(terminal_id)
1019 .or_default()
1020 .push(data);
1021 }
1022 }
1023 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1024 if let Some(entity) = self.terminals.get(&terminal_id) {
1025 entity.update(cx, |term, cx| {
1026 term.inner().update(cx, |inner, cx| {
1027 inner.breadcrumb_text = title;
1028 cx.emit(::terminal::Event::BreadcrumbsChanged);
1029 })
1030 });
1031 }
1032 }
1033 TerminalProviderEvent::Exit {
1034 terminal_id,
1035 status,
1036 } => {
1037 if let Some(entity) = self.terminals.get(&terminal_id) {
1038 entity.update(cx, |_term, cx| {
1039 cx.notify();
1040 });
1041 } else {
1042 self.pending_terminal_exit.insert(terminal_id, status);
1043 }
1044 }
1045 }
1046 }
1047}
1048
1049#[derive(PartialEq, Eq, Debug)]
1050pub enum ThreadStatus {
1051 Idle,
1052 Generating,
1053}
1054
1055#[derive(Debug, Clone)]
1056pub enum LoadError {
1057 Unsupported {
1058 command: SharedString,
1059 current_version: SharedString,
1060 minimum_version: SharedString,
1061 },
1062 FailedToInstall(SharedString),
1063 Exited {
1064 status: ExitStatus,
1065 },
1066 Other(SharedString),
1067}
1068
1069impl Display for LoadError {
1070 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1071 match self {
1072 LoadError::Unsupported {
1073 command: path,
1074 current_version,
1075 minimum_version,
1076 } => {
1077 write!(
1078 f,
1079 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1080 )
1081 }
1082 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1083 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1084 LoadError::Other(msg) => write!(f, "{msg}"),
1085 }
1086 }
1087}
1088
1089impl Error for LoadError {}
1090
1091impl AcpThread {
1092 pub fn new(
1093 title: impl Into<SharedString>,
1094 connection: Rc<dyn AgentConnection>,
1095 project: Entity<Project>,
1096 action_log: Entity<ActionLog>,
1097 session_id: acp::SessionId,
1098 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1099 cx: &mut Context<Self>,
1100 ) -> Self {
1101 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1102 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1103 loop {
1104 let caps = prompt_capabilities_rx.recv().await?;
1105 this.update(cx, |this, cx| {
1106 this.prompt_capabilities = caps;
1107 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1108 })?;
1109 }
1110 });
1111
1112 Self {
1113 action_log,
1114 shared_buffers: Default::default(),
1115 entries: Default::default(),
1116 plan: Default::default(),
1117 title: title.into(),
1118 project,
1119 send_task: None,
1120 connection,
1121 session_id,
1122 token_usage: None,
1123 prompt_capabilities,
1124 _observe_prompt_capabilities: task,
1125 terminals: HashMap::default(),
1126 pending_terminal_output: HashMap::default(),
1127 pending_terminal_exit: HashMap::default(),
1128 }
1129 }
1130
1131 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1132 self.prompt_capabilities.clone()
1133 }
1134
1135 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1136 &self.connection
1137 }
1138
1139 pub fn action_log(&self) -> &Entity<ActionLog> {
1140 &self.action_log
1141 }
1142
1143 pub fn project(&self) -> &Entity<Project> {
1144 &self.project
1145 }
1146
1147 pub fn title(&self) -> SharedString {
1148 self.title.clone()
1149 }
1150
1151 pub fn entries(&self) -> &[AgentThreadEntry] {
1152 &self.entries
1153 }
1154
1155 pub fn session_id(&self) -> &acp::SessionId {
1156 &self.session_id
1157 }
1158
1159 pub fn status(&self) -> ThreadStatus {
1160 if self.send_task.is_some() {
1161 ThreadStatus::Generating
1162 } else {
1163 ThreadStatus::Idle
1164 }
1165 }
1166
1167 pub fn token_usage(&self) -> Option<&TokenUsage> {
1168 self.token_usage.as_ref()
1169 }
1170
1171 pub fn has_pending_edit_tool_calls(&self) -> bool {
1172 for entry in self.entries.iter().rev() {
1173 match entry {
1174 AgentThreadEntry::UserMessage(_) => return false,
1175 AgentThreadEntry::ToolCall(
1176 call @ ToolCall {
1177 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1178 ..
1179 },
1180 ) if call.diffs().next().is_some() => {
1181 return true;
1182 }
1183 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1184 }
1185 }
1186
1187 false
1188 }
1189
1190 pub fn has_in_progress_tool_calls(&self) -> bool {
1191 for entry in self.entries.iter().rev() {
1192 match entry {
1193 AgentThreadEntry::UserMessage(_) => return false,
1194 AgentThreadEntry::ToolCall(ToolCall {
1195 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1196 ..
1197 }) => {
1198 return true;
1199 }
1200 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1201 }
1202 }
1203
1204 false
1205 }
1206
1207 pub fn used_tools_since_last_user_message(&self) -> bool {
1208 for entry in self.entries.iter().rev() {
1209 match entry {
1210 AgentThreadEntry::UserMessage(..) => return false,
1211 AgentThreadEntry::AssistantMessage(..) => continue,
1212 AgentThreadEntry::ToolCall(..) => return true,
1213 }
1214 }
1215
1216 false
1217 }
1218
1219 pub fn handle_session_update(
1220 &mut self,
1221 update: acp::SessionUpdate,
1222 cx: &mut Context<Self>,
1223 ) -> Result<(), acp::Error> {
1224 match update {
1225 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1226 self.push_user_content_block(None, content, cx);
1227 }
1228 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1229 self.push_assistant_content_block(content, false, cx);
1230 }
1231 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1232 self.push_assistant_content_block(content, true, cx);
1233 }
1234 acp::SessionUpdate::ToolCall(tool_call) => {
1235 self.upsert_tool_call(tool_call, cx)?;
1236 }
1237 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1238 self.update_tool_call(tool_call_update, cx)?;
1239 }
1240 acp::SessionUpdate::Plan(plan) => {
1241 self.update_plan(plan, cx);
1242 }
1243 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1244 available_commands,
1245 ..
1246 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1247 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1248 current_mode_id,
1249 ..
1250 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1251 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1252 config_options,
1253 ..
1254 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1255 _ => {}
1256 }
1257 Ok(())
1258 }
1259
1260 pub fn push_user_content_block(
1261 &mut self,
1262 message_id: Option<UserMessageId>,
1263 chunk: acp::ContentBlock,
1264 cx: &mut Context<Self>,
1265 ) {
1266 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1267 }
1268
1269 pub fn push_user_content_block_with_indent(
1270 &mut self,
1271 message_id: Option<UserMessageId>,
1272 chunk: acp::ContentBlock,
1273 indented: bool,
1274 cx: &mut Context<Self>,
1275 ) {
1276 let language_registry = self.project.read(cx).languages().clone();
1277 let path_style = self.project.read(cx).path_style(cx);
1278 let entries_len = self.entries.len();
1279
1280 if let Some(last_entry) = self.entries.last_mut()
1281 && let AgentThreadEntry::UserMessage(UserMessage {
1282 id,
1283 content,
1284 chunks,
1285 indented: existing_indented,
1286 ..
1287 }) = last_entry
1288 && *existing_indented == indented
1289 {
1290 *id = message_id.or(id.take());
1291 content.append(chunk.clone(), &language_registry, path_style, cx);
1292 chunks.push(chunk);
1293 let idx = entries_len - 1;
1294 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1295 } else {
1296 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1297 self.push_entry(
1298 AgentThreadEntry::UserMessage(UserMessage {
1299 id: message_id,
1300 content,
1301 chunks: vec![chunk],
1302 checkpoint: None,
1303 indented,
1304 }),
1305 cx,
1306 );
1307 }
1308 }
1309
1310 pub fn push_assistant_content_block(
1311 &mut self,
1312 chunk: acp::ContentBlock,
1313 is_thought: bool,
1314 cx: &mut Context<Self>,
1315 ) {
1316 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1317 }
1318
1319 pub fn push_assistant_content_block_with_indent(
1320 &mut self,
1321 chunk: acp::ContentBlock,
1322 is_thought: bool,
1323 indented: bool,
1324 cx: &mut Context<Self>,
1325 ) {
1326 let language_registry = self.project.read(cx).languages().clone();
1327 let path_style = self.project.read(cx).path_style(cx);
1328 let entries_len = self.entries.len();
1329 if let Some(last_entry) = self.entries.last_mut()
1330 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1331 chunks,
1332 indented: existing_indented,
1333 }) = last_entry
1334 && *existing_indented == indented
1335 {
1336 let idx = entries_len - 1;
1337 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1338 match (chunks.last_mut(), is_thought) {
1339 (Some(AssistantMessageChunk::Message { block }), false)
1340 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1341 block.append(chunk, &language_registry, path_style, cx)
1342 }
1343 _ => {
1344 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1345 if is_thought {
1346 chunks.push(AssistantMessageChunk::Thought { block })
1347 } else {
1348 chunks.push(AssistantMessageChunk::Message { block })
1349 }
1350 }
1351 }
1352 } else {
1353 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1354 let chunk = if is_thought {
1355 AssistantMessageChunk::Thought { block }
1356 } else {
1357 AssistantMessageChunk::Message { block }
1358 };
1359
1360 self.push_entry(
1361 AgentThreadEntry::AssistantMessage(AssistantMessage {
1362 chunks: vec![chunk],
1363 indented,
1364 }),
1365 cx,
1366 );
1367 }
1368 }
1369
1370 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1371 self.entries.push(entry);
1372 cx.emit(AcpThreadEvent::NewEntry);
1373 }
1374
1375 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1376 self.connection.set_title(&self.session_id, cx).is_some()
1377 }
1378
1379 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1380 if title != self.title {
1381 self.title = title.clone();
1382 cx.emit(AcpThreadEvent::TitleUpdated);
1383 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1384 return set_title.run(title, cx);
1385 }
1386 }
1387 Task::ready(Ok(()))
1388 }
1389
1390 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1391 self.token_usage = usage;
1392 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1393 }
1394
1395 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1396 cx.emit(AcpThreadEvent::Retry(status));
1397 }
1398
1399 pub fn update_tool_call(
1400 &mut self,
1401 update: impl Into<ToolCallUpdate>,
1402 cx: &mut Context<Self>,
1403 ) -> Result<()> {
1404 let update = update.into();
1405 let languages = self.project.read(cx).languages().clone();
1406 let path_style = self.project.read(cx).path_style(cx);
1407
1408 let ix = match self.index_for_tool_call(update.id()) {
1409 Some(ix) => ix,
1410 None => {
1411 // Tool call not found - create a failed tool call entry
1412 let failed_tool_call = ToolCall {
1413 id: update.id().clone(),
1414 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1415 kind: acp::ToolKind::Fetch,
1416 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1417 "Tool call not found".into(),
1418 &languages,
1419 path_style,
1420 cx,
1421 ))],
1422 status: ToolCallStatus::Failed,
1423 locations: Vec::new(),
1424 resolved_locations: Vec::new(),
1425 raw_input: None,
1426 raw_input_markdown: None,
1427 raw_output: None,
1428 };
1429 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1430 return Ok(());
1431 }
1432 };
1433 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1434 unreachable!()
1435 };
1436
1437 match update {
1438 ToolCallUpdate::UpdateFields(update) => {
1439 let location_updated = update.fields.locations.is_some();
1440 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1441 if location_updated {
1442 self.resolve_locations(update.tool_call_id, cx);
1443 }
1444 }
1445 ToolCallUpdate::UpdateDiff(update) => {
1446 call.content.clear();
1447 call.content.push(ToolCallContent::Diff(update.diff));
1448 }
1449 ToolCallUpdate::UpdateTerminal(update) => {
1450 call.content.clear();
1451 call.content
1452 .push(ToolCallContent::Terminal(update.terminal));
1453 }
1454 }
1455
1456 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1457
1458 Ok(())
1459 }
1460
1461 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1462 pub fn upsert_tool_call(
1463 &mut self,
1464 tool_call: acp::ToolCall,
1465 cx: &mut Context<Self>,
1466 ) -> Result<(), acp::Error> {
1467 let status = tool_call.status.into();
1468 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1469 }
1470
1471 /// Fails if id does not match an existing entry.
1472 pub fn upsert_tool_call_inner(
1473 &mut self,
1474 update: acp::ToolCallUpdate,
1475 status: ToolCallStatus,
1476 cx: &mut Context<Self>,
1477 ) -> Result<(), acp::Error> {
1478 let language_registry = self.project.read(cx).languages().clone();
1479 let path_style = self.project.read(cx).path_style(cx);
1480 let id = update.tool_call_id.clone();
1481
1482 let agent_telemetry_id = self.connection().telemetry_id();
1483 let session = self.session_id();
1484 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1485 let status = if matches!(status, ToolCallStatus::Completed) {
1486 "completed"
1487 } else {
1488 "failed"
1489 };
1490 telemetry::event!(
1491 "Agent Tool Call Completed",
1492 agent_telemetry_id,
1493 session,
1494 status
1495 );
1496 }
1497
1498 if let Some(ix) = self.index_for_tool_call(&id) {
1499 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1500 unreachable!()
1501 };
1502
1503 call.update_fields(
1504 update.fields,
1505 language_registry,
1506 path_style,
1507 &self.terminals,
1508 cx,
1509 )?;
1510 call.status = status;
1511
1512 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1513 } else {
1514 let call = ToolCall::from_acp(
1515 update.try_into()?,
1516 status,
1517 language_registry,
1518 self.project.read(cx).path_style(cx),
1519 &self.terminals,
1520 cx,
1521 )?;
1522 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1523 };
1524
1525 self.resolve_locations(id, cx);
1526 Ok(())
1527 }
1528
1529 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1530 self.entries
1531 .iter()
1532 .enumerate()
1533 .rev()
1534 .find_map(|(index, entry)| {
1535 if let AgentThreadEntry::ToolCall(tool_call) = entry
1536 && &tool_call.id == id
1537 {
1538 Some(index)
1539 } else {
1540 None
1541 }
1542 })
1543 }
1544
1545 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1546 // The tool call we are looking for is typically the last one, or very close to the end.
1547 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1548 self.entries
1549 .iter_mut()
1550 .enumerate()
1551 .rev()
1552 .find_map(|(index, tool_call)| {
1553 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1554 && &tool_call.id == id
1555 {
1556 Some((index, tool_call))
1557 } else {
1558 None
1559 }
1560 })
1561 }
1562
1563 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1564 self.entries
1565 .iter()
1566 .enumerate()
1567 .rev()
1568 .find_map(|(index, tool_call)| {
1569 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1570 && &tool_call.id == id
1571 {
1572 Some((index, tool_call))
1573 } else {
1574 None
1575 }
1576 })
1577 }
1578
1579 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1580 let project = self.project.clone();
1581 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1582 return;
1583 };
1584 let task = tool_call.resolve_locations(project, cx);
1585 cx.spawn(async move |this, cx| {
1586 let resolved_locations = task.await;
1587
1588 this.update(cx, |this, cx| {
1589 let project = this.project.clone();
1590
1591 for location in resolved_locations.iter().flatten() {
1592 this.shared_buffers
1593 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1594 }
1595 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1596 return;
1597 };
1598
1599 if let Some(Some(location)) = resolved_locations.last() {
1600 project.update(cx, |project, cx| {
1601 let should_ignore = if let Some(agent_location) = project
1602 .agent_location()
1603 .filter(|agent_location| agent_location.buffer == location.buffer)
1604 {
1605 let snapshot = location.buffer.read(cx).snapshot();
1606 let old_position = agent_location.position.to_point(&snapshot);
1607 let new_position = location.position.to_point(&snapshot);
1608
1609 // ignore this so that when we get updates from the edit tool
1610 // the position doesn't reset to the startof line
1611 old_position.row == new_position.row
1612 && old_position.column > new_position.column
1613 } else {
1614 false
1615 };
1616 if !should_ignore {
1617 project.set_agent_location(Some(location.into()), cx);
1618 }
1619 });
1620 }
1621
1622 let resolved_locations = resolved_locations
1623 .iter()
1624 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1625 .collect::<Vec<_>>();
1626
1627 if tool_call.resolved_locations != resolved_locations {
1628 tool_call.resolved_locations = resolved_locations;
1629 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1630 }
1631 })
1632 })
1633 .detach();
1634 }
1635
1636 pub fn request_tool_call_authorization(
1637 &mut self,
1638 tool_call: acp::ToolCallUpdate,
1639 options: Vec<acp::PermissionOption>,
1640 respect_always_allow_setting: bool,
1641 cx: &mut Context<Self>,
1642 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1643 let (tx, rx) = oneshot::channel();
1644
1645 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1646 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1647 // some tools would (incorrectly) continue to auto-accept.
1648 if let Some(allow_once_option) = options.iter().find_map(|option| {
1649 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1650 Some(option.option_id.clone())
1651 } else {
1652 None
1653 }
1654 }) {
1655 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1656 return Ok(async {
1657 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1658 allow_once_option,
1659 ))
1660 }
1661 .boxed());
1662 }
1663 }
1664
1665 let status = ToolCallStatus::WaitingForConfirmation {
1666 options,
1667 respond_tx: tx,
1668 };
1669
1670 self.upsert_tool_call_inner(tool_call, status, cx)?;
1671 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1672
1673 let fut = async {
1674 match rx.await {
1675 Ok(option) => acp::RequestPermissionOutcome::Selected(
1676 acp::SelectedPermissionOutcome::new(option),
1677 ),
1678 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1679 }
1680 }
1681 .boxed();
1682
1683 Ok(fut)
1684 }
1685
1686 pub fn authorize_tool_call(
1687 &mut self,
1688 id: acp::ToolCallId,
1689 option_id: acp::PermissionOptionId,
1690 option_kind: acp::PermissionOptionKind,
1691 cx: &mut Context<Self>,
1692 ) {
1693 let Some((ix, call)) = self.tool_call_mut(&id) else {
1694 return;
1695 };
1696
1697 let new_status = match option_kind {
1698 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1699 ToolCallStatus::Rejected
1700 }
1701 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1702 ToolCallStatus::InProgress
1703 }
1704 _ => ToolCallStatus::InProgress,
1705 };
1706
1707 let curr_status = mem::replace(&mut call.status, new_status);
1708
1709 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1710 respond_tx.send(option_id).log_err();
1711 } else if cfg!(debug_assertions) {
1712 panic!("tried to authorize an already authorized tool call");
1713 }
1714
1715 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1716 }
1717
1718 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1719 let mut first_tool_call = None;
1720
1721 for entry in self.entries.iter().rev() {
1722 match &entry {
1723 AgentThreadEntry::ToolCall(call) => {
1724 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1725 first_tool_call = Some(call);
1726 } else {
1727 continue;
1728 }
1729 }
1730 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1731 // Reached the beginning of the turn.
1732 // If we had pending permission requests in the previous turn, they have been cancelled.
1733 break;
1734 }
1735 }
1736 }
1737
1738 first_tool_call
1739 }
1740
1741 pub fn plan(&self) -> &Plan {
1742 &self.plan
1743 }
1744
1745 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1746 let new_entries_len = request.entries.len();
1747 let mut new_entries = request.entries.into_iter();
1748
1749 // Reuse existing markdown to prevent flickering
1750 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1751 let PlanEntry {
1752 content,
1753 priority,
1754 status,
1755 } = old;
1756 content.update(cx, |old, cx| {
1757 old.replace(new.content, cx);
1758 });
1759 *priority = new.priority;
1760 *status = new.status;
1761 }
1762 for new in new_entries {
1763 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1764 }
1765 self.plan.entries.truncate(new_entries_len);
1766
1767 cx.notify();
1768 }
1769
1770 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1771 self.plan
1772 .entries
1773 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1774 cx.notify();
1775 }
1776
1777 #[cfg(any(test, feature = "test-support"))]
1778 pub fn send_raw(
1779 &mut self,
1780 message: &str,
1781 cx: &mut Context<Self>,
1782 ) -> BoxFuture<'static, Result<()>> {
1783 self.send(vec![message.into()], cx)
1784 }
1785
1786 pub fn send(
1787 &mut self,
1788 message: Vec<acp::ContentBlock>,
1789 cx: &mut Context<Self>,
1790 ) -> BoxFuture<'static, Result<()>> {
1791 let block = ContentBlock::new_combined(
1792 message.clone(),
1793 self.project.read(cx).languages().clone(),
1794 self.project.read(cx).path_style(cx),
1795 cx,
1796 );
1797 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1798 let git_store = self.project.read(cx).git_store().clone();
1799
1800 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1801 Some(UserMessageId::new())
1802 } else {
1803 None
1804 };
1805
1806 self.run_turn(cx, async move |this, cx| {
1807 this.update(cx, |this, cx| {
1808 this.push_entry(
1809 AgentThreadEntry::UserMessage(UserMessage {
1810 id: message_id.clone(),
1811 content: block,
1812 chunks: message,
1813 checkpoint: None,
1814 indented: false,
1815 }),
1816 cx,
1817 );
1818 })
1819 .ok();
1820
1821 let old_checkpoint = git_store
1822 .update(cx, |git, cx| git.checkpoint(cx))
1823 .await
1824 .context("failed to get old checkpoint")
1825 .log_err();
1826 this.update(cx, |this, cx| {
1827 if let Some((_ix, message)) = this.last_user_message() {
1828 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1829 git_checkpoint,
1830 show: false,
1831 });
1832 }
1833 this.connection.prompt(message_id, request, cx)
1834 })?
1835 .await
1836 })
1837 }
1838
1839 pub fn can_resume(&self, cx: &App) -> bool {
1840 self.connection.resume(&self.session_id, cx).is_some()
1841 }
1842
1843 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1844 self.run_turn(cx, async move |this, cx| {
1845 this.update(cx, |this, cx| {
1846 this.connection
1847 .resume(&this.session_id, cx)
1848 .map(|resume| resume.run(cx))
1849 })?
1850 .context("resuming a session is not supported")?
1851 .await
1852 })
1853 }
1854
1855 fn run_turn(
1856 &mut self,
1857 cx: &mut Context<Self>,
1858 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1859 ) -> BoxFuture<'static, Result<()>> {
1860 self.clear_completed_plan_entries(cx);
1861
1862 let (tx, rx) = oneshot::channel();
1863 let cancel_task = self.cancel(cx);
1864
1865 self.send_task = Some(cx.spawn(async move |this, cx| {
1866 cancel_task.await;
1867 tx.send(f(this, cx).await).ok();
1868 }));
1869
1870 cx.spawn(async move |this, cx| {
1871 let response = rx.await;
1872
1873 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1874 .await?;
1875
1876 this.update(cx, |this, cx| {
1877 this.project
1878 .update(cx, |project, cx| project.set_agent_location(None, cx));
1879 match response {
1880 Ok(Err(e)) => {
1881 this.send_task.take();
1882 cx.emit(AcpThreadEvent::Error);
1883 Err(e)
1884 }
1885 result => {
1886 let canceled = matches!(
1887 result,
1888 Ok(Ok(acp::PromptResponse {
1889 stop_reason: acp::StopReason::Cancelled,
1890 ..
1891 }))
1892 );
1893
1894 // We only take the task if the current prompt wasn't canceled.
1895 //
1896 // This prompt may have been canceled because another one was sent
1897 // while it was still generating. In these cases, dropping `send_task`
1898 // would cause the next generation to be canceled.
1899 if !canceled {
1900 this.send_task.take();
1901 }
1902
1903 // Handle refusal - distinguish between user prompt and tool call refusals
1904 if let Ok(Ok(acp::PromptResponse {
1905 stop_reason: acp::StopReason::Refusal,
1906 ..
1907 })) = result
1908 {
1909 if let Some((user_msg_ix, _)) = this.last_user_message() {
1910 // Check if there's a completed tool call with results after the last user message
1911 // This indicates the refusal is in response to tool output, not the user's prompt
1912 let has_completed_tool_call_after_user_msg =
1913 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1914 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1915 // Check if the tool call has completed and has output
1916 matches!(tool_call.status, ToolCallStatus::Completed)
1917 && tool_call.raw_output.is_some()
1918 } else {
1919 false
1920 }
1921 });
1922
1923 if has_completed_tool_call_after_user_msg {
1924 // Refusal is due to tool output - don't truncate, just notify
1925 // The model refused based on what the tool returned
1926 cx.emit(AcpThreadEvent::Refusal);
1927 } else {
1928 // User prompt was refused - truncate back to before the user message
1929 let range = user_msg_ix..this.entries.len();
1930 if range.start < range.end {
1931 this.entries.truncate(user_msg_ix);
1932 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1933 }
1934 cx.emit(AcpThreadEvent::Refusal);
1935 }
1936 } else {
1937 // No user message found, treat as general refusal
1938 cx.emit(AcpThreadEvent::Refusal);
1939 }
1940 }
1941
1942 cx.emit(AcpThreadEvent::Stopped);
1943 Ok(())
1944 }
1945 }
1946 })?
1947 })
1948 .boxed()
1949 }
1950
1951 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1952 let Some(send_task) = self.send_task.take() else {
1953 return Task::ready(());
1954 };
1955
1956 for entry in self.entries.iter_mut() {
1957 if let AgentThreadEntry::ToolCall(call) = entry {
1958 let cancel = matches!(
1959 call.status,
1960 ToolCallStatus::Pending
1961 | ToolCallStatus::WaitingForConfirmation { .. }
1962 | ToolCallStatus::InProgress
1963 );
1964
1965 if cancel {
1966 call.status = ToolCallStatus::Canceled;
1967 }
1968 }
1969 }
1970
1971 self.connection.cancel(&self.session_id, cx);
1972
1973 // Wait for the send task to complete
1974 cx.foreground_executor().spawn(send_task)
1975 }
1976
1977 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1978 pub fn restore_checkpoint(
1979 &mut self,
1980 id: UserMessageId,
1981 cx: &mut Context<Self>,
1982 ) -> Task<Result<()>> {
1983 let Some((_, message)) = self.user_message_mut(&id) else {
1984 return Task::ready(Err(anyhow!("message not found")));
1985 };
1986
1987 let checkpoint = message
1988 .checkpoint
1989 .as_ref()
1990 .map(|c| c.git_checkpoint.clone());
1991
1992 // Cancel any in-progress generation before restoring
1993 let cancel_task = self.cancel(cx);
1994 let rewind = self.rewind(id.clone(), cx);
1995 let git_store = self.project.read(cx).git_store().clone();
1996
1997 cx.spawn(async move |_, cx| {
1998 cancel_task.await;
1999 rewind.await?;
2000 if let Some(checkpoint) = checkpoint {
2001 git_store
2002 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2003 .await?;
2004 }
2005
2006 Ok(())
2007 })
2008 }
2009
2010 /// Rewinds this thread to before the entry at `index`, removing it and all
2011 /// subsequent entries while rejecting any action_log changes made from that point.
2012 /// Unlike `restore_checkpoint`, this method does not restore from git.
2013 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2014 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2015 return Task::ready(Err(anyhow!("not supported")));
2016 };
2017
2018 let telemetry = ActionLogTelemetry::from(&*self);
2019 cx.spawn(async move |this, cx| {
2020 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2021 this.update(cx, |this, cx| {
2022 if let Some((ix, _)) = this.user_message_mut(&id) {
2023 // Collect all terminals from entries that will be removed
2024 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2025 .iter()
2026 .flat_map(|entry| entry.terminals())
2027 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2028 .collect();
2029
2030 let range = ix..this.entries.len();
2031 this.entries.truncate(ix);
2032 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2033
2034 // Kill and remove the terminals
2035 for terminal_id in terminals_to_remove {
2036 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2037 terminal.update(cx, |terminal, cx| {
2038 terminal.kill(cx);
2039 });
2040 }
2041 }
2042 }
2043 this.action_log().update(cx, |action_log, cx| {
2044 action_log.reject_all_edits(Some(telemetry), cx)
2045 })
2046 })?
2047 .await;
2048 Ok(())
2049 })
2050 }
2051
2052 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2053 let git_store = self.project.read(cx).git_store().clone();
2054
2055 let Some((_, message)) = self.last_user_message() else {
2056 return Task::ready(Ok(()));
2057 };
2058 let Some(user_message_id) = message.id.clone() else {
2059 return Task::ready(Ok(()));
2060 };
2061 let Some(checkpoint) = message.checkpoint.as_ref() else {
2062 return Task::ready(Ok(()));
2063 };
2064 let old_checkpoint = checkpoint.git_checkpoint.clone();
2065
2066 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2067 cx.spawn(async move |this, cx| {
2068 let Some(new_checkpoint) = new_checkpoint
2069 .await
2070 .context("failed to get new checkpoint")
2071 .log_err()
2072 else {
2073 return Ok(());
2074 };
2075
2076 let equal = git_store
2077 .update(cx, |git, cx| {
2078 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2079 })
2080 .await
2081 .unwrap_or(true);
2082
2083 this.update(cx, |this, cx| {
2084 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2085 if let Some(checkpoint) = message.checkpoint.as_mut() {
2086 checkpoint.show = !equal;
2087 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2088 }
2089 }
2090 })?;
2091
2092 Ok(())
2093 })
2094 }
2095
2096 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2097 self.entries
2098 .iter_mut()
2099 .enumerate()
2100 .rev()
2101 .find_map(|(ix, entry)| {
2102 if let AgentThreadEntry::UserMessage(message) = entry {
2103 Some((ix, message))
2104 } else {
2105 None
2106 }
2107 })
2108 }
2109
2110 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2111 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2112 if let AgentThreadEntry::UserMessage(message) = entry {
2113 if message.id.as_ref() == Some(id) {
2114 Some((ix, message))
2115 } else {
2116 None
2117 }
2118 } else {
2119 None
2120 }
2121 })
2122 }
2123
2124 pub fn read_text_file(
2125 &self,
2126 path: PathBuf,
2127 line: Option<u32>,
2128 limit: Option<u32>,
2129 reuse_shared_snapshot: bool,
2130 cx: &mut Context<Self>,
2131 ) -> Task<Result<String, acp::Error>> {
2132 // Args are 1-based, move to 0-based
2133 let line = line.unwrap_or_default().saturating_sub(1);
2134 let limit = limit.unwrap_or(u32::MAX);
2135 let project = self.project.clone();
2136 let action_log = self.action_log.clone();
2137 cx.spawn(async move |this, cx| {
2138 let load = project.update(cx, |project, cx| {
2139 let path = project
2140 .project_path_for_absolute_path(&path, cx)
2141 .ok_or_else(|| {
2142 acp::Error::resource_not_found(Some(path.display().to_string()))
2143 })?;
2144 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2145 })?;
2146
2147 let buffer = load.await?;
2148
2149 let snapshot = if reuse_shared_snapshot {
2150 this.read_with(cx, |this, _| {
2151 this.shared_buffers.get(&buffer.clone()).cloned()
2152 })
2153 .log_err()
2154 .flatten()
2155 } else {
2156 None
2157 };
2158
2159 let snapshot = if let Some(snapshot) = snapshot {
2160 snapshot
2161 } else {
2162 action_log.update(cx, |action_log, cx| {
2163 action_log.buffer_read(buffer.clone(), cx);
2164 });
2165
2166 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2167 this.update(cx, |this, _| {
2168 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2169 })?;
2170 snapshot
2171 };
2172
2173 let max_point = snapshot.max_point();
2174 let start_position = Point::new(line, 0);
2175
2176 if start_position > max_point {
2177 return Err(acp::Error::invalid_params().data(format!(
2178 "Attempting to read beyond the end of the file, line {}:{}",
2179 max_point.row + 1,
2180 max_point.column
2181 )));
2182 }
2183
2184 let start = snapshot.anchor_before(start_position);
2185 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2186
2187 project.update(cx, |project, cx| {
2188 project.set_agent_location(
2189 Some(AgentLocation {
2190 buffer: buffer.downgrade(),
2191 position: start,
2192 }),
2193 cx,
2194 );
2195 });
2196
2197 Ok(snapshot.text_for_range(start..end).collect::<String>())
2198 })
2199 }
2200
2201 pub fn write_text_file(
2202 &self,
2203 path: PathBuf,
2204 content: String,
2205 cx: &mut Context<Self>,
2206 ) -> Task<Result<()>> {
2207 let project = self.project.clone();
2208 let action_log = self.action_log.clone();
2209 cx.spawn(async move |this, cx| {
2210 let load = project.update(cx, |project, cx| {
2211 let path = project
2212 .project_path_for_absolute_path(&path, cx)
2213 .context("invalid path")?;
2214 anyhow::Ok(project.open_buffer(path, cx))
2215 });
2216 let buffer = load?.await?;
2217 let snapshot = this.update(cx, |this, cx| {
2218 this.shared_buffers
2219 .get(&buffer)
2220 .cloned()
2221 .unwrap_or_else(|| buffer.read(cx).snapshot())
2222 })?;
2223 let edits = cx
2224 .background_executor()
2225 .spawn(async move {
2226 let old_text = snapshot.text();
2227 text_diff(old_text.as_str(), &content)
2228 .into_iter()
2229 .map(|(range, replacement)| {
2230 (
2231 snapshot.anchor_after(range.start)
2232 ..snapshot.anchor_before(range.end),
2233 replacement,
2234 )
2235 })
2236 .collect::<Vec<_>>()
2237 })
2238 .await;
2239
2240 project.update(cx, |project, cx| {
2241 project.set_agent_location(
2242 Some(AgentLocation {
2243 buffer: buffer.downgrade(),
2244 position: edits
2245 .last()
2246 .map(|(range, _)| range.end)
2247 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2248 }),
2249 cx,
2250 );
2251 });
2252
2253 let format_on_save = cx.update(|cx| {
2254 action_log.update(cx, |action_log, cx| {
2255 action_log.buffer_read(buffer.clone(), cx);
2256 });
2257
2258 let format_on_save = buffer.update(cx, |buffer, cx| {
2259 buffer.edit(edits, None, cx);
2260
2261 let settings = language::language_settings::language_settings(
2262 buffer.language().map(|l| l.name()),
2263 buffer.file(),
2264 cx,
2265 );
2266
2267 settings.format_on_save != FormatOnSave::Off
2268 });
2269 action_log.update(cx, |action_log, cx| {
2270 action_log.buffer_edited(buffer.clone(), cx);
2271 });
2272 format_on_save
2273 });
2274
2275 if format_on_save {
2276 let format_task = project.update(cx, |project, cx| {
2277 project.format(
2278 HashSet::from_iter([buffer.clone()]),
2279 LspFormatTarget::Buffers,
2280 false,
2281 FormatTrigger::Save,
2282 cx,
2283 )
2284 });
2285 format_task.await.log_err();
2286
2287 action_log.update(cx, |action_log, cx| {
2288 action_log.buffer_edited(buffer.clone(), cx);
2289 });
2290 }
2291
2292 project
2293 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2294 .await
2295 })
2296 }
2297
2298 pub fn create_terminal(
2299 &self,
2300 command: String,
2301 args: Vec<String>,
2302 extra_env: Vec<acp::EnvVariable>,
2303 cwd: Option<PathBuf>,
2304 output_byte_limit: Option<u64>,
2305 cx: &mut Context<Self>,
2306 ) -> Task<Result<Entity<Terminal>>> {
2307 let env = match &cwd {
2308 Some(dir) => self.project.update(cx, |project, cx| {
2309 project.environment().update(cx, |env, cx| {
2310 env.directory_environment(dir.as_path().into(), cx)
2311 })
2312 }),
2313 None => Task::ready(None).shared(),
2314 };
2315 let env = cx.spawn(async move |_, _| {
2316 let mut env = env.await.unwrap_or_default();
2317 // Disables paging for `git` and hopefully other commands
2318 env.insert("PAGER".into(), "".into());
2319 for var in extra_env {
2320 env.insert(var.name, var.value);
2321 }
2322 env
2323 });
2324
2325 let project = self.project.clone();
2326 let language_registry = project.read(cx).languages().clone();
2327 let is_windows = project.read(cx).path_style(cx).is_windows();
2328
2329 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2330 let terminal_task = cx.spawn({
2331 let terminal_id = terminal_id.clone();
2332 async move |_this, cx| {
2333 let env = env.await;
2334 let shell = project
2335 .update(cx, |project, cx| {
2336 project
2337 .remote_client()
2338 .and_then(|r| r.read(cx).default_system_shell())
2339 })
2340 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2341 let (task_command, task_args) =
2342 ShellBuilder::new(&Shell::Program(shell), is_windows)
2343 .redirect_stdin_to_dev_null()
2344 .build(Some(command.clone()), &args);
2345 let terminal = project
2346 .update(cx, |project, cx| {
2347 project.create_terminal_task(
2348 task::SpawnInTerminal {
2349 command: Some(task_command),
2350 args: task_args,
2351 cwd: cwd.clone(),
2352 env,
2353 ..Default::default()
2354 },
2355 cx,
2356 )
2357 })
2358 .await?;
2359
2360 anyhow::Ok(cx.new(|cx| {
2361 Terminal::new(
2362 terminal_id,
2363 &format!("{} {}", command, args.join(" ")),
2364 cwd,
2365 output_byte_limit.map(|l| l as usize),
2366 terminal,
2367 language_registry,
2368 cx,
2369 )
2370 }))
2371 }
2372 });
2373
2374 cx.spawn(async move |this, cx| {
2375 let terminal = terminal_task.await?;
2376 this.update(cx, |this, _cx| {
2377 this.terminals.insert(terminal_id, terminal.clone());
2378 terminal
2379 })
2380 })
2381 }
2382
2383 pub fn kill_terminal(
2384 &mut self,
2385 terminal_id: acp::TerminalId,
2386 cx: &mut Context<Self>,
2387 ) -> Result<()> {
2388 self.terminals
2389 .get(&terminal_id)
2390 .context("Terminal not found")?
2391 .update(cx, |terminal, cx| {
2392 terminal.kill(cx);
2393 });
2394
2395 Ok(())
2396 }
2397
2398 pub fn release_terminal(
2399 &mut self,
2400 terminal_id: acp::TerminalId,
2401 cx: &mut Context<Self>,
2402 ) -> Result<()> {
2403 self.terminals
2404 .remove(&terminal_id)
2405 .context("Terminal not found")?
2406 .update(cx, |terminal, cx| {
2407 terminal.kill(cx);
2408 });
2409
2410 Ok(())
2411 }
2412
2413 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2414 self.terminals
2415 .get(&terminal_id)
2416 .context("Terminal not found")
2417 .cloned()
2418 }
2419
2420 pub fn to_markdown(&self, cx: &App) -> String {
2421 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2422 }
2423
2424 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2425 cx.emit(AcpThreadEvent::LoadError(error));
2426 }
2427
2428 pub fn register_terminal_created(
2429 &mut self,
2430 terminal_id: acp::TerminalId,
2431 command_label: String,
2432 working_dir: Option<PathBuf>,
2433 output_byte_limit: Option<u64>,
2434 terminal: Entity<::terminal::Terminal>,
2435 cx: &mut Context<Self>,
2436 ) -> Entity<Terminal> {
2437 let language_registry = self.project.read(cx).languages().clone();
2438
2439 let entity = cx.new(|cx| {
2440 Terminal::new(
2441 terminal_id.clone(),
2442 &command_label,
2443 working_dir.clone(),
2444 output_byte_limit.map(|l| l as usize),
2445 terminal,
2446 language_registry,
2447 cx,
2448 )
2449 });
2450 self.terminals.insert(terminal_id.clone(), entity.clone());
2451 entity
2452 }
2453}
2454
2455fn markdown_for_raw_output(
2456 raw_output: &serde_json::Value,
2457 language_registry: &Arc<LanguageRegistry>,
2458 cx: &mut App,
2459) -> Option<Entity<Markdown>> {
2460 match raw_output {
2461 serde_json::Value::Null => None,
2462 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2463 Markdown::new(
2464 value.to_string().into(),
2465 Some(language_registry.clone()),
2466 None,
2467 cx,
2468 )
2469 })),
2470 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2471 Markdown::new(
2472 value.to_string().into(),
2473 Some(language_registry.clone()),
2474 None,
2475 cx,
2476 )
2477 })),
2478 serde_json::Value::String(value) => Some(cx.new(|cx| {
2479 Markdown::new(
2480 value.clone().into(),
2481 Some(language_registry.clone()),
2482 None,
2483 cx,
2484 )
2485 })),
2486 value => Some(cx.new(|cx| {
2487 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2488
2489 Markdown::new(
2490 format!("```json\n{}\n```", pretty_json).into(),
2491 Some(language_registry.clone()),
2492 None,
2493 cx,
2494 )
2495 })),
2496 }
2497}
2498
2499#[cfg(test)]
2500mod tests {
2501 use super::*;
2502 use anyhow::anyhow;
2503 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2504 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2505 use indoc::indoc;
2506 use project::{FakeFs, Fs};
2507 use rand::{distr, prelude::*};
2508 use serde_json::json;
2509 use settings::SettingsStore;
2510 use smol::stream::StreamExt as _;
2511 use std::{
2512 any::Any,
2513 cell::RefCell,
2514 path::Path,
2515 rc::Rc,
2516 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2517 time::Duration,
2518 };
2519 use util::path;
2520
2521 fn init_test(cx: &mut TestAppContext) {
2522 env_logger::try_init().ok();
2523 cx.update(|cx| {
2524 let settings_store = SettingsStore::test(cx);
2525 cx.set_global(settings_store);
2526 });
2527 }
2528
2529 #[gpui::test]
2530 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2531 init_test(cx);
2532
2533 let fs = FakeFs::new(cx.executor());
2534 let project = Project::test(fs, [], cx).await;
2535 let connection = Rc::new(FakeAgentConnection::new());
2536 let thread = cx
2537 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2538 .await
2539 .unwrap();
2540
2541 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2542
2543 // Send Output BEFORE Created - should be buffered by acp_thread
2544 thread.update(cx, |thread, cx| {
2545 thread.on_terminal_provider_event(
2546 TerminalProviderEvent::Output {
2547 terminal_id: terminal_id.clone(),
2548 data: b"hello buffered".to_vec(),
2549 },
2550 cx,
2551 );
2552 });
2553
2554 // Create a display-only terminal and then send Created
2555 let lower = cx.new(|cx| {
2556 let builder = ::terminal::TerminalBuilder::new_display_only(
2557 ::terminal::terminal_settings::CursorShape::default(),
2558 ::terminal::terminal_settings::AlternateScroll::On,
2559 None,
2560 0,
2561 )
2562 .unwrap();
2563 builder.subscribe(cx)
2564 });
2565
2566 thread.update(cx, |thread, cx| {
2567 thread.on_terminal_provider_event(
2568 TerminalProviderEvent::Created {
2569 terminal_id: terminal_id.clone(),
2570 label: "Buffered Test".to_string(),
2571 cwd: None,
2572 output_byte_limit: None,
2573 terminal: lower.clone(),
2574 },
2575 cx,
2576 );
2577 });
2578
2579 // After Created, buffered Output should have been flushed into the renderer
2580 let content = thread.read_with(cx, |thread, cx| {
2581 let term = thread.terminal(terminal_id.clone()).unwrap();
2582 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2583 });
2584
2585 assert!(
2586 content.contains("hello buffered"),
2587 "expected buffered output to render, got: {content}"
2588 );
2589 }
2590
2591 #[gpui::test]
2592 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2593 init_test(cx);
2594
2595 let fs = FakeFs::new(cx.executor());
2596 let project = Project::test(fs, [], cx).await;
2597 let connection = Rc::new(FakeAgentConnection::new());
2598 let thread = cx
2599 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2600 .await
2601 .unwrap();
2602
2603 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2604
2605 // Send Output BEFORE Created
2606 thread.update(cx, |thread, cx| {
2607 thread.on_terminal_provider_event(
2608 TerminalProviderEvent::Output {
2609 terminal_id: terminal_id.clone(),
2610 data: b"pre-exit data".to_vec(),
2611 },
2612 cx,
2613 );
2614 });
2615
2616 // Send Exit BEFORE Created
2617 thread.update(cx, |thread, cx| {
2618 thread.on_terminal_provider_event(
2619 TerminalProviderEvent::Exit {
2620 terminal_id: terminal_id.clone(),
2621 status: acp::TerminalExitStatus::new().exit_code(0),
2622 },
2623 cx,
2624 );
2625 });
2626
2627 // Now create a display-only lower-level terminal and send Created
2628 let lower = cx.new(|cx| {
2629 let builder = ::terminal::TerminalBuilder::new_display_only(
2630 ::terminal::terminal_settings::CursorShape::default(),
2631 ::terminal::terminal_settings::AlternateScroll::On,
2632 None,
2633 0,
2634 )
2635 .unwrap();
2636 builder.subscribe(cx)
2637 });
2638
2639 thread.update(cx, |thread, cx| {
2640 thread.on_terminal_provider_event(
2641 TerminalProviderEvent::Created {
2642 terminal_id: terminal_id.clone(),
2643 label: "Buffered Exit Test".to_string(),
2644 cwd: None,
2645 output_byte_limit: None,
2646 terminal: lower.clone(),
2647 },
2648 cx,
2649 );
2650 });
2651
2652 // Output should be present after Created (flushed from buffer)
2653 let content = thread.read_with(cx, |thread, cx| {
2654 let term = thread.terminal(terminal_id.clone()).unwrap();
2655 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2656 });
2657
2658 assert!(
2659 content.contains("pre-exit data"),
2660 "expected pre-exit data to render, got: {content}"
2661 );
2662 }
2663
2664 /// Test that killing a terminal via Terminal::kill properly:
2665 /// 1. Causes wait_for_exit to complete (doesn't hang forever)
2666 /// 2. The underlying terminal still has the output that was written before the kill
2667 ///
2668 /// This test verifies that the fix to kill_active_task (which now also kills
2669 /// the shell process in addition to the foreground process) properly allows
2670 /// wait_for_exit to complete instead of hanging indefinitely.
2671 #[cfg(unix)]
2672 #[gpui::test]
2673 async fn test_terminal_kill_allows_wait_for_exit_to_complete(cx: &mut gpui::TestAppContext) {
2674 use std::collections::HashMap;
2675 use task::Shell;
2676 use util::shell_builder::ShellBuilder;
2677
2678 init_test(cx);
2679 cx.executor().allow_parking();
2680
2681 let fs = FakeFs::new(cx.executor());
2682 let project = Project::test(fs, [], cx).await;
2683 let connection = Rc::new(FakeAgentConnection::new());
2684 let thread = cx
2685 .update(|cx| connection.new_thread(project.clone(), Path::new(path!("/test")), cx))
2686 .await
2687 .unwrap();
2688
2689 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2690
2691 // Create a real PTY terminal that runs a command which prints output then sleeps
2692 // We use printf instead of echo and chain with && sleep to ensure proper execution
2693 let (completion_tx, _completion_rx) = smol::channel::unbounded();
2694 let (program, args) = ShellBuilder::new(&Shell::System, false).build(
2695 Some("printf 'output_before_kill\\n' && sleep 60".to_owned()),
2696 &[],
2697 );
2698
2699 let builder = cx
2700 .update(|cx| {
2701 ::terminal::TerminalBuilder::new(
2702 None,
2703 None,
2704 task::Shell::WithArguments {
2705 program,
2706 args,
2707 title_override: None,
2708 },
2709 HashMap::default(),
2710 ::terminal::terminal_settings::CursorShape::default(),
2711 ::terminal::terminal_settings::AlternateScroll::On,
2712 None,
2713 vec![],
2714 0,
2715 false,
2716 0,
2717 Some(completion_tx),
2718 cx,
2719 vec![],
2720 )
2721 })
2722 .await
2723 .unwrap();
2724
2725 let lower_terminal = cx.new(|cx| builder.subscribe(cx));
2726
2727 // Create the acp_thread Terminal wrapper
2728 thread.update(cx, |thread, cx| {
2729 thread.on_terminal_provider_event(
2730 TerminalProviderEvent::Created {
2731 terminal_id: terminal_id.clone(),
2732 label: "printf output_before_kill && sleep 60".to_string(),
2733 cwd: None,
2734 output_byte_limit: None,
2735 terminal: lower_terminal.clone(),
2736 },
2737 cx,
2738 );
2739 });
2740
2741 // Wait for the printf command to execute and produce output
2742 smol::Timer::after(Duration::from_millis(500)).await;
2743
2744 // Get the acp_thread Terminal and kill it
2745 let wait_for_exit = thread.update(cx, |thread, cx| {
2746 let term = thread.terminals.get(&terminal_id).unwrap();
2747 let wait_for_exit = term.read(cx).wait_for_exit();
2748 term.update(cx, |term, cx| {
2749 term.kill(cx);
2750 });
2751 wait_for_exit
2752 });
2753
2754 // KEY ASSERTION: wait_for_exit should complete within a reasonable time (not hang).
2755 // Before the fix to kill_active_task, this would hang forever because
2756 // only the foreground process was killed, not the shell, so the PTY
2757 // child never exited and wait_for_completed_task never completed.
2758 let exit_result = futures::select! {
2759 result = futures::FutureExt::fuse(wait_for_exit) => Some(result),
2760 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(5))) => None,
2761 };
2762
2763 assert!(
2764 exit_result.is_some(),
2765 "wait_for_exit should complete after kill, but it timed out. \
2766 This indicates kill_active_task is not properly killing the shell process."
2767 );
2768
2769 // Give the system a chance to process any pending updates
2770 cx.run_until_parked();
2771
2772 // Verify that the underlying terminal still has the output that was
2773 // written before the kill. This verifies that killing doesn't lose output.
2774 let inner_content = thread.read_with(cx, |thread, cx| {
2775 let term = thread.terminals.get(&terminal_id).unwrap();
2776 term.read(cx).inner().read(cx).get_content()
2777 });
2778
2779 assert!(
2780 inner_content.contains("output_before_kill"),
2781 "Underlying terminal should contain output from before kill, got: {}",
2782 inner_content
2783 );
2784 }
2785
2786 #[gpui::test]
2787 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2788 init_test(cx);
2789
2790 let fs = FakeFs::new(cx.executor());
2791 let project = Project::test(fs, [], cx).await;
2792 let connection = Rc::new(FakeAgentConnection::new());
2793 let thread = cx
2794 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2795 .await
2796 .unwrap();
2797
2798 // Test creating a new user message
2799 thread.update(cx, |thread, cx| {
2800 thread.push_user_content_block(None, "Hello, ".into(), cx);
2801 });
2802
2803 thread.update(cx, |thread, cx| {
2804 assert_eq!(thread.entries.len(), 1);
2805 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2806 assert_eq!(user_msg.id, None);
2807 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2808 } else {
2809 panic!("Expected UserMessage");
2810 }
2811 });
2812
2813 // Test appending to existing user message
2814 let message_1_id = UserMessageId::new();
2815 thread.update(cx, |thread, cx| {
2816 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2817 });
2818
2819 thread.update(cx, |thread, cx| {
2820 assert_eq!(thread.entries.len(), 1);
2821 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2822 assert_eq!(user_msg.id, Some(message_1_id));
2823 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2824 } else {
2825 panic!("Expected UserMessage");
2826 }
2827 });
2828
2829 // Test creating new user message after assistant message
2830 thread.update(cx, |thread, cx| {
2831 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2832 });
2833
2834 let message_2_id = UserMessageId::new();
2835 thread.update(cx, |thread, cx| {
2836 thread.push_user_content_block(
2837 Some(message_2_id.clone()),
2838 "New user message".into(),
2839 cx,
2840 );
2841 });
2842
2843 thread.update(cx, |thread, cx| {
2844 assert_eq!(thread.entries.len(), 3);
2845 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2846 assert_eq!(user_msg.id, Some(message_2_id));
2847 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2848 } else {
2849 panic!("Expected UserMessage at index 2");
2850 }
2851 });
2852 }
2853
2854 #[gpui::test]
2855 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2856 init_test(cx);
2857
2858 let fs = FakeFs::new(cx.executor());
2859 let project = Project::test(fs, [], cx).await;
2860 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2861 |_, thread, mut cx| {
2862 async move {
2863 thread.update(&mut cx, |thread, cx| {
2864 thread
2865 .handle_session_update(
2866 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2867 "Thinking ".into(),
2868 )),
2869 cx,
2870 )
2871 .unwrap();
2872 thread
2873 .handle_session_update(
2874 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2875 "hard!".into(),
2876 )),
2877 cx,
2878 )
2879 .unwrap();
2880 })?;
2881 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2882 }
2883 .boxed_local()
2884 },
2885 ));
2886
2887 let thread = cx
2888 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2889 .await
2890 .unwrap();
2891
2892 thread
2893 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2894 .await
2895 .unwrap();
2896
2897 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2898 assert_eq!(
2899 output,
2900 indoc! {r#"
2901 ## User
2902
2903 Hello from Zed!
2904
2905 ## Assistant
2906
2907 <thinking>
2908 Thinking hard!
2909 </thinking>
2910
2911 "#}
2912 );
2913 }
2914
2915 #[gpui::test]
2916 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2917 init_test(cx);
2918
2919 let fs = FakeFs::new(cx.executor());
2920 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2921 .await;
2922 let project = Project::test(fs.clone(), [], cx).await;
2923 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2924 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2925 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2926 move |_, thread, mut cx| {
2927 let read_file_tx = read_file_tx.clone();
2928 async move {
2929 let content = thread
2930 .update(&mut cx, |thread, cx| {
2931 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2932 })
2933 .unwrap()
2934 .await
2935 .unwrap();
2936 assert_eq!(content, "one\ntwo\nthree\n");
2937 read_file_tx.take().unwrap().send(()).unwrap();
2938 thread
2939 .update(&mut cx, |thread, cx| {
2940 thread.write_text_file(
2941 path!("/tmp/foo").into(),
2942 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2943 cx,
2944 )
2945 })
2946 .unwrap()
2947 .await
2948 .unwrap();
2949 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2950 }
2951 .boxed_local()
2952 },
2953 ));
2954
2955 let (worktree, pathbuf) = project
2956 .update(cx, |project, cx| {
2957 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2958 })
2959 .await
2960 .unwrap();
2961 let buffer = project
2962 .update(cx, |project, cx| {
2963 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2964 })
2965 .await
2966 .unwrap();
2967
2968 let thread = cx
2969 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2970 .await
2971 .unwrap();
2972
2973 let request = thread.update(cx, |thread, cx| {
2974 thread.send_raw("Extend the count in /tmp/foo", cx)
2975 });
2976 read_file_rx.await.ok();
2977 buffer.update(cx, |buffer, cx| {
2978 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2979 });
2980 cx.run_until_parked();
2981 assert_eq!(
2982 buffer.read_with(cx, |buffer, _| buffer.text()),
2983 "zero\none\ntwo\nthree\nfour\nfive\n"
2984 );
2985 assert_eq!(
2986 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2987 "zero\none\ntwo\nthree\nfour\nfive\n"
2988 );
2989 request.await.unwrap();
2990 }
2991
2992 #[gpui::test]
2993 async fn test_reading_from_line(cx: &mut TestAppContext) {
2994 init_test(cx);
2995
2996 let fs = FakeFs::new(cx.executor());
2997 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2998 .await;
2999 let project = Project::test(fs.clone(), [], cx).await;
3000 project
3001 .update(cx, |project, cx| {
3002 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3003 })
3004 .await
3005 .unwrap();
3006
3007 let connection = Rc::new(FakeAgentConnection::new());
3008
3009 let thread = cx
3010 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3011 .await
3012 .unwrap();
3013
3014 // Whole file
3015 let content = thread
3016 .update(cx, |thread, cx| {
3017 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3018 })
3019 .await
3020 .unwrap();
3021
3022 assert_eq!(content, "one\ntwo\nthree\nfour\n");
3023
3024 // Only start line
3025 let content = thread
3026 .update(cx, |thread, cx| {
3027 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
3028 })
3029 .await
3030 .unwrap();
3031
3032 assert_eq!(content, "three\nfour\n");
3033
3034 // Only limit
3035 let content = thread
3036 .update(cx, |thread, cx| {
3037 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3038 })
3039 .await
3040 .unwrap();
3041
3042 assert_eq!(content, "one\ntwo\n");
3043
3044 // Range
3045 let content = thread
3046 .update(cx, |thread, cx| {
3047 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
3048 })
3049 .await
3050 .unwrap();
3051
3052 assert_eq!(content, "two\nthree\n");
3053
3054 // Invalid
3055 let err = thread
3056 .update(cx, |thread, cx| {
3057 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
3058 })
3059 .await
3060 .unwrap_err();
3061
3062 assert_eq!(
3063 err.to_string(),
3064 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
3065 );
3066 }
3067
3068 #[gpui::test]
3069 async fn test_reading_empty_file(cx: &mut TestAppContext) {
3070 init_test(cx);
3071
3072 let fs = FakeFs::new(cx.executor());
3073 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
3074 let project = Project::test(fs.clone(), [], cx).await;
3075 project
3076 .update(cx, |project, cx| {
3077 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
3078 })
3079 .await
3080 .unwrap();
3081
3082 let connection = Rc::new(FakeAgentConnection::new());
3083
3084 let thread = cx
3085 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3086 .await
3087 .unwrap();
3088
3089 // Whole file
3090 let content = thread
3091 .update(cx, |thread, cx| {
3092 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
3093 })
3094 .await
3095 .unwrap();
3096
3097 assert_eq!(content, "");
3098
3099 // Only start line
3100 let content = thread
3101 .update(cx, |thread, cx| {
3102 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
3103 })
3104 .await
3105 .unwrap();
3106
3107 assert_eq!(content, "");
3108
3109 // Only limit
3110 let content = thread
3111 .update(cx, |thread, cx| {
3112 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
3113 })
3114 .await
3115 .unwrap();
3116
3117 assert_eq!(content, "");
3118
3119 // Range
3120 let content = thread
3121 .update(cx, |thread, cx| {
3122 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3123 })
3124 .await
3125 .unwrap();
3126
3127 assert_eq!(content, "");
3128
3129 // Invalid
3130 let err = thread
3131 .update(cx, |thread, cx| {
3132 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3133 })
3134 .await
3135 .unwrap_err();
3136
3137 assert_eq!(
3138 err.to_string(),
3139 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3140 );
3141 }
3142 #[gpui::test]
3143 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3144 init_test(cx);
3145
3146 let fs = FakeFs::new(cx.executor());
3147 fs.insert_tree(path!("/tmp"), json!({})).await;
3148 let project = Project::test(fs.clone(), [], cx).await;
3149 project
3150 .update(cx, |project, cx| {
3151 project.find_or_create_worktree(path!("/tmp"), true, cx)
3152 })
3153 .await
3154 .unwrap();
3155
3156 let connection = Rc::new(FakeAgentConnection::new());
3157
3158 let thread = cx
3159 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3160 .await
3161 .unwrap();
3162
3163 // Out of project file
3164 let err = thread
3165 .update(cx, |thread, cx| {
3166 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3167 })
3168 .await
3169 .unwrap_err();
3170
3171 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3172 }
3173
3174 #[gpui::test]
3175 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3176 init_test(cx);
3177
3178 let fs = FakeFs::new(cx.executor());
3179 let project = Project::test(fs, [], cx).await;
3180 let id = acp::ToolCallId::new("test");
3181
3182 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3183 let id = id.clone();
3184 move |_, thread, mut cx| {
3185 let id = id.clone();
3186 async move {
3187 thread
3188 .update(&mut cx, |thread, cx| {
3189 thread.handle_session_update(
3190 acp::SessionUpdate::ToolCall(
3191 acp::ToolCall::new(id.clone(), "Label")
3192 .kind(acp::ToolKind::Fetch)
3193 .status(acp::ToolCallStatus::InProgress),
3194 ),
3195 cx,
3196 )
3197 })
3198 .unwrap()
3199 .unwrap();
3200 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3201 }
3202 .boxed_local()
3203 }
3204 }));
3205
3206 let thread = cx
3207 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3208 .await
3209 .unwrap();
3210
3211 let request = thread.update(cx, |thread, cx| {
3212 thread.send_raw("Fetch https://example.com", cx)
3213 });
3214
3215 run_until_first_tool_call(&thread, cx).await;
3216
3217 thread.read_with(cx, |thread, _| {
3218 assert!(matches!(
3219 thread.entries[1],
3220 AgentThreadEntry::ToolCall(ToolCall {
3221 status: ToolCallStatus::InProgress,
3222 ..
3223 })
3224 ));
3225 });
3226
3227 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3228
3229 thread.read_with(cx, |thread, _| {
3230 assert!(matches!(
3231 &thread.entries[1],
3232 AgentThreadEntry::ToolCall(ToolCall {
3233 status: ToolCallStatus::Canceled,
3234 ..
3235 })
3236 ));
3237 });
3238
3239 thread
3240 .update(cx, |thread, cx| {
3241 thread.handle_session_update(
3242 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3243 id,
3244 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3245 )),
3246 cx,
3247 )
3248 })
3249 .unwrap();
3250
3251 request.await.unwrap();
3252
3253 thread.read_with(cx, |thread, _| {
3254 assert!(matches!(
3255 thread.entries[1],
3256 AgentThreadEntry::ToolCall(ToolCall {
3257 status: ToolCallStatus::Completed,
3258 ..
3259 })
3260 ));
3261 });
3262 }
3263
3264 #[gpui::test]
3265 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3266 init_test(cx);
3267 let fs = FakeFs::new(cx.background_executor.clone());
3268 fs.insert_tree(path!("/test"), json!({})).await;
3269 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3270
3271 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3272 move |_, thread, mut cx| {
3273 async move {
3274 thread
3275 .update(&mut cx, |thread, cx| {
3276 thread.handle_session_update(
3277 acp::SessionUpdate::ToolCall(
3278 acp::ToolCall::new("test", "Label")
3279 .kind(acp::ToolKind::Edit)
3280 .status(acp::ToolCallStatus::Completed)
3281 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3282 "/test/test.txt",
3283 "foo",
3284 ))]),
3285 ),
3286 cx,
3287 )
3288 })
3289 .unwrap()
3290 .unwrap();
3291 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3292 }
3293 .boxed_local()
3294 }
3295 }));
3296
3297 let thread = cx
3298 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3299 .await
3300 .unwrap();
3301
3302 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3303 .await
3304 .unwrap();
3305
3306 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3307 }
3308
3309 #[gpui::test(iterations = 10)]
3310 async fn test_checkpoints(cx: &mut TestAppContext) {
3311 init_test(cx);
3312 let fs = FakeFs::new(cx.background_executor.clone());
3313 fs.insert_tree(
3314 path!("/test"),
3315 json!({
3316 ".git": {}
3317 }),
3318 )
3319 .await;
3320 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3321
3322 let simulate_changes = Arc::new(AtomicBool::new(true));
3323 let next_filename = Arc::new(AtomicUsize::new(0));
3324 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3325 let simulate_changes = simulate_changes.clone();
3326 let next_filename = next_filename.clone();
3327 let fs = fs.clone();
3328 move |request, thread, mut cx| {
3329 let fs = fs.clone();
3330 let simulate_changes = simulate_changes.clone();
3331 let next_filename = next_filename.clone();
3332 async move {
3333 if simulate_changes.load(SeqCst) {
3334 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3335 fs.write(Path::new(&filename), b"").await?;
3336 }
3337
3338 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3339 panic!("expected text content block");
3340 };
3341 thread.update(&mut cx, |thread, cx| {
3342 thread
3343 .handle_session_update(
3344 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3345 content.text.to_uppercase().into(),
3346 )),
3347 cx,
3348 )
3349 .unwrap();
3350 })?;
3351 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3352 }
3353 .boxed_local()
3354 }
3355 }));
3356 let thread = cx
3357 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3358 .await
3359 .unwrap();
3360
3361 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3362 .await
3363 .unwrap();
3364 thread.read_with(cx, |thread, cx| {
3365 assert_eq!(
3366 thread.to_markdown(cx),
3367 indoc! {"
3368 ## User (checkpoint)
3369
3370 Lorem
3371
3372 ## Assistant
3373
3374 LOREM
3375
3376 "}
3377 );
3378 });
3379 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3380
3381 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3382 .await
3383 .unwrap();
3384 thread.read_with(cx, |thread, cx| {
3385 assert_eq!(
3386 thread.to_markdown(cx),
3387 indoc! {"
3388 ## User (checkpoint)
3389
3390 Lorem
3391
3392 ## Assistant
3393
3394 LOREM
3395
3396 ## User (checkpoint)
3397
3398 ipsum
3399
3400 ## Assistant
3401
3402 IPSUM
3403
3404 "}
3405 );
3406 });
3407 assert_eq!(
3408 fs.files(),
3409 vec![
3410 Path::new(path!("/test/file-0")),
3411 Path::new(path!("/test/file-1"))
3412 ]
3413 );
3414
3415 // Checkpoint isn't stored when there are no changes.
3416 simulate_changes.store(false, SeqCst);
3417 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3418 .await
3419 .unwrap();
3420 thread.read_with(cx, |thread, cx| {
3421 assert_eq!(
3422 thread.to_markdown(cx),
3423 indoc! {"
3424 ## User (checkpoint)
3425
3426 Lorem
3427
3428 ## Assistant
3429
3430 LOREM
3431
3432 ## User (checkpoint)
3433
3434 ipsum
3435
3436 ## Assistant
3437
3438 IPSUM
3439
3440 ## User
3441
3442 dolor
3443
3444 ## Assistant
3445
3446 DOLOR
3447
3448 "}
3449 );
3450 });
3451 assert_eq!(
3452 fs.files(),
3453 vec![
3454 Path::new(path!("/test/file-0")),
3455 Path::new(path!("/test/file-1"))
3456 ]
3457 );
3458
3459 // Rewinding the conversation truncates the history and restores the checkpoint.
3460 thread
3461 .update(cx, |thread, cx| {
3462 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3463 panic!("unexpected entries {:?}", thread.entries)
3464 };
3465 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3466 })
3467 .await
3468 .unwrap();
3469 thread.read_with(cx, |thread, cx| {
3470 assert_eq!(
3471 thread.to_markdown(cx),
3472 indoc! {"
3473 ## User (checkpoint)
3474
3475 Lorem
3476
3477 ## Assistant
3478
3479 LOREM
3480
3481 "}
3482 );
3483 });
3484 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3485 }
3486
3487 #[gpui::test]
3488 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3489 use std::sync::atomic::AtomicUsize;
3490 init_test(cx);
3491
3492 let fs = FakeFs::new(cx.executor());
3493 let project = Project::test(fs, None, cx).await;
3494
3495 // Create a connection that simulates refusal after tool result
3496 let prompt_count = Arc::new(AtomicUsize::new(0));
3497 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3498 let prompt_count = prompt_count.clone();
3499 move |_request, thread, mut cx| {
3500 let count = prompt_count.fetch_add(1, SeqCst);
3501 async move {
3502 if count == 0 {
3503 // First prompt: Generate a tool call with result
3504 thread.update(&mut cx, |thread, cx| {
3505 thread
3506 .handle_session_update(
3507 acp::SessionUpdate::ToolCall(
3508 acp::ToolCall::new("tool1", "Test Tool")
3509 .kind(acp::ToolKind::Fetch)
3510 .status(acp::ToolCallStatus::Completed)
3511 .raw_input(serde_json::json!({"query": "test"}))
3512 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3513 ),
3514 cx,
3515 )
3516 .unwrap();
3517 })?;
3518
3519 // Now return refusal because of the tool result
3520 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3521 } else {
3522 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3523 }
3524 }
3525 .boxed_local()
3526 }
3527 }));
3528
3529 let thread = cx
3530 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3531 .await
3532 .unwrap();
3533
3534 // Track if we see a Refusal event
3535 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3536 let saw_refusal_event_captured = saw_refusal_event.clone();
3537 thread.update(cx, |_thread, cx| {
3538 cx.subscribe(
3539 &thread,
3540 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3541 if matches!(event, AcpThreadEvent::Refusal) {
3542 *saw_refusal_event_captured.lock().unwrap() = true;
3543 }
3544 },
3545 )
3546 .detach();
3547 });
3548
3549 // Send a user message - this will trigger tool call and then refusal
3550 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3551 cx.background_executor.spawn(send_task).detach();
3552 cx.run_until_parked();
3553
3554 // Verify that:
3555 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3556 // 2. The user message was NOT truncated
3557 assert!(
3558 *saw_refusal_event.lock().unwrap(),
3559 "Refusal event should be emitted for tool result refusals"
3560 );
3561
3562 thread.read_with(cx, |thread, _| {
3563 let entries = thread.entries();
3564 assert!(entries.len() >= 2, "Should have user message and tool call");
3565
3566 // Verify user message is still there
3567 assert!(
3568 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3569 "User message should not be truncated"
3570 );
3571
3572 // Verify tool call is there with result
3573 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3574 assert!(
3575 tool_call.raw_output.is_some(),
3576 "Tool call should have output"
3577 );
3578 } else {
3579 panic!("Expected tool call at index 1");
3580 }
3581 });
3582 }
3583
3584 #[gpui::test]
3585 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3586 init_test(cx);
3587
3588 let fs = FakeFs::new(cx.executor());
3589 let project = Project::test(fs, None, cx).await;
3590
3591 let refuse_next = Arc::new(AtomicBool::new(false));
3592 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3593 let refuse_next = refuse_next.clone();
3594 move |_request, _thread, _cx| {
3595 if refuse_next.load(SeqCst) {
3596 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3597 .boxed_local()
3598 } else {
3599 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3600 .boxed_local()
3601 }
3602 }
3603 }));
3604
3605 let thread = cx
3606 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3607 .await
3608 .unwrap();
3609
3610 // Track if we see a Refusal event
3611 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3612 let saw_refusal_event_captured = saw_refusal_event.clone();
3613 thread.update(cx, |_thread, cx| {
3614 cx.subscribe(
3615 &thread,
3616 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3617 if matches!(event, AcpThreadEvent::Refusal) {
3618 *saw_refusal_event_captured.lock().unwrap() = true;
3619 }
3620 },
3621 )
3622 .detach();
3623 });
3624
3625 // Send a message that will be refused
3626 refuse_next.store(true, SeqCst);
3627 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3628 .await
3629 .unwrap();
3630
3631 // Verify that a Refusal event WAS emitted for user prompt refusal
3632 assert!(
3633 *saw_refusal_event.lock().unwrap(),
3634 "Refusal event should be emitted for user prompt refusals"
3635 );
3636
3637 // Verify the message was truncated (user prompt refusal)
3638 thread.read_with(cx, |thread, cx| {
3639 assert_eq!(thread.to_markdown(cx), "");
3640 });
3641 }
3642
3643 #[gpui::test]
3644 async fn test_refusal(cx: &mut TestAppContext) {
3645 init_test(cx);
3646 let fs = FakeFs::new(cx.background_executor.clone());
3647 fs.insert_tree(path!("/"), json!({})).await;
3648 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3649
3650 let refuse_next = Arc::new(AtomicBool::new(false));
3651 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3652 let refuse_next = refuse_next.clone();
3653 move |request, thread, mut cx| {
3654 let refuse_next = refuse_next.clone();
3655 async move {
3656 if refuse_next.load(SeqCst) {
3657 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3658 }
3659
3660 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3661 panic!("expected text content block");
3662 };
3663 thread.update(&mut cx, |thread, cx| {
3664 thread
3665 .handle_session_update(
3666 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3667 content.text.to_uppercase().into(),
3668 )),
3669 cx,
3670 )
3671 .unwrap();
3672 })?;
3673 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3674 }
3675 .boxed_local()
3676 }
3677 }));
3678 let thread = cx
3679 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3680 .await
3681 .unwrap();
3682
3683 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3684 .await
3685 .unwrap();
3686 thread.read_with(cx, |thread, cx| {
3687 assert_eq!(
3688 thread.to_markdown(cx),
3689 indoc! {"
3690 ## User
3691
3692 hello
3693
3694 ## Assistant
3695
3696 HELLO
3697
3698 "}
3699 );
3700 });
3701
3702 // Simulate refusing the second message. The message should be truncated
3703 // when a user prompt is refused.
3704 refuse_next.store(true, SeqCst);
3705 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3706 .await
3707 .unwrap();
3708 thread.read_with(cx, |thread, cx| {
3709 assert_eq!(
3710 thread.to_markdown(cx),
3711 indoc! {"
3712 ## User
3713
3714 hello
3715
3716 ## Assistant
3717
3718 HELLO
3719
3720 "}
3721 );
3722 });
3723 }
3724
3725 async fn run_until_first_tool_call(
3726 thread: &Entity<AcpThread>,
3727 cx: &mut TestAppContext,
3728 ) -> usize {
3729 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3730
3731 let subscription = cx.update(|cx| {
3732 cx.subscribe(thread, move |thread, _, cx| {
3733 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3734 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3735 return tx.try_send(ix).unwrap();
3736 }
3737 }
3738 })
3739 });
3740
3741 select! {
3742 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3743 panic!("Timeout waiting for tool call")
3744 }
3745 ix = rx.next().fuse() => {
3746 drop(subscription);
3747 ix.unwrap()
3748 }
3749 }
3750 }
3751
3752 #[derive(Clone, Default)]
3753 struct FakeAgentConnection {
3754 auth_methods: Vec<acp::AuthMethod>,
3755 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3756 on_user_message: Option<
3757 Rc<
3758 dyn Fn(
3759 acp::PromptRequest,
3760 WeakEntity<AcpThread>,
3761 AsyncApp,
3762 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3763 + 'static,
3764 >,
3765 >,
3766 }
3767
3768 impl FakeAgentConnection {
3769 fn new() -> Self {
3770 Self {
3771 auth_methods: Vec::new(),
3772 on_user_message: None,
3773 sessions: Arc::default(),
3774 }
3775 }
3776
3777 #[expect(unused)]
3778 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3779 self.auth_methods = auth_methods;
3780 self
3781 }
3782
3783 fn on_user_message(
3784 mut self,
3785 handler: impl Fn(
3786 acp::PromptRequest,
3787 WeakEntity<AcpThread>,
3788 AsyncApp,
3789 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3790 + 'static,
3791 ) -> Self {
3792 self.on_user_message.replace(Rc::new(handler));
3793 self
3794 }
3795 }
3796
3797 impl AgentConnection for FakeAgentConnection {
3798 fn telemetry_id(&self) -> SharedString {
3799 "fake".into()
3800 }
3801
3802 fn auth_methods(&self) -> &[acp::AuthMethod] {
3803 &self.auth_methods
3804 }
3805
3806 fn new_thread(
3807 self: Rc<Self>,
3808 project: Entity<Project>,
3809 _cwd: &Path,
3810 cx: &mut App,
3811 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3812 let session_id = acp::SessionId::new(
3813 rand::rng()
3814 .sample_iter(&distr::Alphanumeric)
3815 .take(7)
3816 .map(char::from)
3817 .collect::<String>(),
3818 );
3819 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3820 let thread = cx.new(|cx| {
3821 AcpThread::new(
3822 "Test",
3823 self.clone(),
3824 project,
3825 action_log,
3826 session_id.clone(),
3827 watch::Receiver::constant(
3828 acp::PromptCapabilities::new()
3829 .image(true)
3830 .audio(true)
3831 .embedded_context(true),
3832 ),
3833 cx,
3834 )
3835 });
3836 self.sessions.lock().insert(session_id, thread.downgrade());
3837 Task::ready(Ok(thread))
3838 }
3839
3840 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3841 if self.auth_methods().iter().any(|m| m.id == method) {
3842 Task::ready(Ok(()))
3843 } else {
3844 Task::ready(Err(anyhow!("Invalid Auth Method")))
3845 }
3846 }
3847
3848 fn prompt(
3849 &self,
3850 _id: Option<UserMessageId>,
3851 params: acp::PromptRequest,
3852 cx: &mut App,
3853 ) -> Task<gpui::Result<acp::PromptResponse>> {
3854 let sessions = self.sessions.lock();
3855 let thread = sessions.get(¶ms.session_id).unwrap();
3856 if let Some(handler) = &self.on_user_message {
3857 let handler = handler.clone();
3858 let thread = thread.clone();
3859 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3860 } else {
3861 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3862 }
3863 }
3864
3865 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3866 let sessions = self.sessions.lock();
3867 let thread = sessions.get(session_id).unwrap().clone();
3868
3869 cx.spawn(async move |cx| {
3870 thread
3871 .update(cx, |thread, cx| thread.cancel(cx))
3872 .unwrap()
3873 .await
3874 })
3875 .detach();
3876 }
3877
3878 fn truncate(
3879 &self,
3880 session_id: &acp::SessionId,
3881 _cx: &App,
3882 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3883 Some(Rc::new(FakeAgentSessionEditor {
3884 _session_id: session_id.clone(),
3885 }))
3886 }
3887
3888 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3889 self
3890 }
3891 }
3892
3893 struct FakeAgentSessionEditor {
3894 _session_id: acp::SessionId,
3895 }
3896
3897 impl AgentSessionTruncate for FakeAgentSessionEditor {
3898 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3899 Task::ready(Ok(()))
3900 }
3901 }
3902
3903 #[gpui::test]
3904 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3905 init_test(cx);
3906
3907 let fs = FakeFs::new(cx.executor());
3908 let project = Project::test(fs, [], cx).await;
3909 let connection = Rc::new(FakeAgentConnection::new());
3910 let thread = cx
3911 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3912 .await
3913 .unwrap();
3914
3915 // Try to update a tool call that doesn't exist
3916 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3917 thread.update(cx, |thread, cx| {
3918 let result = thread.handle_session_update(
3919 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3920 nonexistent_id.clone(),
3921 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3922 )),
3923 cx,
3924 );
3925
3926 // The update should succeed (not return an error)
3927 assert!(result.is_ok());
3928
3929 // There should now be exactly one entry in the thread
3930 assert_eq!(thread.entries.len(), 1);
3931
3932 // The entry should be a failed tool call
3933 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3934 assert_eq!(tool_call.id, nonexistent_id);
3935 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3936 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3937
3938 // Check that the content contains the error message
3939 assert_eq!(tool_call.content.len(), 1);
3940 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3941 match content_block {
3942 ContentBlock::Markdown { markdown } => {
3943 let markdown_text = markdown.read(cx).source();
3944 assert!(markdown_text.contains("Tool call not found"));
3945 }
3946 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3947 ContentBlock::ResourceLink { .. } => {
3948 panic!("Expected markdown content, got resource link")
3949 }
3950 ContentBlock::Image { .. } => {
3951 panic!("Expected markdown content, got image")
3952 }
3953 }
3954 } else {
3955 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3956 }
3957 } else {
3958 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3959 }
3960 });
3961 }
3962
3963 /// Tests that restoring a checkpoint properly cleans up terminals that were
3964 /// created after that checkpoint, and cancels any in-progress generation.
3965 ///
3966 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3967 /// that were started after that checkpoint should be terminated, and any in-progress
3968 /// AI generation should be canceled.
3969 #[gpui::test]
3970 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3971 init_test(cx);
3972
3973 let fs = FakeFs::new(cx.executor());
3974 let project = Project::test(fs, [], cx).await;
3975 let connection = Rc::new(FakeAgentConnection::new());
3976 let thread = cx
3977 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3978 .await
3979 .unwrap();
3980
3981 // Send first user message to create a checkpoint
3982 cx.update(|cx| {
3983 thread.update(cx, |thread, cx| {
3984 thread.send(vec!["first message".into()], cx)
3985 })
3986 })
3987 .await
3988 .unwrap();
3989
3990 // Send second message (creates another checkpoint) - we'll restore to this one
3991 cx.update(|cx| {
3992 thread.update(cx, |thread, cx| {
3993 thread.send(vec!["second message".into()], cx)
3994 })
3995 })
3996 .await
3997 .unwrap();
3998
3999 // Create 2 terminals BEFORE the checkpoint that have completed running
4000 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4001 let mock_terminal_1 = cx.new(|cx| {
4002 let builder = ::terminal::TerminalBuilder::new_display_only(
4003 ::terminal::terminal_settings::CursorShape::default(),
4004 ::terminal::terminal_settings::AlternateScroll::On,
4005 None,
4006 0,
4007 )
4008 .unwrap();
4009 builder.subscribe(cx)
4010 });
4011
4012 thread.update(cx, |thread, cx| {
4013 thread.on_terminal_provider_event(
4014 TerminalProviderEvent::Created {
4015 terminal_id: terminal_id_1.clone(),
4016 label: "echo 'first'".to_string(),
4017 cwd: Some(PathBuf::from("/test")),
4018 output_byte_limit: None,
4019 terminal: mock_terminal_1.clone(),
4020 },
4021 cx,
4022 );
4023 });
4024
4025 thread.update(cx, |thread, cx| {
4026 thread.on_terminal_provider_event(
4027 TerminalProviderEvent::Output {
4028 terminal_id: terminal_id_1.clone(),
4029 data: b"first\n".to_vec(),
4030 },
4031 cx,
4032 );
4033 });
4034
4035 thread.update(cx, |thread, cx| {
4036 thread.on_terminal_provider_event(
4037 TerminalProviderEvent::Exit {
4038 terminal_id: terminal_id_1.clone(),
4039 status: acp::TerminalExitStatus::new().exit_code(0),
4040 },
4041 cx,
4042 );
4043 });
4044
4045 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4046 let mock_terminal_2 = cx.new(|cx| {
4047 let builder = ::terminal::TerminalBuilder::new_display_only(
4048 ::terminal::terminal_settings::CursorShape::default(),
4049 ::terminal::terminal_settings::AlternateScroll::On,
4050 None,
4051 0,
4052 )
4053 .unwrap();
4054 builder.subscribe(cx)
4055 });
4056
4057 thread.update(cx, |thread, cx| {
4058 thread.on_terminal_provider_event(
4059 TerminalProviderEvent::Created {
4060 terminal_id: terminal_id_2.clone(),
4061 label: "echo 'second'".to_string(),
4062 cwd: Some(PathBuf::from("/test")),
4063 output_byte_limit: None,
4064 terminal: mock_terminal_2.clone(),
4065 },
4066 cx,
4067 );
4068 });
4069
4070 thread.update(cx, |thread, cx| {
4071 thread.on_terminal_provider_event(
4072 TerminalProviderEvent::Output {
4073 terminal_id: terminal_id_2.clone(),
4074 data: b"second\n".to_vec(),
4075 },
4076 cx,
4077 );
4078 });
4079
4080 thread.update(cx, |thread, cx| {
4081 thread.on_terminal_provider_event(
4082 TerminalProviderEvent::Exit {
4083 terminal_id: terminal_id_2.clone(),
4084 status: acp::TerminalExitStatus::new().exit_code(0),
4085 },
4086 cx,
4087 );
4088 });
4089
4090 // Get the second message ID to restore to
4091 let second_message_id = thread.read_with(cx, |thread, _| {
4092 // At this point we have:
4093 // - Index 0: First user message (with checkpoint)
4094 // - Index 1: Second user message (with checkpoint)
4095 // No assistant responses because FakeAgentConnection just returns EndTurn
4096 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
4097 panic!("expected user message at index 1");
4098 };
4099 message.id.clone().unwrap()
4100 });
4101
4102 // Create a terminal AFTER the checkpoint we'll restore to.
4103 // This simulates the AI agent starting a long-running terminal command.
4104 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
4105 let mock_terminal = cx.new(|cx| {
4106 let builder = ::terminal::TerminalBuilder::new_display_only(
4107 ::terminal::terminal_settings::CursorShape::default(),
4108 ::terminal::terminal_settings::AlternateScroll::On,
4109 None,
4110 0,
4111 )
4112 .unwrap();
4113 builder.subscribe(cx)
4114 });
4115
4116 // Register the terminal as created
4117 thread.update(cx, |thread, cx| {
4118 thread.on_terminal_provider_event(
4119 TerminalProviderEvent::Created {
4120 terminal_id: terminal_id.clone(),
4121 label: "sleep 1000".to_string(),
4122 cwd: Some(PathBuf::from("/test")),
4123 output_byte_limit: None,
4124 terminal: mock_terminal.clone(),
4125 },
4126 cx,
4127 );
4128 });
4129
4130 // Simulate the terminal producing output (still running)
4131 thread.update(cx, |thread, cx| {
4132 thread.on_terminal_provider_event(
4133 TerminalProviderEvent::Output {
4134 terminal_id: terminal_id.clone(),
4135 data: b"terminal is running...\n".to_vec(),
4136 },
4137 cx,
4138 );
4139 });
4140
4141 // Create a tool call entry that references this terminal
4142 // This represents the agent requesting a terminal command
4143 thread.update(cx, |thread, cx| {
4144 thread
4145 .handle_session_update(
4146 acp::SessionUpdate::ToolCall(
4147 acp::ToolCall::new("terminal-tool-1", "Running command")
4148 .kind(acp::ToolKind::Execute)
4149 .status(acp::ToolCallStatus::InProgress)
4150 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4151 terminal_id.clone(),
4152 ))])
4153 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4154 ),
4155 cx,
4156 )
4157 .unwrap();
4158 });
4159
4160 // Verify terminal exists and is in the thread
4161 let terminal_exists_before =
4162 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4163 assert!(
4164 terminal_exists_before,
4165 "Terminal should exist before checkpoint restore"
4166 );
4167
4168 // Verify the terminal's underlying task is still running (not completed)
4169 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4170 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4171 terminal_entity.read_with(cx, |term, _cx| {
4172 term.output().is_none() // output is None means it's still running
4173 })
4174 });
4175 assert!(
4176 terminal_running_before,
4177 "Terminal should be running before checkpoint restore"
4178 );
4179
4180 // Verify we have the expected entries before restore
4181 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4182 assert!(
4183 entry_count_before > 1,
4184 "Should have multiple entries before restore"
4185 );
4186
4187 // Restore the checkpoint to the second message.
4188 // This should:
4189 // 1. Cancel any in-progress generation (via the cancel() call)
4190 // 2. Remove the terminal that was created after that point
4191 thread
4192 .update(cx, |thread, cx| {
4193 thread.restore_checkpoint(second_message_id, cx)
4194 })
4195 .await
4196 .unwrap();
4197
4198 // Verify that no send_task is in progress after restore
4199 // (cancel() clears the send_task)
4200 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4201 assert!(
4202 !has_send_task_after,
4203 "Should not have a send_task after restore (cancel should have cleared it)"
4204 );
4205
4206 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4207 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4208 assert_eq!(
4209 entry_count, 1,
4210 "Should have 1 entry after restore (only the first user message)"
4211 );
4212
4213 // Verify the 2 completed terminals from before the checkpoint still exist
4214 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4215 thread.terminals.contains_key(&terminal_id_1)
4216 });
4217 assert!(
4218 terminal_1_exists,
4219 "Terminal 1 (from before checkpoint) should still exist"
4220 );
4221
4222 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4223 thread.terminals.contains_key(&terminal_id_2)
4224 });
4225 assert!(
4226 terminal_2_exists,
4227 "Terminal 2 (from before checkpoint) should still exist"
4228 );
4229
4230 // Verify they're still in completed state
4231 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4232 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4233 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4234 });
4235 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4236
4237 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4238 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4239 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4240 });
4241 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4242
4243 // Verify the running terminal (created after checkpoint) was removed
4244 let terminal_3_exists =
4245 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4246 assert!(
4247 !terminal_3_exists,
4248 "Terminal 3 (created after checkpoint) should have been removed"
4249 );
4250
4251 // Verify total count is 2 (the two from before the checkpoint)
4252 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4253 assert_eq!(
4254 terminal_count, 2,
4255 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4256 );
4257 }
4258
4259 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4260 /// even when a new user message is added while the async checkpoint comparison is in progress.
4261 ///
4262 /// This is a regression test for a bug where update_last_checkpoint would fail with
4263 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4264 /// update_last_checkpoint started and when its async closure ran.
4265 #[gpui::test]
4266 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4267 init_test(cx);
4268
4269 let fs = FakeFs::new(cx.executor());
4270 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4271 .await;
4272 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4273
4274 let handler_done = Arc::new(AtomicBool::new(false));
4275 let handler_done_clone = handler_done.clone();
4276 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4277 move |_, _thread, _cx| {
4278 handler_done_clone.store(true, SeqCst);
4279 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4280 },
4281 ));
4282
4283 let thread = cx
4284 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4285 .await
4286 .unwrap();
4287
4288 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4289 let send_task = cx.background_executor.spawn(send_future);
4290
4291 // Tick until handler completes, then a few more to let update_last_checkpoint start
4292 while !handler_done.load(SeqCst) {
4293 cx.executor().tick();
4294 }
4295 for _ in 0..5 {
4296 cx.executor().tick();
4297 }
4298
4299 thread.update(cx, |thread, cx| {
4300 thread.push_entry(
4301 AgentThreadEntry::UserMessage(UserMessage {
4302 id: Some(UserMessageId::new()),
4303 content: ContentBlock::Empty,
4304 chunks: vec!["Injected message (no checkpoint)".into()],
4305 checkpoint: None,
4306 indented: false,
4307 }),
4308 cx,
4309 );
4310 });
4311
4312 cx.run_until_parked();
4313 let result = send_task.await;
4314
4315 assert!(
4316 result.is_ok(),
4317 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4318 result.err()
4319 );
4320 }
4321}