1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15use task::{Shell, ShellBuilder};
16pub use terminal::*;
17
18use action_log::{ActionLog, ActionLogTelemetry};
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46 pub indented: bool,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77 pub indented: bool,
78}
79
80impl AssistantMessage {
81 pub fn to_markdown(&self, cx: &App) -> String {
82 format!(
83 "## Assistant\n\n{}\n\n",
84 self.chunks
85 .iter()
86 .map(|chunk| chunk.to_markdown(cx))
87 .join("\n\n")
88 )
89 }
90}
91
92#[derive(Debug, PartialEq)]
93pub enum AssistantMessageChunk {
94 Message { block: ContentBlock },
95 Thought { block: ContentBlock },
96}
97
98impl AssistantMessageChunk {
99 pub fn from_str(
100 chunk: &str,
101 language_registry: &Arc<LanguageRegistry>,
102 path_style: PathStyle,
103 cx: &mut App,
104 ) -> Self {
105 Self::Message {
106 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
107 }
108 }
109
110 fn to_markdown(&self, cx: &App) -> String {
111 match self {
112 Self::Message { block } => block.to_markdown(cx).to_string(),
113 Self::Thought { block } => {
114 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
115 }
116 }
117 }
118}
119
120#[derive(Debug)]
121pub enum AgentThreadEntry {
122 UserMessage(UserMessage),
123 AssistantMessage(AssistantMessage),
124 ToolCall(ToolCall),
125}
126
127impl AgentThreadEntry {
128 pub fn is_indented(&self) -> bool {
129 match self {
130 Self::UserMessage(message) => message.indented,
131 Self::AssistantMessage(message) => message.indented,
132 Self::ToolCall(_) => false,
133 }
134 }
135
136 pub fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
141 }
142 }
143
144 pub fn user_message(&self) -> Option<&UserMessage> {
145 if let AgentThreadEntry::UserMessage(message) = self {
146 Some(message)
147 } else {
148 None
149 }
150 }
151
152 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
153 if let AgentThreadEntry::ToolCall(call) = self {
154 itertools::Either::Left(call.diffs())
155 } else {
156 itertools::Either::Right(std::iter::empty())
157 }
158 }
159
160 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
161 if let AgentThreadEntry::ToolCall(call) = self {
162 itertools::Either::Left(call.terminals())
163 } else {
164 itertools::Either::Right(std::iter::empty())
165 }
166 }
167
168 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
169 if let AgentThreadEntry::ToolCall(ToolCall {
170 locations,
171 resolved_locations,
172 ..
173 }) = self
174 {
175 Some((
176 locations.get(ix)?.clone(),
177 resolved_locations.get(ix)?.clone()?,
178 ))
179 } else {
180 None
181 }
182 }
183}
184
185#[derive(Debug)]
186pub struct ToolCall {
187 pub id: acp::ToolCallId,
188 pub label: Entity<Markdown>,
189 pub kind: acp::ToolKind,
190 pub content: Vec<ToolCallContent>,
191 pub status: ToolCallStatus,
192 pub locations: Vec<acp::ToolCallLocation>,
193 pub resolved_locations: Vec<Option<AgentLocation>>,
194 pub raw_input: Option<serde_json::Value>,
195 pub raw_output: Option<serde_json::Value>,
196}
197
198impl ToolCall {
199 fn from_acp(
200 tool_call: acp::ToolCall,
201 status: ToolCallStatus,
202 language_registry: Arc<LanguageRegistry>,
203 path_style: PathStyle,
204 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
205 cx: &mut App,
206 ) -> Result<Self> {
207 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
208 first_line.to_owned() + "…"
209 } else {
210 tool_call.title
211 };
212 let mut content = Vec::with_capacity(tool_call.content.len());
213 for item in tool_call.content {
214 if let Some(item) = ToolCallContent::from_acp(
215 item,
216 language_registry.clone(),
217 path_style,
218 terminals,
219 cx,
220 )? {
221 content.push(item);
222 }
223 }
224
225 let result = Self {
226 id: tool_call.tool_call_id,
227 label: cx
228 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
229 kind: tool_call.kind,
230 content,
231 locations: tool_call.locations,
232 resolved_locations: Vec::default(),
233 status,
234 raw_input: tool_call.raw_input,
235 raw_output: tool_call.raw_output,
236 };
237 Ok(result)
238 }
239
240 fn update_fields(
241 &mut self,
242 fields: acp::ToolCallUpdateFields,
243 language_registry: Arc<LanguageRegistry>,
244 path_style: PathStyle,
245 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
246 cx: &mut App,
247 ) -> Result<()> {
248 let acp::ToolCallUpdateFields {
249 kind,
250 status,
251 title,
252 content,
253 locations,
254 raw_input,
255 raw_output,
256 ..
257 } = fields;
258
259 if let Some(kind) = kind {
260 self.kind = kind;
261 }
262
263 if let Some(status) = status {
264 self.status = status.into();
265 }
266
267 if let Some(title) = title {
268 self.label.update(cx, |label, cx| {
269 if let Some((first_line, _)) = title.split_once("\n") {
270 label.replace(first_line.to_owned() + "…", cx)
271 } else {
272 label.replace(title, cx);
273 }
274 });
275 }
276
277 if let Some(content) = content {
278 let mut new_content_len = content.len();
279 let mut content = content.into_iter();
280
281 // Reuse existing content if we can
282 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
283 let valid_content =
284 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
285 if !valid_content {
286 new_content_len -= 1;
287 }
288 }
289 for new in content {
290 if let Some(new) = ToolCallContent::from_acp(
291 new,
292 language_registry.clone(),
293 path_style,
294 terminals,
295 cx,
296 )? {
297 self.content.push(new);
298 } else {
299 new_content_len -= 1;
300 }
301 }
302 self.content.truncate(new_content_len);
303 }
304
305 if let Some(locations) = locations {
306 self.locations = locations;
307 }
308
309 if let Some(raw_input) = raw_input {
310 self.raw_input = Some(raw_input);
311 }
312
313 if let Some(raw_output) = raw_output {
314 if self.content.is_empty()
315 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
316 {
317 self.content
318 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
319 markdown,
320 }));
321 }
322 self.raw_output = Some(raw_output);
323 }
324 Ok(())
325 }
326
327 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
328 self.content.iter().filter_map(|content| match content {
329 ToolCallContent::Diff(diff) => Some(diff),
330 ToolCallContent::ContentBlock(_) => None,
331 ToolCallContent::Terminal(_) => None,
332 })
333 }
334
335 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
336 self.content.iter().filter_map(|content| match content {
337 ToolCallContent::Terminal(terminal) => Some(terminal),
338 ToolCallContent::ContentBlock(_) => None,
339 ToolCallContent::Diff(_) => None,
340 })
341 }
342
343 fn to_markdown(&self, cx: &App) -> String {
344 let mut markdown = format!(
345 "**Tool Call: {}**\nStatus: {}\n\n",
346 self.label.read(cx).source(),
347 self.status
348 );
349 for content in &self.content {
350 markdown.push_str(content.to_markdown(cx).as_str());
351 markdown.push_str("\n\n");
352 }
353 markdown
354 }
355
356 async fn resolve_location(
357 location: acp::ToolCallLocation,
358 project: WeakEntity<Project>,
359 cx: &mut AsyncApp,
360 ) -> Option<ResolvedLocation> {
361 let buffer = project
362 .update(cx, |project, cx| {
363 project
364 .project_path_for_absolute_path(&location.path, cx)
365 .map(|path| project.open_buffer(path, cx))
366 })
367 .ok()??;
368 let buffer = buffer.await.log_err()?;
369 let position = buffer
370 .update(cx, |buffer, _| {
371 let snapshot = buffer.snapshot();
372 if let Some(row) = location.line {
373 let column = snapshot.indent_size_for_line(row).len;
374 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
375 snapshot.anchor_before(point)
376 } else {
377 Anchor::min_for_buffer(snapshot.remote_id())
378 }
379 })
380 .ok()?;
381
382 Some(ResolvedLocation { buffer, position })
383 }
384
385 fn resolve_locations(
386 &self,
387 project: Entity<Project>,
388 cx: &mut App,
389 ) -> Task<Vec<Option<ResolvedLocation>>> {
390 let locations = self.locations.clone();
391 project.update(cx, |_, cx| {
392 cx.spawn(async move |project, cx| {
393 let mut new_locations = Vec::new();
394 for location in locations {
395 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
396 }
397 new_locations
398 })
399 })
400 }
401}
402
403// Separate so we can hold a strong reference to the buffer
404// for saving on the thread
405#[derive(Clone, Debug, PartialEq, Eq)]
406struct ResolvedLocation {
407 buffer: Entity<Buffer>,
408 position: Anchor,
409}
410
411impl From<&ResolvedLocation> for AgentLocation {
412 fn from(value: &ResolvedLocation) -> Self {
413 Self {
414 buffer: value.buffer.downgrade(),
415 position: value.position,
416 }
417 }
418}
419
420#[derive(Debug)]
421pub enum ToolCallStatus {
422 /// The tool call hasn't started running yet, but we start showing it to
423 /// the user.
424 Pending,
425 /// The tool call is waiting for confirmation from the user.
426 WaitingForConfirmation {
427 options: Vec<acp::PermissionOption>,
428 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
429 },
430 /// The tool call is currently running.
431 InProgress,
432 /// The tool call completed successfully.
433 Completed,
434 /// The tool call failed.
435 Failed,
436 /// The user rejected the tool call.
437 Rejected,
438 /// The user canceled generation so the tool call was canceled.
439 Canceled,
440}
441
442impl From<acp::ToolCallStatus> for ToolCallStatus {
443 fn from(status: acp::ToolCallStatus) -> Self {
444 match status {
445 acp::ToolCallStatus::Pending => Self::Pending,
446 acp::ToolCallStatus::InProgress => Self::InProgress,
447 acp::ToolCallStatus::Completed => Self::Completed,
448 acp::ToolCallStatus::Failed => Self::Failed,
449 _ => Self::Pending,
450 }
451 }
452}
453
454impl Display for ToolCallStatus {
455 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
456 write!(
457 f,
458 "{}",
459 match self {
460 ToolCallStatus::Pending => "Pending",
461 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
462 ToolCallStatus::InProgress => "In Progress",
463 ToolCallStatus::Completed => "Completed",
464 ToolCallStatus::Failed => "Failed",
465 ToolCallStatus::Rejected => "Rejected",
466 ToolCallStatus::Canceled => "Canceled",
467 }
468 )
469 }
470}
471
472#[derive(Debug, PartialEq, Clone)]
473pub enum ContentBlock {
474 Empty,
475 Markdown { markdown: Entity<Markdown> },
476 ResourceLink { resource_link: acp::ResourceLink },
477}
478
479impl ContentBlock {
480 pub fn new(
481 block: acp::ContentBlock,
482 language_registry: &Arc<LanguageRegistry>,
483 path_style: PathStyle,
484 cx: &mut App,
485 ) -> Self {
486 let mut this = Self::Empty;
487 this.append(block, language_registry, path_style, cx);
488 this
489 }
490
491 pub fn new_combined(
492 blocks: impl IntoIterator<Item = acp::ContentBlock>,
493 language_registry: Arc<LanguageRegistry>,
494 path_style: PathStyle,
495 cx: &mut App,
496 ) -> Self {
497 let mut this = Self::Empty;
498 for block in blocks {
499 this.append(block, &language_registry, path_style, cx);
500 }
501 this
502 }
503
504 pub fn append(
505 &mut self,
506 block: acp::ContentBlock,
507 language_registry: &Arc<LanguageRegistry>,
508 path_style: PathStyle,
509 cx: &mut App,
510 ) {
511 if matches!(self, ContentBlock::Empty)
512 && let acp::ContentBlock::ResourceLink(resource_link) = block
513 {
514 *self = ContentBlock::ResourceLink { resource_link };
515 return;
516 }
517
518 let new_content = self.block_string_contents(block, path_style);
519
520 match self {
521 ContentBlock::Empty => {
522 *self = Self::create_markdown_block(new_content, language_registry, cx);
523 }
524 ContentBlock::Markdown { markdown } => {
525 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
526 }
527 ContentBlock::ResourceLink { resource_link } => {
528 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
529 let combined = format!("{}\n{}", existing_content, new_content);
530
531 *self = Self::create_markdown_block(combined, language_registry, cx);
532 }
533 }
534 }
535
536 fn create_markdown_block(
537 content: String,
538 language_registry: &Arc<LanguageRegistry>,
539 cx: &mut App,
540 ) -> ContentBlock {
541 ContentBlock::Markdown {
542 markdown: cx
543 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
544 }
545 }
546
547 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
548 match block {
549 acp::ContentBlock::Text(text_content) => text_content.text,
550 acp::ContentBlock::ResourceLink(resource_link) => {
551 Self::resource_link_md(&resource_link.uri, path_style)
552 }
553 acp::ContentBlock::Resource(acp::EmbeddedResource {
554 resource:
555 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
556 uri,
557 ..
558 }),
559 ..
560 }) => Self::resource_link_md(&uri, path_style),
561 acp::ContentBlock::Image(image) => Self::image_md(&image),
562 _ => String::new(),
563 }
564 }
565
566 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
567 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
568 uri.as_link().to_string()
569 } else {
570 uri.to_string()
571 }
572 }
573
574 fn image_md(_image: &acp::ImageContent) -> String {
575 "`Image`".into()
576 }
577
578 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
579 match self {
580 ContentBlock::Empty => "",
581 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
582 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
583 }
584 }
585
586 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
587 match self {
588 ContentBlock::Empty => None,
589 ContentBlock::Markdown { markdown } => Some(markdown),
590 ContentBlock::ResourceLink { .. } => None,
591 }
592 }
593
594 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
595 match self {
596 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
597 _ => None,
598 }
599 }
600}
601
602#[derive(Debug)]
603pub enum ToolCallContent {
604 ContentBlock(ContentBlock),
605 Diff(Entity<Diff>),
606 Terminal(Entity<Terminal>),
607}
608
609impl ToolCallContent {
610 pub fn from_acp(
611 content: acp::ToolCallContent,
612 language_registry: Arc<LanguageRegistry>,
613 path_style: PathStyle,
614 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
615 cx: &mut App,
616 ) -> Result<Option<Self>> {
617 match content {
618 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
619 Ok(Some(Self::ContentBlock(ContentBlock::new(
620 content,
621 &language_registry,
622 path_style,
623 cx,
624 ))))
625 }
626 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
627 Diff::finalized(
628 diff.path.to_string_lossy().into_owned(),
629 diff.old_text,
630 diff.new_text,
631 language_registry,
632 cx,
633 )
634 })))),
635 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
636 .get(&terminal_id)
637 .cloned()
638 .map(|terminal| Some(Self::Terminal(terminal)))
639 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
640 _ => Ok(None),
641 }
642 }
643
644 pub fn update_from_acp(
645 &mut self,
646 new: acp::ToolCallContent,
647 language_registry: Arc<LanguageRegistry>,
648 path_style: PathStyle,
649 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
650 cx: &mut App,
651 ) -> Result<bool> {
652 let needs_update = match (&self, &new) {
653 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
654 old_diff.read(cx).needs_update(
655 new_diff.old_text.as_deref().unwrap_or(""),
656 &new_diff.new_text,
657 cx,
658 )
659 }
660 _ => true,
661 };
662
663 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
664 if needs_update {
665 *self = update;
666 }
667 Ok(true)
668 } else {
669 Ok(false)
670 }
671 }
672
673 pub fn to_markdown(&self, cx: &App) -> String {
674 match self {
675 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
676 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
677 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
678 }
679 }
680}
681
682#[derive(Debug, PartialEq)]
683pub enum ToolCallUpdate {
684 UpdateFields(acp::ToolCallUpdate),
685 UpdateDiff(ToolCallUpdateDiff),
686 UpdateTerminal(ToolCallUpdateTerminal),
687}
688
689impl ToolCallUpdate {
690 fn id(&self) -> &acp::ToolCallId {
691 match self {
692 Self::UpdateFields(update) => &update.tool_call_id,
693 Self::UpdateDiff(diff) => &diff.id,
694 Self::UpdateTerminal(terminal) => &terminal.id,
695 }
696 }
697}
698
699impl From<acp::ToolCallUpdate> for ToolCallUpdate {
700 fn from(update: acp::ToolCallUpdate) -> Self {
701 Self::UpdateFields(update)
702 }
703}
704
705impl From<ToolCallUpdateDiff> for ToolCallUpdate {
706 fn from(diff: ToolCallUpdateDiff) -> Self {
707 Self::UpdateDiff(diff)
708 }
709}
710
711#[derive(Debug, PartialEq)]
712pub struct ToolCallUpdateDiff {
713 pub id: acp::ToolCallId,
714 pub diff: Entity<Diff>,
715}
716
717impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
718 fn from(terminal: ToolCallUpdateTerminal) -> Self {
719 Self::UpdateTerminal(terminal)
720 }
721}
722
723#[derive(Debug, PartialEq)]
724pub struct ToolCallUpdateTerminal {
725 pub id: acp::ToolCallId,
726 pub terminal: Entity<Terminal>,
727}
728
729#[derive(Debug, Default)]
730pub struct Plan {
731 pub entries: Vec<PlanEntry>,
732}
733
734#[derive(Debug)]
735pub struct PlanStats<'a> {
736 pub in_progress_entry: Option<&'a PlanEntry>,
737 pub pending: u32,
738 pub completed: u32,
739}
740
741impl Plan {
742 pub fn is_empty(&self) -> bool {
743 self.entries.is_empty()
744 }
745
746 pub fn stats(&self) -> PlanStats<'_> {
747 let mut stats = PlanStats {
748 in_progress_entry: None,
749 pending: 0,
750 completed: 0,
751 };
752
753 for entry in &self.entries {
754 match &entry.status {
755 acp::PlanEntryStatus::Pending => {
756 stats.pending += 1;
757 }
758 acp::PlanEntryStatus::InProgress => {
759 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
760 }
761 acp::PlanEntryStatus::Completed => {
762 stats.completed += 1;
763 }
764 _ => {}
765 }
766 }
767
768 stats
769 }
770}
771
772#[derive(Debug)]
773pub struct PlanEntry {
774 pub content: Entity<Markdown>,
775 pub priority: acp::PlanEntryPriority,
776 pub status: acp::PlanEntryStatus,
777}
778
779impl PlanEntry {
780 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
781 Self {
782 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
783 priority: entry.priority,
784 status: entry.status,
785 }
786 }
787}
788
789#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
790pub struct TokenUsage {
791 pub max_tokens: u64,
792 pub used_tokens: u64,
793}
794
795impl TokenUsage {
796 pub fn ratio(&self) -> TokenUsageRatio {
797 #[cfg(debug_assertions)]
798 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
799 .unwrap_or("0.8".to_string())
800 .parse()
801 .unwrap();
802 #[cfg(not(debug_assertions))]
803 let warning_threshold: f32 = 0.8;
804
805 // When the maximum is unknown because there is no selected model,
806 // avoid showing the token limit warning.
807 if self.max_tokens == 0 {
808 TokenUsageRatio::Normal
809 } else if self.used_tokens >= self.max_tokens {
810 TokenUsageRatio::Exceeded
811 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
812 TokenUsageRatio::Warning
813 } else {
814 TokenUsageRatio::Normal
815 }
816 }
817}
818
819#[derive(Debug, Clone, PartialEq, Eq)]
820pub enum TokenUsageRatio {
821 Normal,
822 Warning,
823 Exceeded,
824}
825
826#[derive(Debug, Clone)]
827pub struct RetryStatus {
828 pub last_error: SharedString,
829 pub attempt: usize,
830 pub max_attempts: usize,
831 pub started_at: Instant,
832 pub duration: Duration,
833}
834
835pub struct AcpThread {
836 title: SharedString,
837 entries: Vec<AgentThreadEntry>,
838 plan: Plan,
839 project: Entity<Project>,
840 action_log: Entity<ActionLog>,
841 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
842 send_task: Option<Task<()>>,
843 connection: Rc<dyn AgentConnection>,
844 session_id: acp::SessionId,
845 token_usage: Option<TokenUsage>,
846 prompt_capabilities: acp::PromptCapabilities,
847 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
848 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
849 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
850 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
851}
852
853impl From<&AcpThread> for ActionLogTelemetry {
854 fn from(value: &AcpThread) -> Self {
855 Self {
856 agent_telemetry_id: value.connection().telemetry_id(),
857 session_id: value.session_id.0.clone(),
858 }
859 }
860}
861
862#[derive(Debug)]
863pub enum AcpThreadEvent {
864 NewEntry,
865 TitleUpdated,
866 TokenUsageUpdated,
867 EntryUpdated(usize),
868 EntriesRemoved(Range<usize>),
869 ToolAuthorizationRequired,
870 Retry(RetryStatus),
871 Stopped,
872 Error,
873 LoadError(LoadError),
874 PromptCapabilitiesUpdated,
875 Refusal,
876 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
877 ModeUpdated(acp::SessionModeId),
878}
879
880impl EventEmitter<AcpThreadEvent> for AcpThread {}
881
882#[derive(Debug, Clone)]
883pub enum TerminalProviderEvent {
884 Created {
885 terminal_id: acp::TerminalId,
886 label: String,
887 cwd: Option<PathBuf>,
888 output_byte_limit: Option<u64>,
889 terminal: Entity<::terminal::Terminal>,
890 },
891 Output {
892 terminal_id: acp::TerminalId,
893 data: Vec<u8>,
894 },
895 TitleChanged {
896 terminal_id: acp::TerminalId,
897 title: String,
898 },
899 Exit {
900 terminal_id: acp::TerminalId,
901 status: acp::TerminalExitStatus,
902 },
903}
904
905#[derive(Debug, Clone)]
906pub enum TerminalProviderCommand {
907 WriteInput {
908 terminal_id: acp::TerminalId,
909 bytes: Vec<u8>,
910 },
911 Resize {
912 terminal_id: acp::TerminalId,
913 cols: u16,
914 rows: u16,
915 },
916 Close {
917 terminal_id: acp::TerminalId,
918 },
919}
920
921impl AcpThread {
922 pub fn on_terminal_provider_event(
923 &mut self,
924 event: TerminalProviderEvent,
925 cx: &mut Context<Self>,
926 ) {
927 match event {
928 TerminalProviderEvent::Created {
929 terminal_id,
930 label,
931 cwd,
932 output_byte_limit,
933 terminal,
934 } => {
935 let entity = self.register_terminal_created(
936 terminal_id.clone(),
937 label,
938 cwd,
939 output_byte_limit,
940 terminal,
941 cx,
942 );
943
944 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
945 for data in chunks.drain(..) {
946 entity.update(cx, |term, cx| {
947 term.inner().update(cx, |inner, cx| {
948 inner.write_output(&data, cx);
949 })
950 });
951 }
952 }
953
954 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
955 entity.update(cx, |_term, cx| {
956 cx.notify();
957 });
958 }
959
960 cx.notify();
961 }
962 TerminalProviderEvent::Output { terminal_id, data } => {
963 if let Some(entity) = self.terminals.get(&terminal_id) {
964 entity.update(cx, |term, cx| {
965 term.inner().update(cx, |inner, cx| {
966 inner.write_output(&data, cx);
967 })
968 });
969 } else {
970 self.pending_terminal_output
971 .entry(terminal_id)
972 .or_default()
973 .push(data);
974 }
975 }
976 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
977 if let Some(entity) = self.terminals.get(&terminal_id) {
978 entity.update(cx, |term, cx| {
979 term.inner().update(cx, |inner, cx| {
980 inner.breadcrumb_text = title;
981 cx.emit(::terminal::Event::BreadcrumbsChanged);
982 })
983 });
984 }
985 }
986 TerminalProviderEvent::Exit {
987 terminal_id,
988 status,
989 } => {
990 if let Some(entity) = self.terminals.get(&terminal_id) {
991 entity.update(cx, |_term, cx| {
992 cx.notify();
993 });
994 } else {
995 self.pending_terminal_exit.insert(terminal_id, status);
996 }
997 }
998 }
999 }
1000}
1001
1002#[derive(PartialEq, Eq, Debug)]
1003pub enum ThreadStatus {
1004 Idle,
1005 Generating,
1006}
1007
1008#[derive(Debug, Clone)]
1009pub enum LoadError {
1010 Unsupported {
1011 command: SharedString,
1012 current_version: SharedString,
1013 minimum_version: SharedString,
1014 },
1015 FailedToInstall(SharedString),
1016 Exited {
1017 status: ExitStatus,
1018 },
1019 Other(SharedString),
1020}
1021
1022impl Display for LoadError {
1023 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1024 match self {
1025 LoadError::Unsupported {
1026 command: path,
1027 current_version,
1028 minimum_version,
1029 } => {
1030 write!(
1031 f,
1032 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1033 )
1034 }
1035 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1036 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1037 LoadError::Other(msg) => write!(f, "{msg}"),
1038 }
1039 }
1040}
1041
1042impl Error for LoadError {}
1043
1044impl AcpThread {
1045 pub fn new(
1046 title: impl Into<SharedString>,
1047 connection: Rc<dyn AgentConnection>,
1048 project: Entity<Project>,
1049 action_log: Entity<ActionLog>,
1050 session_id: acp::SessionId,
1051 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1052 cx: &mut Context<Self>,
1053 ) -> Self {
1054 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1055 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1056 loop {
1057 let caps = prompt_capabilities_rx.recv().await?;
1058 this.update(cx, |this, cx| {
1059 this.prompt_capabilities = caps;
1060 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1061 })?;
1062 }
1063 });
1064
1065 Self {
1066 action_log,
1067 shared_buffers: Default::default(),
1068 entries: Default::default(),
1069 plan: Default::default(),
1070 title: title.into(),
1071 project,
1072 send_task: None,
1073 connection,
1074 session_id,
1075 token_usage: None,
1076 prompt_capabilities,
1077 _observe_prompt_capabilities: task,
1078 terminals: HashMap::default(),
1079 pending_terminal_output: HashMap::default(),
1080 pending_terminal_exit: HashMap::default(),
1081 }
1082 }
1083
1084 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1085 self.prompt_capabilities.clone()
1086 }
1087
1088 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1089 &self.connection
1090 }
1091
1092 pub fn action_log(&self) -> &Entity<ActionLog> {
1093 &self.action_log
1094 }
1095
1096 pub fn project(&self) -> &Entity<Project> {
1097 &self.project
1098 }
1099
1100 pub fn title(&self) -> SharedString {
1101 self.title.clone()
1102 }
1103
1104 pub fn entries(&self) -> &[AgentThreadEntry] {
1105 &self.entries
1106 }
1107
1108 pub fn session_id(&self) -> &acp::SessionId {
1109 &self.session_id
1110 }
1111
1112 pub fn status(&self) -> ThreadStatus {
1113 if self.send_task.is_some() {
1114 ThreadStatus::Generating
1115 } else {
1116 ThreadStatus::Idle
1117 }
1118 }
1119
1120 pub fn token_usage(&self) -> Option<&TokenUsage> {
1121 self.token_usage.as_ref()
1122 }
1123
1124 pub fn has_pending_edit_tool_calls(&self) -> bool {
1125 for entry in self.entries.iter().rev() {
1126 match entry {
1127 AgentThreadEntry::UserMessage(_) => return false,
1128 AgentThreadEntry::ToolCall(
1129 call @ ToolCall {
1130 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1131 ..
1132 },
1133 ) if call.diffs().next().is_some() => {
1134 return true;
1135 }
1136 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1137 }
1138 }
1139
1140 false
1141 }
1142
1143 pub fn used_tools_since_last_user_message(&self) -> bool {
1144 for entry in self.entries.iter().rev() {
1145 match entry {
1146 AgentThreadEntry::UserMessage(..) => return false,
1147 AgentThreadEntry::AssistantMessage(..) => continue,
1148 AgentThreadEntry::ToolCall(..) => return true,
1149 }
1150 }
1151
1152 false
1153 }
1154
1155 pub fn handle_session_update(
1156 &mut self,
1157 update: acp::SessionUpdate,
1158 cx: &mut Context<Self>,
1159 ) -> Result<(), acp::Error> {
1160 match update {
1161 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1162 self.push_user_content_block(None, content, cx);
1163 }
1164 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1165 self.push_assistant_content_block(content, false, cx);
1166 }
1167 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1168 self.push_assistant_content_block(content, true, cx);
1169 }
1170 acp::SessionUpdate::ToolCall(tool_call) => {
1171 self.upsert_tool_call(tool_call, cx)?;
1172 }
1173 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1174 self.update_tool_call(tool_call_update, cx)?;
1175 }
1176 acp::SessionUpdate::Plan(plan) => {
1177 self.update_plan(plan, cx);
1178 }
1179 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1180 available_commands,
1181 ..
1182 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1183 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1184 current_mode_id,
1185 ..
1186 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1187 _ => {}
1188 }
1189 Ok(())
1190 }
1191
1192 pub fn push_user_content_block(
1193 &mut self,
1194 message_id: Option<UserMessageId>,
1195 chunk: acp::ContentBlock,
1196 cx: &mut Context<Self>,
1197 ) {
1198 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1199 }
1200
1201 pub fn push_user_content_block_with_indent(
1202 &mut self,
1203 message_id: Option<UserMessageId>,
1204 chunk: acp::ContentBlock,
1205 indented: bool,
1206 cx: &mut Context<Self>,
1207 ) {
1208 let language_registry = self.project.read(cx).languages().clone();
1209 let path_style = self.project.read(cx).path_style(cx);
1210 let entries_len = self.entries.len();
1211
1212 if let Some(last_entry) = self.entries.last_mut()
1213 && let AgentThreadEntry::UserMessage(UserMessage {
1214 id,
1215 content,
1216 chunks,
1217 indented: existing_indented,
1218 ..
1219 }) = last_entry
1220 && *existing_indented == indented
1221 {
1222 *id = message_id.or(id.take());
1223 content.append(chunk.clone(), &language_registry, path_style, cx);
1224 chunks.push(chunk);
1225 let idx = entries_len - 1;
1226 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1227 } else {
1228 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1229 self.push_entry(
1230 AgentThreadEntry::UserMessage(UserMessage {
1231 id: message_id,
1232 content,
1233 chunks: vec![chunk],
1234 checkpoint: None,
1235 indented,
1236 }),
1237 cx,
1238 );
1239 }
1240 }
1241
1242 pub fn push_assistant_content_block(
1243 &mut self,
1244 chunk: acp::ContentBlock,
1245 is_thought: bool,
1246 cx: &mut Context<Self>,
1247 ) {
1248 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1249 }
1250
1251 pub fn push_assistant_content_block_with_indent(
1252 &mut self,
1253 chunk: acp::ContentBlock,
1254 is_thought: bool,
1255 indented: bool,
1256 cx: &mut Context<Self>,
1257 ) {
1258 let language_registry = self.project.read(cx).languages().clone();
1259 let path_style = self.project.read(cx).path_style(cx);
1260 let entries_len = self.entries.len();
1261 if let Some(last_entry) = self.entries.last_mut()
1262 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1263 chunks,
1264 indented: existing_indented,
1265 }) = last_entry
1266 && *existing_indented == indented
1267 {
1268 let idx = entries_len - 1;
1269 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1270 match (chunks.last_mut(), is_thought) {
1271 (Some(AssistantMessageChunk::Message { block }), false)
1272 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1273 block.append(chunk, &language_registry, path_style, cx)
1274 }
1275 _ => {
1276 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1277 if is_thought {
1278 chunks.push(AssistantMessageChunk::Thought { block })
1279 } else {
1280 chunks.push(AssistantMessageChunk::Message { block })
1281 }
1282 }
1283 }
1284 } else {
1285 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1286 let chunk = if is_thought {
1287 AssistantMessageChunk::Thought { block }
1288 } else {
1289 AssistantMessageChunk::Message { block }
1290 };
1291
1292 self.push_entry(
1293 AgentThreadEntry::AssistantMessage(AssistantMessage {
1294 chunks: vec![chunk],
1295 indented,
1296 }),
1297 cx,
1298 );
1299 }
1300 }
1301
1302 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1303 self.entries.push(entry);
1304 cx.emit(AcpThreadEvent::NewEntry);
1305 }
1306
1307 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1308 self.connection.set_title(&self.session_id, cx).is_some()
1309 }
1310
1311 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1312 if title != self.title {
1313 self.title = title.clone();
1314 cx.emit(AcpThreadEvent::TitleUpdated);
1315 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1316 return set_title.run(title, cx);
1317 }
1318 }
1319 Task::ready(Ok(()))
1320 }
1321
1322 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1323 self.token_usage = usage;
1324 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1325 }
1326
1327 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1328 cx.emit(AcpThreadEvent::Retry(status));
1329 }
1330
1331 pub fn update_tool_call(
1332 &mut self,
1333 update: impl Into<ToolCallUpdate>,
1334 cx: &mut Context<Self>,
1335 ) -> Result<()> {
1336 let update = update.into();
1337 let languages = self.project.read(cx).languages().clone();
1338 let path_style = self.project.read(cx).path_style(cx);
1339
1340 let ix = match self.index_for_tool_call(update.id()) {
1341 Some(ix) => ix,
1342 None => {
1343 // Tool call not found - create a failed tool call entry
1344 let failed_tool_call = ToolCall {
1345 id: update.id().clone(),
1346 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1347 kind: acp::ToolKind::Fetch,
1348 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1349 "Tool call not found".into(),
1350 &languages,
1351 path_style,
1352 cx,
1353 ))],
1354 status: ToolCallStatus::Failed,
1355 locations: Vec::new(),
1356 resolved_locations: Vec::new(),
1357 raw_input: None,
1358 raw_output: None,
1359 };
1360 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1361 return Ok(());
1362 }
1363 };
1364 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1365 unreachable!()
1366 };
1367
1368 match update {
1369 ToolCallUpdate::UpdateFields(update) => {
1370 let location_updated = update.fields.locations.is_some();
1371 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1372 if location_updated {
1373 self.resolve_locations(update.tool_call_id, cx);
1374 }
1375 }
1376 ToolCallUpdate::UpdateDiff(update) => {
1377 call.content.clear();
1378 call.content.push(ToolCallContent::Diff(update.diff));
1379 }
1380 ToolCallUpdate::UpdateTerminal(update) => {
1381 call.content.clear();
1382 call.content
1383 .push(ToolCallContent::Terminal(update.terminal));
1384 }
1385 }
1386
1387 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1388
1389 Ok(())
1390 }
1391
1392 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1393 pub fn upsert_tool_call(
1394 &mut self,
1395 tool_call: acp::ToolCall,
1396 cx: &mut Context<Self>,
1397 ) -> Result<(), acp::Error> {
1398 let status = tool_call.status.into();
1399 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1400 }
1401
1402 /// Fails if id does not match an existing entry.
1403 pub fn upsert_tool_call_inner(
1404 &mut self,
1405 update: acp::ToolCallUpdate,
1406 status: ToolCallStatus,
1407 cx: &mut Context<Self>,
1408 ) -> Result<(), acp::Error> {
1409 let language_registry = self.project.read(cx).languages().clone();
1410 let path_style = self.project.read(cx).path_style(cx);
1411 let id = update.tool_call_id.clone();
1412
1413 let agent_telemetry_id = self.connection().telemetry_id();
1414 let session = self.session_id();
1415 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1416 let status = if matches!(status, ToolCallStatus::Completed) {
1417 "completed"
1418 } else {
1419 "failed"
1420 };
1421 telemetry::event!(
1422 "Agent Tool Call Completed",
1423 agent_telemetry_id,
1424 session,
1425 status
1426 );
1427 }
1428
1429 if let Some(ix) = self.index_for_tool_call(&id) {
1430 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1431 unreachable!()
1432 };
1433
1434 call.update_fields(
1435 update.fields,
1436 language_registry,
1437 path_style,
1438 &self.terminals,
1439 cx,
1440 )?;
1441 call.status = status;
1442
1443 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1444 } else {
1445 let call = ToolCall::from_acp(
1446 update.try_into()?,
1447 status,
1448 language_registry,
1449 self.project.read(cx).path_style(cx),
1450 &self.terminals,
1451 cx,
1452 )?;
1453 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1454 };
1455
1456 self.resolve_locations(id, cx);
1457 Ok(())
1458 }
1459
1460 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1461 self.entries
1462 .iter()
1463 .enumerate()
1464 .rev()
1465 .find_map(|(index, entry)| {
1466 if let AgentThreadEntry::ToolCall(tool_call) = entry
1467 && &tool_call.id == id
1468 {
1469 Some(index)
1470 } else {
1471 None
1472 }
1473 })
1474 }
1475
1476 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1477 // The tool call we are looking for is typically the last one, or very close to the end.
1478 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1479 self.entries
1480 .iter_mut()
1481 .enumerate()
1482 .rev()
1483 .find_map(|(index, tool_call)| {
1484 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1485 && &tool_call.id == id
1486 {
1487 Some((index, tool_call))
1488 } else {
1489 None
1490 }
1491 })
1492 }
1493
1494 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1495 self.entries
1496 .iter()
1497 .enumerate()
1498 .rev()
1499 .find_map(|(index, tool_call)| {
1500 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1501 && &tool_call.id == id
1502 {
1503 Some((index, tool_call))
1504 } else {
1505 None
1506 }
1507 })
1508 }
1509
1510 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1511 let project = self.project.clone();
1512 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1513 return;
1514 };
1515 let task = tool_call.resolve_locations(project, cx);
1516 cx.spawn(async move |this, cx| {
1517 let resolved_locations = task.await;
1518
1519 this.update(cx, |this, cx| {
1520 let project = this.project.clone();
1521
1522 for location in resolved_locations.iter().flatten() {
1523 this.shared_buffers
1524 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1525 }
1526 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1527 return;
1528 };
1529
1530 if let Some(Some(location)) = resolved_locations.last() {
1531 project.update(cx, |project, cx| {
1532 let should_ignore = if let Some(agent_location) = project
1533 .agent_location()
1534 .filter(|agent_location| agent_location.buffer == location.buffer)
1535 {
1536 let snapshot = location.buffer.read(cx).snapshot();
1537 let old_position = agent_location.position.to_point(&snapshot);
1538 let new_position = location.position.to_point(&snapshot);
1539
1540 // ignore this so that when we get updates from the edit tool
1541 // the position doesn't reset to the startof line
1542 old_position.row == new_position.row
1543 && old_position.column > new_position.column
1544 } else {
1545 false
1546 };
1547 if !should_ignore {
1548 project.set_agent_location(Some(location.into()), cx);
1549 }
1550 });
1551 }
1552
1553 let resolved_locations = resolved_locations
1554 .iter()
1555 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1556 .collect::<Vec<_>>();
1557
1558 if tool_call.resolved_locations != resolved_locations {
1559 tool_call.resolved_locations = resolved_locations;
1560 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1561 }
1562 })
1563 })
1564 .detach();
1565 }
1566
1567 pub fn request_tool_call_authorization(
1568 &mut self,
1569 tool_call: acp::ToolCallUpdate,
1570 options: Vec<acp::PermissionOption>,
1571 respect_always_allow_setting: bool,
1572 cx: &mut Context<Self>,
1573 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1574 let (tx, rx) = oneshot::channel();
1575
1576 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1577 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1578 // some tools would (incorrectly) continue to auto-accept.
1579 if let Some(allow_once_option) = options.iter().find_map(|option| {
1580 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1581 Some(option.option_id.clone())
1582 } else {
1583 None
1584 }
1585 }) {
1586 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1587 return Ok(async {
1588 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1589 allow_once_option,
1590 ))
1591 }
1592 .boxed());
1593 }
1594 }
1595
1596 let status = ToolCallStatus::WaitingForConfirmation {
1597 options,
1598 respond_tx: tx,
1599 };
1600
1601 self.upsert_tool_call_inner(tool_call, status, cx)?;
1602 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1603
1604 let fut = async {
1605 match rx.await {
1606 Ok(option) => acp::RequestPermissionOutcome::Selected(
1607 acp::SelectedPermissionOutcome::new(option),
1608 ),
1609 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1610 }
1611 }
1612 .boxed();
1613
1614 Ok(fut)
1615 }
1616
1617 pub fn authorize_tool_call(
1618 &mut self,
1619 id: acp::ToolCallId,
1620 option_id: acp::PermissionOptionId,
1621 option_kind: acp::PermissionOptionKind,
1622 cx: &mut Context<Self>,
1623 ) {
1624 let Some((ix, call)) = self.tool_call_mut(&id) else {
1625 return;
1626 };
1627
1628 let new_status = match option_kind {
1629 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1630 ToolCallStatus::Rejected
1631 }
1632 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1633 ToolCallStatus::InProgress
1634 }
1635 _ => ToolCallStatus::InProgress,
1636 };
1637
1638 let curr_status = mem::replace(&mut call.status, new_status);
1639
1640 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1641 respond_tx.send(option_id).log_err();
1642 } else if cfg!(debug_assertions) {
1643 panic!("tried to authorize an already authorized tool call");
1644 }
1645
1646 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1647 }
1648
1649 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1650 let mut first_tool_call = None;
1651
1652 for entry in self.entries.iter().rev() {
1653 match &entry {
1654 AgentThreadEntry::ToolCall(call) => {
1655 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1656 first_tool_call = Some(call);
1657 } else {
1658 continue;
1659 }
1660 }
1661 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1662 // Reached the beginning of the turn.
1663 // If we had pending permission requests in the previous turn, they have been cancelled.
1664 break;
1665 }
1666 }
1667 }
1668
1669 first_tool_call
1670 }
1671
1672 pub fn plan(&self) -> &Plan {
1673 &self.plan
1674 }
1675
1676 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1677 let new_entries_len = request.entries.len();
1678 let mut new_entries = request.entries.into_iter();
1679
1680 // Reuse existing markdown to prevent flickering
1681 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1682 let PlanEntry {
1683 content,
1684 priority,
1685 status,
1686 } = old;
1687 content.update(cx, |old, cx| {
1688 old.replace(new.content, cx);
1689 });
1690 *priority = new.priority;
1691 *status = new.status;
1692 }
1693 for new in new_entries {
1694 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1695 }
1696 self.plan.entries.truncate(new_entries_len);
1697
1698 cx.notify();
1699 }
1700
1701 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1702 self.plan
1703 .entries
1704 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1705 cx.notify();
1706 }
1707
1708 #[cfg(any(test, feature = "test-support"))]
1709 pub fn send_raw(
1710 &mut self,
1711 message: &str,
1712 cx: &mut Context<Self>,
1713 ) -> BoxFuture<'static, Result<()>> {
1714 self.send(vec![message.into()], cx)
1715 }
1716
1717 pub fn send(
1718 &mut self,
1719 message: Vec<acp::ContentBlock>,
1720 cx: &mut Context<Self>,
1721 ) -> BoxFuture<'static, Result<()>> {
1722 let block = ContentBlock::new_combined(
1723 message.clone(),
1724 self.project.read(cx).languages().clone(),
1725 self.project.read(cx).path_style(cx),
1726 cx,
1727 );
1728 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1729 let git_store = self.project.read(cx).git_store().clone();
1730
1731 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1732 Some(UserMessageId::new())
1733 } else {
1734 None
1735 };
1736
1737 self.run_turn(cx, async move |this, cx| {
1738 this.update(cx, |this, cx| {
1739 this.push_entry(
1740 AgentThreadEntry::UserMessage(UserMessage {
1741 id: message_id.clone(),
1742 content: block,
1743 chunks: message,
1744 checkpoint: None,
1745 indented: false,
1746 }),
1747 cx,
1748 );
1749 })
1750 .ok();
1751
1752 let old_checkpoint = git_store
1753 .update(cx, |git, cx| git.checkpoint(cx))?
1754 .await
1755 .context("failed to get old checkpoint")
1756 .log_err();
1757 this.update(cx, |this, cx| {
1758 if let Some((_ix, message)) = this.last_user_message() {
1759 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1760 git_checkpoint,
1761 show: false,
1762 });
1763 }
1764 this.connection.prompt(message_id, request, cx)
1765 })?
1766 .await
1767 })
1768 }
1769
1770 pub fn can_resume(&self, cx: &App) -> bool {
1771 self.connection.resume(&self.session_id, cx).is_some()
1772 }
1773
1774 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1775 self.run_turn(cx, async move |this, cx| {
1776 this.update(cx, |this, cx| {
1777 this.connection
1778 .resume(&this.session_id, cx)
1779 .map(|resume| resume.run(cx))
1780 })?
1781 .context("resuming a session is not supported")?
1782 .await
1783 })
1784 }
1785
1786 fn run_turn(
1787 &mut self,
1788 cx: &mut Context<Self>,
1789 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1790 ) -> BoxFuture<'static, Result<()>> {
1791 self.clear_completed_plan_entries(cx);
1792
1793 let (tx, rx) = oneshot::channel();
1794 let cancel_task = self.cancel(cx);
1795
1796 self.send_task = Some(cx.spawn(async move |this, cx| {
1797 cancel_task.await;
1798 tx.send(f(this, cx).await).ok();
1799 }));
1800
1801 cx.spawn(async move |this, cx| {
1802 let response = rx.await;
1803
1804 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1805 .await?;
1806
1807 this.update(cx, |this, cx| {
1808 this.project
1809 .update(cx, |project, cx| project.set_agent_location(None, cx));
1810 match response {
1811 Ok(Err(e)) => {
1812 this.send_task.take();
1813 cx.emit(AcpThreadEvent::Error);
1814 Err(e)
1815 }
1816 result => {
1817 let canceled = matches!(
1818 result,
1819 Ok(Ok(acp::PromptResponse {
1820 stop_reason: acp::StopReason::Cancelled,
1821 ..
1822 }))
1823 );
1824
1825 // We only take the task if the current prompt wasn't canceled.
1826 //
1827 // This prompt may have been canceled because another one was sent
1828 // while it was still generating. In these cases, dropping `send_task`
1829 // would cause the next generation to be canceled.
1830 if !canceled {
1831 this.send_task.take();
1832 }
1833
1834 // Handle refusal - distinguish between user prompt and tool call refusals
1835 if let Ok(Ok(acp::PromptResponse {
1836 stop_reason: acp::StopReason::Refusal,
1837 ..
1838 })) = result
1839 {
1840 if let Some((user_msg_ix, _)) = this.last_user_message() {
1841 // Check if there's a completed tool call with results after the last user message
1842 // This indicates the refusal is in response to tool output, not the user's prompt
1843 let has_completed_tool_call_after_user_msg =
1844 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1845 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1846 // Check if the tool call has completed and has output
1847 matches!(tool_call.status, ToolCallStatus::Completed)
1848 && tool_call.raw_output.is_some()
1849 } else {
1850 false
1851 }
1852 });
1853
1854 if has_completed_tool_call_after_user_msg {
1855 // Refusal is due to tool output - don't truncate, just notify
1856 // The model refused based on what the tool returned
1857 cx.emit(AcpThreadEvent::Refusal);
1858 } else {
1859 // User prompt was refused - truncate back to before the user message
1860 let range = user_msg_ix..this.entries.len();
1861 if range.start < range.end {
1862 this.entries.truncate(user_msg_ix);
1863 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1864 }
1865 cx.emit(AcpThreadEvent::Refusal);
1866 }
1867 } else {
1868 // No user message found, treat as general refusal
1869 cx.emit(AcpThreadEvent::Refusal);
1870 }
1871 }
1872
1873 cx.emit(AcpThreadEvent::Stopped);
1874 Ok(())
1875 }
1876 }
1877 })?
1878 })
1879 .boxed()
1880 }
1881
1882 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1883 let Some(send_task) = self.send_task.take() else {
1884 return Task::ready(());
1885 };
1886
1887 for entry in self.entries.iter_mut() {
1888 if let AgentThreadEntry::ToolCall(call) = entry {
1889 let cancel = matches!(
1890 call.status,
1891 ToolCallStatus::Pending
1892 | ToolCallStatus::WaitingForConfirmation { .. }
1893 | ToolCallStatus::InProgress
1894 );
1895
1896 if cancel {
1897 call.status = ToolCallStatus::Canceled;
1898 }
1899 }
1900 }
1901
1902 self.connection.cancel(&self.session_id, cx);
1903
1904 // Wait for the send task to complete
1905 cx.foreground_executor().spawn(send_task)
1906 }
1907
1908 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1909 pub fn restore_checkpoint(
1910 &mut self,
1911 id: UserMessageId,
1912 cx: &mut Context<Self>,
1913 ) -> Task<Result<()>> {
1914 let Some((_, message)) = self.user_message_mut(&id) else {
1915 return Task::ready(Err(anyhow!("message not found")));
1916 };
1917
1918 let checkpoint = message
1919 .checkpoint
1920 .as_ref()
1921 .map(|c| c.git_checkpoint.clone());
1922
1923 // Cancel any in-progress generation before restoring
1924 let cancel_task = self.cancel(cx);
1925 let rewind = self.rewind(id.clone(), cx);
1926 let git_store = self.project.read(cx).git_store().clone();
1927
1928 cx.spawn(async move |_, cx| {
1929 cancel_task.await;
1930 rewind.await?;
1931 if let Some(checkpoint) = checkpoint {
1932 git_store
1933 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1934 .await?;
1935 }
1936
1937 Ok(())
1938 })
1939 }
1940
1941 /// Rewinds this thread to before the entry at `index`, removing it and all
1942 /// subsequent entries while rejecting any action_log changes made from that point.
1943 /// Unlike `restore_checkpoint`, this method does not restore from git.
1944 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1945 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1946 return Task::ready(Err(anyhow!("not supported")));
1947 };
1948
1949 let telemetry = ActionLogTelemetry::from(&*self);
1950 cx.spawn(async move |this, cx| {
1951 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1952 this.update(cx, |this, cx| {
1953 if let Some((ix, _)) = this.user_message_mut(&id) {
1954 // Collect all terminals from entries that will be removed
1955 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1956 .iter()
1957 .flat_map(|entry| entry.terminals())
1958 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1959 .collect();
1960
1961 let range = ix..this.entries.len();
1962 this.entries.truncate(ix);
1963 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1964
1965 // Kill and remove the terminals
1966 for terminal_id in terminals_to_remove {
1967 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1968 terminal.update(cx, |terminal, cx| {
1969 terminal.kill(cx);
1970 });
1971 }
1972 }
1973 }
1974 this.action_log().update(cx, |action_log, cx| {
1975 action_log.reject_all_edits(Some(telemetry), cx)
1976 })
1977 })?
1978 .await;
1979 Ok(())
1980 })
1981 }
1982
1983 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1984 let git_store = self.project.read(cx).git_store().clone();
1985
1986 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1987 if let Some(checkpoint) = message.checkpoint.as_ref() {
1988 checkpoint.git_checkpoint.clone()
1989 } else {
1990 return Task::ready(Ok(()));
1991 }
1992 } else {
1993 return Task::ready(Ok(()));
1994 };
1995
1996 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1997 cx.spawn(async move |this, cx| {
1998 let new_checkpoint = new_checkpoint
1999 .await
2000 .context("failed to get new checkpoint")
2001 .log_err();
2002 if let Some(new_checkpoint) = new_checkpoint {
2003 let equal = git_store
2004 .update(cx, |git, cx| {
2005 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2006 })?
2007 .await
2008 .unwrap_or(true);
2009 this.update(cx, |this, cx| {
2010 let (ix, message) = this.last_user_message().context("no user message")?;
2011 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
2012 checkpoint.show = !equal;
2013 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2014 anyhow::Ok(())
2015 })??;
2016 }
2017
2018 Ok(())
2019 })
2020 }
2021
2022 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2023 self.entries
2024 .iter_mut()
2025 .enumerate()
2026 .rev()
2027 .find_map(|(ix, entry)| {
2028 if let AgentThreadEntry::UserMessage(message) = entry {
2029 Some((ix, message))
2030 } else {
2031 None
2032 }
2033 })
2034 }
2035
2036 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2037 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2038 if let AgentThreadEntry::UserMessage(message) = entry {
2039 if message.id.as_ref() == Some(id) {
2040 Some((ix, message))
2041 } else {
2042 None
2043 }
2044 } else {
2045 None
2046 }
2047 })
2048 }
2049
2050 pub fn read_text_file(
2051 &self,
2052 path: PathBuf,
2053 line: Option<u32>,
2054 limit: Option<u32>,
2055 reuse_shared_snapshot: bool,
2056 cx: &mut Context<Self>,
2057 ) -> Task<Result<String, acp::Error>> {
2058 // Args are 1-based, move to 0-based
2059 let line = line.unwrap_or_default().saturating_sub(1);
2060 let limit = limit.unwrap_or(u32::MAX);
2061 let project = self.project.clone();
2062 let action_log = self.action_log.clone();
2063 cx.spawn(async move |this, cx| {
2064 let load = project
2065 .update(cx, |project, cx| {
2066 let path = project
2067 .project_path_for_absolute_path(&path, cx)
2068 .ok_or_else(|| {
2069 acp::Error::resource_not_found(Some(path.display().to_string()))
2070 })?;
2071 Ok(project.open_buffer(path, cx))
2072 })
2073 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2074 .flatten()?;
2075
2076 let buffer = load.await?;
2077
2078 let snapshot = if reuse_shared_snapshot {
2079 this.read_with(cx, |this, _| {
2080 this.shared_buffers.get(&buffer.clone()).cloned()
2081 })
2082 .log_err()
2083 .flatten()
2084 } else {
2085 None
2086 };
2087
2088 let snapshot = if let Some(snapshot) = snapshot {
2089 snapshot
2090 } else {
2091 action_log.update(cx, |action_log, cx| {
2092 action_log.buffer_read(buffer.clone(), cx);
2093 })?;
2094
2095 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2096 this.update(cx, |this, _| {
2097 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2098 })?;
2099 snapshot
2100 };
2101
2102 let max_point = snapshot.max_point();
2103 let start_position = Point::new(line, 0);
2104
2105 if start_position > max_point {
2106 return Err(acp::Error::invalid_params().data(format!(
2107 "Attempting to read beyond the end of the file, line {}:{}",
2108 max_point.row + 1,
2109 max_point.column
2110 )));
2111 }
2112
2113 let start = snapshot.anchor_before(start_position);
2114 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2115
2116 project.update(cx, |project, cx| {
2117 project.set_agent_location(
2118 Some(AgentLocation {
2119 buffer: buffer.downgrade(),
2120 position: start,
2121 }),
2122 cx,
2123 );
2124 })?;
2125
2126 Ok(snapshot.text_for_range(start..end).collect::<String>())
2127 })
2128 }
2129
2130 pub fn write_text_file(
2131 &self,
2132 path: PathBuf,
2133 content: String,
2134 cx: &mut Context<Self>,
2135 ) -> Task<Result<()>> {
2136 let project = self.project.clone();
2137 let action_log = self.action_log.clone();
2138 cx.spawn(async move |this, cx| {
2139 let load = project.update(cx, |project, cx| {
2140 let path = project
2141 .project_path_for_absolute_path(&path, cx)
2142 .context("invalid path")?;
2143 anyhow::Ok(project.open_buffer(path, cx))
2144 });
2145 let buffer = load??.await?;
2146 let snapshot = this.update(cx, |this, cx| {
2147 this.shared_buffers
2148 .get(&buffer)
2149 .cloned()
2150 .unwrap_or_else(|| buffer.read(cx).snapshot())
2151 })?;
2152 let edits = cx
2153 .background_executor()
2154 .spawn(async move {
2155 let old_text = snapshot.text();
2156 text_diff(old_text.as_str(), &content)
2157 .into_iter()
2158 .map(|(range, replacement)| {
2159 (
2160 snapshot.anchor_after(range.start)
2161 ..snapshot.anchor_before(range.end),
2162 replacement,
2163 )
2164 })
2165 .collect::<Vec<_>>()
2166 })
2167 .await;
2168
2169 project.update(cx, |project, cx| {
2170 project.set_agent_location(
2171 Some(AgentLocation {
2172 buffer: buffer.downgrade(),
2173 position: edits
2174 .last()
2175 .map(|(range, _)| range.end)
2176 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2177 }),
2178 cx,
2179 );
2180 })?;
2181
2182 let format_on_save = cx.update(|cx| {
2183 action_log.update(cx, |action_log, cx| {
2184 action_log.buffer_read(buffer.clone(), cx);
2185 });
2186
2187 let format_on_save = buffer.update(cx, |buffer, cx| {
2188 buffer.edit(edits, None, cx);
2189
2190 let settings = language::language_settings::language_settings(
2191 buffer.language().map(|l| l.name()),
2192 buffer.file(),
2193 cx,
2194 );
2195
2196 settings.format_on_save != FormatOnSave::Off
2197 });
2198 action_log.update(cx, |action_log, cx| {
2199 action_log.buffer_edited(buffer.clone(), cx);
2200 });
2201 format_on_save
2202 })?;
2203
2204 if format_on_save {
2205 let format_task = project.update(cx, |project, cx| {
2206 project.format(
2207 HashSet::from_iter([buffer.clone()]),
2208 LspFormatTarget::Buffers,
2209 false,
2210 FormatTrigger::Save,
2211 cx,
2212 )
2213 })?;
2214 format_task.await.log_err();
2215
2216 action_log.update(cx, |action_log, cx| {
2217 action_log.buffer_edited(buffer.clone(), cx);
2218 })?;
2219 }
2220
2221 project
2222 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2223 .await
2224 })
2225 }
2226
2227 pub fn create_terminal(
2228 &self,
2229 command: String,
2230 args: Vec<String>,
2231 extra_env: Vec<acp::EnvVariable>,
2232 cwd: Option<PathBuf>,
2233 output_byte_limit: Option<u64>,
2234 cx: &mut Context<Self>,
2235 ) -> Task<Result<Entity<Terminal>>> {
2236 let env = match &cwd {
2237 Some(dir) => self.project.update(cx, |project, cx| {
2238 project.environment().update(cx, |env, cx| {
2239 env.directory_environment(dir.as_path().into(), cx)
2240 })
2241 }),
2242 None => Task::ready(None).shared(),
2243 };
2244 let env = cx.spawn(async move |_, _| {
2245 let mut env = env.await.unwrap_or_default();
2246 // Disables paging for `git` and hopefully other commands
2247 env.insert("PAGER".into(), "".into());
2248 for var in extra_env {
2249 env.insert(var.name, var.value);
2250 }
2251 env
2252 });
2253
2254 let project = self.project.clone();
2255 let language_registry = project.read(cx).languages().clone();
2256 let is_windows = project.read(cx).path_style(cx).is_windows();
2257
2258 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2259 let terminal_task = cx.spawn({
2260 let terminal_id = terminal_id.clone();
2261 async move |_this, cx| {
2262 let env = env.await;
2263 let shell = project
2264 .update(cx, |project, cx| {
2265 project
2266 .remote_client()
2267 .and_then(|r| r.read(cx).default_system_shell())
2268 })?
2269 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2270 let (task_command, task_args) =
2271 ShellBuilder::new(&Shell::Program(shell), is_windows)
2272 .redirect_stdin_to_dev_null()
2273 .build(Some(command.clone()), &args);
2274 let terminal = project
2275 .update(cx, |project, cx| {
2276 project.create_terminal_task(
2277 task::SpawnInTerminal {
2278 command: Some(task_command),
2279 args: task_args,
2280 cwd: cwd.clone(),
2281 env,
2282 ..Default::default()
2283 },
2284 cx,
2285 )
2286 })?
2287 .await?;
2288
2289 cx.new(|cx| {
2290 Terminal::new(
2291 terminal_id,
2292 &format!("{} {}", command, args.join(" ")),
2293 cwd,
2294 output_byte_limit.map(|l| l as usize),
2295 terminal,
2296 language_registry,
2297 cx,
2298 )
2299 })
2300 }
2301 });
2302
2303 cx.spawn(async move |this, cx| {
2304 let terminal = terminal_task.await?;
2305 this.update(cx, |this, _cx| {
2306 this.terminals.insert(terminal_id, terminal.clone());
2307 terminal
2308 })
2309 })
2310 }
2311
2312 pub fn kill_terminal(
2313 &mut self,
2314 terminal_id: acp::TerminalId,
2315 cx: &mut Context<Self>,
2316 ) -> Result<()> {
2317 self.terminals
2318 .get(&terminal_id)
2319 .context("Terminal not found")?
2320 .update(cx, |terminal, cx| {
2321 terminal.kill(cx);
2322 });
2323
2324 Ok(())
2325 }
2326
2327 pub fn release_terminal(
2328 &mut self,
2329 terminal_id: acp::TerminalId,
2330 cx: &mut Context<Self>,
2331 ) -> Result<()> {
2332 self.terminals
2333 .remove(&terminal_id)
2334 .context("Terminal not found")?
2335 .update(cx, |terminal, cx| {
2336 terminal.kill(cx);
2337 });
2338
2339 Ok(())
2340 }
2341
2342 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2343 self.terminals
2344 .get(&terminal_id)
2345 .context("Terminal not found")
2346 .cloned()
2347 }
2348
2349 pub fn to_markdown(&self, cx: &App) -> String {
2350 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2351 }
2352
2353 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2354 cx.emit(AcpThreadEvent::LoadError(error));
2355 }
2356
2357 pub fn register_terminal_created(
2358 &mut self,
2359 terminal_id: acp::TerminalId,
2360 command_label: String,
2361 working_dir: Option<PathBuf>,
2362 output_byte_limit: Option<u64>,
2363 terminal: Entity<::terminal::Terminal>,
2364 cx: &mut Context<Self>,
2365 ) -> Entity<Terminal> {
2366 let language_registry = self.project.read(cx).languages().clone();
2367
2368 let entity = cx.new(|cx| {
2369 Terminal::new(
2370 terminal_id.clone(),
2371 &command_label,
2372 working_dir.clone(),
2373 output_byte_limit.map(|l| l as usize),
2374 terminal,
2375 language_registry,
2376 cx,
2377 )
2378 });
2379 self.terminals.insert(terminal_id.clone(), entity.clone());
2380 entity
2381 }
2382}
2383
2384fn markdown_for_raw_output(
2385 raw_output: &serde_json::Value,
2386 language_registry: &Arc<LanguageRegistry>,
2387 cx: &mut App,
2388) -> Option<Entity<Markdown>> {
2389 match raw_output {
2390 serde_json::Value::Null => None,
2391 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2392 Markdown::new(
2393 value.to_string().into(),
2394 Some(language_registry.clone()),
2395 None,
2396 cx,
2397 )
2398 })),
2399 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2400 Markdown::new(
2401 value.to_string().into(),
2402 Some(language_registry.clone()),
2403 None,
2404 cx,
2405 )
2406 })),
2407 serde_json::Value::String(value) => Some(cx.new(|cx| {
2408 Markdown::new(
2409 value.clone().into(),
2410 Some(language_registry.clone()),
2411 None,
2412 cx,
2413 )
2414 })),
2415 value => Some(cx.new(|cx| {
2416 Markdown::new(
2417 format!("```json\n{}\n```", value).into(),
2418 Some(language_registry.clone()),
2419 None,
2420 cx,
2421 )
2422 })),
2423 }
2424}
2425
2426#[cfg(test)]
2427mod tests {
2428 use super::*;
2429 use anyhow::anyhow;
2430 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2431 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2432 use indoc::indoc;
2433 use project::{FakeFs, Fs};
2434 use rand::{distr, prelude::*};
2435 use serde_json::json;
2436 use settings::SettingsStore;
2437 use smol::stream::StreamExt as _;
2438 use std::{
2439 any::Any,
2440 cell::RefCell,
2441 path::Path,
2442 rc::Rc,
2443 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2444 time::Duration,
2445 };
2446 use util::path;
2447
2448 fn init_test(cx: &mut TestAppContext) {
2449 env_logger::try_init().ok();
2450 cx.update(|cx| {
2451 let settings_store = SettingsStore::test(cx);
2452 cx.set_global(settings_store);
2453 });
2454 }
2455
2456 #[gpui::test]
2457 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2458 init_test(cx);
2459
2460 let fs = FakeFs::new(cx.executor());
2461 let project = Project::test(fs, [], cx).await;
2462 let connection = Rc::new(FakeAgentConnection::new());
2463 let thread = cx
2464 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2465 .await
2466 .unwrap();
2467
2468 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2469
2470 // Send Output BEFORE Created - should be buffered by acp_thread
2471 thread.update(cx, |thread, cx| {
2472 thread.on_terminal_provider_event(
2473 TerminalProviderEvent::Output {
2474 terminal_id: terminal_id.clone(),
2475 data: b"hello buffered".to_vec(),
2476 },
2477 cx,
2478 );
2479 });
2480
2481 // Create a display-only terminal and then send Created
2482 let lower = cx.new(|cx| {
2483 let builder = ::terminal::TerminalBuilder::new_display_only(
2484 ::terminal::terminal_settings::CursorShape::default(),
2485 ::terminal::terminal_settings::AlternateScroll::On,
2486 None,
2487 0,
2488 )
2489 .unwrap();
2490 builder.subscribe(cx)
2491 });
2492
2493 thread.update(cx, |thread, cx| {
2494 thread.on_terminal_provider_event(
2495 TerminalProviderEvent::Created {
2496 terminal_id: terminal_id.clone(),
2497 label: "Buffered Test".to_string(),
2498 cwd: None,
2499 output_byte_limit: None,
2500 terminal: lower.clone(),
2501 },
2502 cx,
2503 );
2504 });
2505
2506 // After Created, buffered Output should have been flushed into the renderer
2507 let content = thread.read_with(cx, |thread, cx| {
2508 let term = thread.terminal(terminal_id.clone()).unwrap();
2509 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2510 });
2511
2512 assert!(
2513 content.contains("hello buffered"),
2514 "expected buffered output to render, got: {content}"
2515 );
2516 }
2517
2518 #[gpui::test]
2519 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2520 init_test(cx);
2521
2522 let fs = FakeFs::new(cx.executor());
2523 let project = Project::test(fs, [], cx).await;
2524 let connection = Rc::new(FakeAgentConnection::new());
2525 let thread = cx
2526 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2527 .await
2528 .unwrap();
2529
2530 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2531
2532 // Send Output BEFORE Created
2533 thread.update(cx, |thread, cx| {
2534 thread.on_terminal_provider_event(
2535 TerminalProviderEvent::Output {
2536 terminal_id: terminal_id.clone(),
2537 data: b"pre-exit data".to_vec(),
2538 },
2539 cx,
2540 );
2541 });
2542
2543 // Send Exit BEFORE Created
2544 thread.update(cx, |thread, cx| {
2545 thread.on_terminal_provider_event(
2546 TerminalProviderEvent::Exit {
2547 terminal_id: terminal_id.clone(),
2548 status: acp::TerminalExitStatus::new().exit_code(0),
2549 },
2550 cx,
2551 );
2552 });
2553
2554 // Now create a display-only lower-level terminal and send Created
2555 let lower = cx.new(|cx| {
2556 let builder = ::terminal::TerminalBuilder::new_display_only(
2557 ::terminal::terminal_settings::CursorShape::default(),
2558 ::terminal::terminal_settings::AlternateScroll::On,
2559 None,
2560 0,
2561 )
2562 .unwrap();
2563 builder.subscribe(cx)
2564 });
2565
2566 thread.update(cx, |thread, cx| {
2567 thread.on_terminal_provider_event(
2568 TerminalProviderEvent::Created {
2569 terminal_id: terminal_id.clone(),
2570 label: "Buffered Exit Test".to_string(),
2571 cwd: None,
2572 output_byte_limit: None,
2573 terminal: lower.clone(),
2574 },
2575 cx,
2576 );
2577 });
2578
2579 // Output should be present after Created (flushed from buffer)
2580 let content = thread.read_with(cx, |thread, cx| {
2581 let term = thread.terminal(terminal_id.clone()).unwrap();
2582 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2583 });
2584
2585 assert!(
2586 content.contains("pre-exit data"),
2587 "expected pre-exit data to render, got: {content}"
2588 );
2589 }
2590
2591 #[gpui::test]
2592 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2593 init_test(cx);
2594
2595 let fs = FakeFs::new(cx.executor());
2596 let project = Project::test(fs, [], cx).await;
2597 let connection = Rc::new(FakeAgentConnection::new());
2598 let thread = cx
2599 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2600 .await
2601 .unwrap();
2602
2603 // Test creating a new user message
2604 thread.update(cx, |thread, cx| {
2605 thread.push_user_content_block(None, "Hello, ".into(), cx);
2606 });
2607
2608 thread.update(cx, |thread, cx| {
2609 assert_eq!(thread.entries.len(), 1);
2610 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2611 assert_eq!(user_msg.id, None);
2612 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2613 } else {
2614 panic!("Expected UserMessage");
2615 }
2616 });
2617
2618 // Test appending to existing user message
2619 let message_1_id = UserMessageId::new();
2620 thread.update(cx, |thread, cx| {
2621 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2622 });
2623
2624 thread.update(cx, |thread, cx| {
2625 assert_eq!(thread.entries.len(), 1);
2626 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2627 assert_eq!(user_msg.id, Some(message_1_id));
2628 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2629 } else {
2630 panic!("Expected UserMessage");
2631 }
2632 });
2633
2634 // Test creating new user message after assistant message
2635 thread.update(cx, |thread, cx| {
2636 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2637 });
2638
2639 let message_2_id = UserMessageId::new();
2640 thread.update(cx, |thread, cx| {
2641 thread.push_user_content_block(
2642 Some(message_2_id.clone()),
2643 "New user message".into(),
2644 cx,
2645 );
2646 });
2647
2648 thread.update(cx, |thread, cx| {
2649 assert_eq!(thread.entries.len(), 3);
2650 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2651 assert_eq!(user_msg.id, Some(message_2_id));
2652 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2653 } else {
2654 panic!("Expected UserMessage at index 2");
2655 }
2656 });
2657 }
2658
2659 #[gpui::test]
2660 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2661 init_test(cx);
2662
2663 let fs = FakeFs::new(cx.executor());
2664 let project = Project::test(fs, [], cx).await;
2665 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2666 |_, thread, mut cx| {
2667 async move {
2668 thread.update(&mut cx, |thread, cx| {
2669 thread
2670 .handle_session_update(
2671 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2672 "Thinking ".into(),
2673 )),
2674 cx,
2675 )
2676 .unwrap();
2677 thread
2678 .handle_session_update(
2679 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2680 "hard!".into(),
2681 )),
2682 cx,
2683 )
2684 .unwrap();
2685 })?;
2686 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2687 }
2688 .boxed_local()
2689 },
2690 ));
2691
2692 let thread = cx
2693 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2694 .await
2695 .unwrap();
2696
2697 thread
2698 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2699 .await
2700 .unwrap();
2701
2702 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2703 assert_eq!(
2704 output,
2705 indoc! {r#"
2706 ## User
2707
2708 Hello from Zed!
2709
2710 ## Assistant
2711
2712 <thinking>
2713 Thinking hard!
2714 </thinking>
2715
2716 "#}
2717 );
2718 }
2719
2720 #[gpui::test]
2721 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2722 init_test(cx);
2723
2724 let fs = FakeFs::new(cx.executor());
2725 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2726 .await;
2727 let project = Project::test(fs.clone(), [], cx).await;
2728 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2729 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2730 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2731 move |_, thread, mut cx| {
2732 let read_file_tx = read_file_tx.clone();
2733 async move {
2734 let content = thread
2735 .update(&mut cx, |thread, cx| {
2736 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2737 })
2738 .unwrap()
2739 .await
2740 .unwrap();
2741 assert_eq!(content, "one\ntwo\nthree\n");
2742 read_file_tx.take().unwrap().send(()).unwrap();
2743 thread
2744 .update(&mut cx, |thread, cx| {
2745 thread.write_text_file(
2746 path!("/tmp/foo").into(),
2747 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2748 cx,
2749 )
2750 })
2751 .unwrap()
2752 .await
2753 .unwrap();
2754 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2755 }
2756 .boxed_local()
2757 },
2758 ));
2759
2760 let (worktree, pathbuf) = project
2761 .update(cx, |project, cx| {
2762 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2763 })
2764 .await
2765 .unwrap();
2766 let buffer = project
2767 .update(cx, |project, cx| {
2768 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2769 })
2770 .await
2771 .unwrap();
2772
2773 let thread = cx
2774 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2775 .await
2776 .unwrap();
2777
2778 let request = thread.update(cx, |thread, cx| {
2779 thread.send_raw("Extend the count in /tmp/foo", cx)
2780 });
2781 read_file_rx.await.ok();
2782 buffer.update(cx, |buffer, cx| {
2783 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2784 });
2785 cx.run_until_parked();
2786 assert_eq!(
2787 buffer.read_with(cx, |buffer, _| buffer.text()),
2788 "zero\none\ntwo\nthree\nfour\nfive\n"
2789 );
2790 assert_eq!(
2791 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2792 "zero\none\ntwo\nthree\nfour\nfive\n"
2793 );
2794 request.await.unwrap();
2795 }
2796
2797 #[gpui::test]
2798 async fn test_reading_from_line(cx: &mut TestAppContext) {
2799 init_test(cx);
2800
2801 let fs = FakeFs::new(cx.executor());
2802 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2803 .await;
2804 let project = Project::test(fs.clone(), [], cx).await;
2805 project
2806 .update(cx, |project, cx| {
2807 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2808 })
2809 .await
2810 .unwrap();
2811
2812 let connection = Rc::new(FakeAgentConnection::new());
2813
2814 let thread = cx
2815 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2816 .await
2817 .unwrap();
2818
2819 // Whole file
2820 let content = thread
2821 .update(cx, |thread, cx| {
2822 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2823 })
2824 .await
2825 .unwrap();
2826
2827 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2828
2829 // Only start line
2830 let content = thread
2831 .update(cx, |thread, cx| {
2832 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2833 })
2834 .await
2835 .unwrap();
2836
2837 assert_eq!(content, "three\nfour\n");
2838
2839 // Only limit
2840 let content = thread
2841 .update(cx, |thread, cx| {
2842 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2843 })
2844 .await
2845 .unwrap();
2846
2847 assert_eq!(content, "one\ntwo\n");
2848
2849 // Range
2850 let content = thread
2851 .update(cx, |thread, cx| {
2852 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2853 })
2854 .await
2855 .unwrap();
2856
2857 assert_eq!(content, "two\nthree\n");
2858
2859 // Invalid
2860 let err = thread
2861 .update(cx, |thread, cx| {
2862 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2863 })
2864 .await
2865 .unwrap_err();
2866
2867 assert_eq!(
2868 err.to_string(),
2869 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2870 );
2871 }
2872
2873 #[gpui::test]
2874 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2875 init_test(cx);
2876
2877 let fs = FakeFs::new(cx.executor());
2878 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2879 let project = Project::test(fs.clone(), [], cx).await;
2880 project
2881 .update(cx, |project, cx| {
2882 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2883 })
2884 .await
2885 .unwrap();
2886
2887 let connection = Rc::new(FakeAgentConnection::new());
2888
2889 let thread = cx
2890 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2891 .await
2892 .unwrap();
2893
2894 // Whole file
2895 let content = thread
2896 .update(cx, |thread, cx| {
2897 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2898 })
2899 .await
2900 .unwrap();
2901
2902 assert_eq!(content, "");
2903
2904 // Only start line
2905 let content = thread
2906 .update(cx, |thread, cx| {
2907 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2908 })
2909 .await
2910 .unwrap();
2911
2912 assert_eq!(content, "");
2913
2914 // Only limit
2915 let content = thread
2916 .update(cx, |thread, cx| {
2917 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2918 })
2919 .await
2920 .unwrap();
2921
2922 assert_eq!(content, "");
2923
2924 // Range
2925 let content = thread
2926 .update(cx, |thread, cx| {
2927 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2928 })
2929 .await
2930 .unwrap();
2931
2932 assert_eq!(content, "");
2933
2934 // Invalid
2935 let err = thread
2936 .update(cx, |thread, cx| {
2937 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2938 })
2939 .await
2940 .unwrap_err();
2941
2942 assert_eq!(
2943 err.to_string(),
2944 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2945 );
2946 }
2947 #[gpui::test]
2948 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2949 init_test(cx);
2950
2951 let fs = FakeFs::new(cx.executor());
2952 fs.insert_tree(path!("/tmp"), json!({})).await;
2953 let project = Project::test(fs.clone(), [], cx).await;
2954 project
2955 .update(cx, |project, cx| {
2956 project.find_or_create_worktree(path!("/tmp"), true, cx)
2957 })
2958 .await
2959 .unwrap();
2960
2961 let connection = Rc::new(FakeAgentConnection::new());
2962
2963 let thread = cx
2964 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2965 .await
2966 .unwrap();
2967
2968 // Out of project file
2969 let err = thread
2970 .update(cx, |thread, cx| {
2971 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2972 })
2973 .await
2974 .unwrap_err();
2975
2976 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
2977 }
2978
2979 #[gpui::test]
2980 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2981 init_test(cx);
2982
2983 let fs = FakeFs::new(cx.executor());
2984 let project = Project::test(fs, [], cx).await;
2985 let id = acp::ToolCallId::new("test");
2986
2987 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2988 let id = id.clone();
2989 move |_, thread, mut cx| {
2990 let id = id.clone();
2991 async move {
2992 thread
2993 .update(&mut cx, |thread, cx| {
2994 thread.handle_session_update(
2995 acp::SessionUpdate::ToolCall(
2996 acp::ToolCall::new(id.clone(), "Label")
2997 .kind(acp::ToolKind::Fetch)
2998 .status(acp::ToolCallStatus::InProgress),
2999 ),
3000 cx,
3001 )
3002 })
3003 .unwrap()
3004 .unwrap();
3005 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3006 }
3007 .boxed_local()
3008 }
3009 }));
3010
3011 let thread = cx
3012 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3013 .await
3014 .unwrap();
3015
3016 let request = thread.update(cx, |thread, cx| {
3017 thread.send_raw("Fetch https://example.com", cx)
3018 });
3019
3020 run_until_first_tool_call(&thread, cx).await;
3021
3022 thread.read_with(cx, |thread, _| {
3023 assert!(matches!(
3024 thread.entries[1],
3025 AgentThreadEntry::ToolCall(ToolCall {
3026 status: ToolCallStatus::InProgress,
3027 ..
3028 })
3029 ));
3030 });
3031
3032 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3033
3034 thread.read_with(cx, |thread, _| {
3035 assert!(matches!(
3036 &thread.entries[1],
3037 AgentThreadEntry::ToolCall(ToolCall {
3038 status: ToolCallStatus::Canceled,
3039 ..
3040 })
3041 ));
3042 });
3043
3044 thread
3045 .update(cx, |thread, cx| {
3046 thread.handle_session_update(
3047 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3048 id,
3049 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3050 )),
3051 cx,
3052 )
3053 })
3054 .unwrap();
3055
3056 request.await.unwrap();
3057
3058 thread.read_with(cx, |thread, _| {
3059 assert!(matches!(
3060 thread.entries[1],
3061 AgentThreadEntry::ToolCall(ToolCall {
3062 status: ToolCallStatus::Completed,
3063 ..
3064 })
3065 ));
3066 });
3067 }
3068
3069 #[gpui::test]
3070 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3071 init_test(cx);
3072 let fs = FakeFs::new(cx.background_executor.clone());
3073 fs.insert_tree(path!("/test"), json!({})).await;
3074 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3075
3076 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3077 move |_, thread, mut cx| {
3078 async move {
3079 thread
3080 .update(&mut cx, |thread, cx| {
3081 thread.handle_session_update(
3082 acp::SessionUpdate::ToolCall(
3083 acp::ToolCall::new("test", "Label")
3084 .kind(acp::ToolKind::Edit)
3085 .status(acp::ToolCallStatus::Completed)
3086 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3087 "/test/test.txt",
3088 "foo",
3089 ))]),
3090 ),
3091 cx,
3092 )
3093 })
3094 .unwrap()
3095 .unwrap();
3096 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3097 }
3098 .boxed_local()
3099 }
3100 }));
3101
3102 let thread = cx
3103 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3104 .await
3105 .unwrap();
3106
3107 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3108 .await
3109 .unwrap();
3110
3111 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3112 }
3113
3114 #[gpui::test(iterations = 10)]
3115 async fn test_checkpoints(cx: &mut TestAppContext) {
3116 init_test(cx);
3117 let fs = FakeFs::new(cx.background_executor.clone());
3118 fs.insert_tree(
3119 path!("/test"),
3120 json!({
3121 ".git": {}
3122 }),
3123 )
3124 .await;
3125 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3126
3127 let simulate_changes = Arc::new(AtomicBool::new(true));
3128 let next_filename = Arc::new(AtomicUsize::new(0));
3129 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3130 let simulate_changes = simulate_changes.clone();
3131 let next_filename = next_filename.clone();
3132 let fs = fs.clone();
3133 move |request, thread, mut cx| {
3134 let fs = fs.clone();
3135 let simulate_changes = simulate_changes.clone();
3136 let next_filename = next_filename.clone();
3137 async move {
3138 if simulate_changes.load(SeqCst) {
3139 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3140 fs.write(Path::new(&filename), b"").await?;
3141 }
3142
3143 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3144 panic!("expected text content block");
3145 };
3146 thread.update(&mut cx, |thread, cx| {
3147 thread
3148 .handle_session_update(
3149 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3150 content.text.to_uppercase().into(),
3151 )),
3152 cx,
3153 )
3154 .unwrap();
3155 })?;
3156 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3157 }
3158 .boxed_local()
3159 }
3160 }));
3161 let thread = cx
3162 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3163 .await
3164 .unwrap();
3165
3166 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3167 .await
3168 .unwrap();
3169 thread.read_with(cx, |thread, cx| {
3170 assert_eq!(
3171 thread.to_markdown(cx),
3172 indoc! {"
3173 ## User (checkpoint)
3174
3175 Lorem
3176
3177 ## Assistant
3178
3179 LOREM
3180
3181 "}
3182 );
3183 });
3184 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3185
3186 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3187 .await
3188 .unwrap();
3189 thread.read_with(cx, |thread, cx| {
3190 assert_eq!(
3191 thread.to_markdown(cx),
3192 indoc! {"
3193 ## User (checkpoint)
3194
3195 Lorem
3196
3197 ## Assistant
3198
3199 LOREM
3200
3201 ## User (checkpoint)
3202
3203 ipsum
3204
3205 ## Assistant
3206
3207 IPSUM
3208
3209 "}
3210 );
3211 });
3212 assert_eq!(
3213 fs.files(),
3214 vec![
3215 Path::new(path!("/test/file-0")),
3216 Path::new(path!("/test/file-1"))
3217 ]
3218 );
3219
3220 // Checkpoint isn't stored when there are no changes.
3221 simulate_changes.store(false, SeqCst);
3222 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3223 .await
3224 .unwrap();
3225 thread.read_with(cx, |thread, cx| {
3226 assert_eq!(
3227 thread.to_markdown(cx),
3228 indoc! {"
3229 ## User (checkpoint)
3230
3231 Lorem
3232
3233 ## Assistant
3234
3235 LOREM
3236
3237 ## User (checkpoint)
3238
3239 ipsum
3240
3241 ## Assistant
3242
3243 IPSUM
3244
3245 ## User
3246
3247 dolor
3248
3249 ## Assistant
3250
3251 DOLOR
3252
3253 "}
3254 );
3255 });
3256 assert_eq!(
3257 fs.files(),
3258 vec![
3259 Path::new(path!("/test/file-0")),
3260 Path::new(path!("/test/file-1"))
3261 ]
3262 );
3263
3264 // Rewinding the conversation truncates the history and restores the checkpoint.
3265 thread
3266 .update(cx, |thread, cx| {
3267 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3268 panic!("unexpected entries {:?}", thread.entries)
3269 };
3270 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3271 })
3272 .await
3273 .unwrap();
3274 thread.read_with(cx, |thread, cx| {
3275 assert_eq!(
3276 thread.to_markdown(cx),
3277 indoc! {"
3278 ## User (checkpoint)
3279
3280 Lorem
3281
3282 ## Assistant
3283
3284 LOREM
3285
3286 "}
3287 );
3288 });
3289 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3290 }
3291
3292 #[gpui::test]
3293 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3294 use std::sync::atomic::AtomicUsize;
3295 init_test(cx);
3296
3297 let fs = FakeFs::new(cx.executor());
3298 let project = Project::test(fs, None, cx).await;
3299
3300 // Create a connection that simulates refusal after tool result
3301 let prompt_count = Arc::new(AtomicUsize::new(0));
3302 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3303 let prompt_count = prompt_count.clone();
3304 move |_request, thread, mut cx| {
3305 let count = prompt_count.fetch_add(1, SeqCst);
3306 async move {
3307 if count == 0 {
3308 // First prompt: Generate a tool call with result
3309 thread.update(&mut cx, |thread, cx| {
3310 thread
3311 .handle_session_update(
3312 acp::SessionUpdate::ToolCall(
3313 acp::ToolCall::new("tool1", "Test Tool")
3314 .kind(acp::ToolKind::Fetch)
3315 .status(acp::ToolCallStatus::Completed)
3316 .raw_input(serde_json::json!({"query": "test"}))
3317 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3318 ),
3319 cx,
3320 )
3321 .unwrap();
3322 })?;
3323
3324 // Now return refusal because of the tool result
3325 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3326 } else {
3327 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3328 }
3329 }
3330 .boxed_local()
3331 }
3332 }));
3333
3334 let thread = cx
3335 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3336 .await
3337 .unwrap();
3338
3339 // Track if we see a Refusal event
3340 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3341 let saw_refusal_event_captured = saw_refusal_event.clone();
3342 thread.update(cx, |_thread, cx| {
3343 cx.subscribe(
3344 &thread,
3345 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3346 if matches!(event, AcpThreadEvent::Refusal) {
3347 *saw_refusal_event_captured.lock().unwrap() = true;
3348 }
3349 },
3350 )
3351 .detach();
3352 });
3353
3354 // Send a user message - this will trigger tool call and then refusal
3355 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3356 cx.background_executor.spawn(send_task).detach();
3357 cx.run_until_parked();
3358
3359 // Verify that:
3360 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3361 // 2. The user message was NOT truncated
3362 assert!(
3363 *saw_refusal_event.lock().unwrap(),
3364 "Refusal event should be emitted for tool result refusals"
3365 );
3366
3367 thread.read_with(cx, |thread, _| {
3368 let entries = thread.entries();
3369 assert!(entries.len() >= 2, "Should have user message and tool call");
3370
3371 // Verify user message is still there
3372 assert!(
3373 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3374 "User message should not be truncated"
3375 );
3376
3377 // Verify tool call is there with result
3378 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3379 assert!(
3380 tool_call.raw_output.is_some(),
3381 "Tool call should have output"
3382 );
3383 } else {
3384 panic!("Expected tool call at index 1");
3385 }
3386 });
3387 }
3388
3389 #[gpui::test]
3390 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3391 init_test(cx);
3392
3393 let fs = FakeFs::new(cx.executor());
3394 let project = Project::test(fs, None, cx).await;
3395
3396 let refuse_next = Arc::new(AtomicBool::new(false));
3397 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3398 let refuse_next = refuse_next.clone();
3399 move |_request, _thread, _cx| {
3400 if refuse_next.load(SeqCst) {
3401 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3402 .boxed_local()
3403 } else {
3404 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3405 .boxed_local()
3406 }
3407 }
3408 }));
3409
3410 let thread = cx
3411 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3412 .await
3413 .unwrap();
3414
3415 // Track if we see a Refusal event
3416 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3417 let saw_refusal_event_captured = saw_refusal_event.clone();
3418 thread.update(cx, |_thread, cx| {
3419 cx.subscribe(
3420 &thread,
3421 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3422 if matches!(event, AcpThreadEvent::Refusal) {
3423 *saw_refusal_event_captured.lock().unwrap() = true;
3424 }
3425 },
3426 )
3427 .detach();
3428 });
3429
3430 // Send a message that will be refused
3431 refuse_next.store(true, SeqCst);
3432 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3433 .await
3434 .unwrap();
3435
3436 // Verify that a Refusal event WAS emitted for user prompt refusal
3437 assert!(
3438 *saw_refusal_event.lock().unwrap(),
3439 "Refusal event should be emitted for user prompt refusals"
3440 );
3441
3442 // Verify the message was truncated (user prompt refusal)
3443 thread.read_with(cx, |thread, cx| {
3444 assert_eq!(thread.to_markdown(cx), "");
3445 });
3446 }
3447
3448 #[gpui::test]
3449 async fn test_refusal(cx: &mut TestAppContext) {
3450 init_test(cx);
3451 let fs = FakeFs::new(cx.background_executor.clone());
3452 fs.insert_tree(path!("/"), json!({})).await;
3453 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3454
3455 let refuse_next = Arc::new(AtomicBool::new(false));
3456 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3457 let refuse_next = refuse_next.clone();
3458 move |request, thread, mut cx| {
3459 let refuse_next = refuse_next.clone();
3460 async move {
3461 if refuse_next.load(SeqCst) {
3462 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3463 }
3464
3465 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3466 panic!("expected text content block");
3467 };
3468 thread.update(&mut cx, |thread, cx| {
3469 thread
3470 .handle_session_update(
3471 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3472 content.text.to_uppercase().into(),
3473 )),
3474 cx,
3475 )
3476 .unwrap();
3477 })?;
3478 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3479 }
3480 .boxed_local()
3481 }
3482 }));
3483 let thread = cx
3484 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3485 .await
3486 .unwrap();
3487
3488 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3489 .await
3490 .unwrap();
3491 thread.read_with(cx, |thread, cx| {
3492 assert_eq!(
3493 thread.to_markdown(cx),
3494 indoc! {"
3495 ## User
3496
3497 hello
3498
3499 ## Assistant
3500
3501 HELLO
3502
3503 "}
3504 );
3505 });
3506
3507 // Simulate refusing the second message. The message should be truncated
3508 // when a user prompt is refused.
3509 refuse_next.store(true, SeqCst);
3510 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3511 .await
3512 .unwrap();
3513 thread.read_with(cx, |thread, cx| {
3514 assert_eq!(
3515 thread.to_markdown(cx),
3516 indoc! {"
3517 ## User
3518
3519 hello
3520
3521 ## Assistant
3522
3523 HELLO
3524
3525 "}
3526 );
3527 });
3528 }
3529
3530 async fn run_until_first_tool_call(
3531 thread: &Entity<AcpThread>,
3532 cx: &mut TestAppContext,
3533 ) -> usize {
3534 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3535
3536 let subscription = cx.update(|cx| {
3537 cx.subscribe(thread, move |thread, _, cx| {
3538 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3539 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3540 return tx.try_send(ix).unwrap();
3541 }
3542 }
3543 })
3544 });
3545
3546 select! {
3547 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3548 panic!("Timeout waiting for tool call")
3549 }
3550 ix = rx.next().fuse() => {
3551 drop(subscription);
3552 ix.unwrap()
3553 }
3554 }
3555 }
3556
3557 #[derive(Clone, Default)]
3558 struct FakeAgentConnection {
3559 auth_methods: Vec<acp::AuthMethod>,
3560 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3561 on_user_message: Option<
3562 Rc<
3563 dyn Fn(
3564 acp::PromptRequest,
3565 WeakEntity<AcpThread>,
3566 AsyncApp,
3567 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3568 + 'static,
3569 >,
3570 >,
3571 }
3572
3573 impl FakeAgentConnection {
3574 fn new() -> Self {
3575 Self {
3576 auth_methods: Vec::new(),
3577 on_user_message: None,
3578 sessions: Arc::default(),
3579 }
3580 }
3581
3582 #[expect(unused)]
3583 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3584 self.auth_methods = auth_methods;
3585 self
3586 }
3587
3588 fn on_user_message(
3589 mut self,
3590 handler: impl Fn(
3591 acp::PromptRequest,
3592 WeakEntity<AcpThread>,
3593 AsyncApp,
3594 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3595 + 'static,
3596 ) -> Self {
3597 self.on_user_message.replace(Rc::new(handler));
3598 self
3599 }
3600 }
3601
3602 impl AgentConnection for FakeAgentConnection {
3603 fn telemetry_id(&self) -> SharedString {
3604 "fake".into()
3605 }
3606
3607 fn auth_methods(&self) -> &[acp::AuthMethod] {
3608 &self.auth_methods
3609 }
3610
3611 fn new_thread(
3612 self: Rc<Self>,
3613 project: Entity<Project>,
3614 _cwd: &Path,
3615 cx: &mut App,
3616 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3617 let session_id = acp::SessionId::new(
3618 rand::rng()
3619 .sample_iter(&distr::Alphanumeric)
3620 .take(7)
3621 .map(char::from)
3622 .collect::<String>(),
3623 );
3624 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3625 let thread = cx.new(|cx| {
3626 AcpThread::new(
3627 "Test",
3628 self.clone(),
3629 project,
3630 action_log,
3631 session_id.clone(),
3632 watch::Receiver::constant(
3633 acp::PromptCapabilities::new()
3634 .image(true)
3635 .audio(true)
3636 .embedded_context(true),
3637 ),
3638 cx,
3639 )
3640 });
3641 self.sessions.lock().insert(session_id, thread.downgrade());
3642 Task::ready(Ok(thread))
3643 }
3644
3645 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3646 if self.auth_methods().iter().any(|m| m.id == method) {
3647 Task::ready(Ok(()))
3648 } else {
3649 Task::ready(Err(anyhow!("Invalid Auth Method")))
3650 }
3651 }
3652
3653 fn prompt(
3654 &self,
3655 _id: Option<UserMessageId>,
3656 params: acp::PromptRequest,
3657 cx: &mut App,
3658 ) -> Task<gpui::Result<acp::PromptResponse>> {
3659 let sessions = self.sessions.lock();
3660 let thread = sessions.get(¶ms.session_id).unwrap();
3661 if let Some(handler) = &self.on_user_message {
3662 let handler = handler.clone();
3663 let thread = thread.clone();
3664 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3665 } else {
3666 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3667 }
3668 }
3669
3670 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3671 let sessions = self.sessions.lock();
3672 let thread = sessions.get(session_id).unwrap().clone();
3673
3674 cx.spawn(async move |cx| {
3675 thread
3676 .update(cx, |thread, cx| thread.cancel(cx))
3677 .unwrap()
3678 .await
3679 })
3680 .detach();
3681 }
3682
3683 fn truncate(
3684 &self,
3685 session_id: &acp::SessionId,
3686 _cx: &App,
3687 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3688 Some(Rc::new(FakeAgentSessionEditor {
3689 _session_id: session_id.clone(),
3690 }))
3691 }
3692
3693 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3694 self
3695 }
3696 }
3697
3698 struct FakeAgentSessionEditor {
3699 _session_id: acp::SessionId,
3700 }
3701
3702 impl AgentSessionTruncate for FakeAgentSessionEditor {
3703 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3704 Task::ready(Ok(()))
3705 }
3706 }
3707
3708 #[gpui::test]
3709 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3710 init_test(cx);
3711
3712 let fs = FakeFs::new(cx.executor());
3713 let project = Project::test(fs, [], cx).await;
3714 let connection = Rc::new(FakeAgentConnection::new());
3715 let thread = cx
3716 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3717 .await
3718 .unwrap();
3719
3720 // Try to update a tool call that doesn't exist
3721 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3722 thread.update(cx, |thread, cx| {
3723 let result = thread.handle_session_update(
3724 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3725 nonexistent_id.clone(),
3726 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3727 )),
3728 cx,
3729 );
3730
3731 // The update should succeed (not return an error)
3732 assert!(result.is_ok());
3733
3734 // There should now be exactly one entry in the thread
3735 assert_eq!(thread.entries.len(), 1);
3736
3737 // The entry should be a failed tool call
3738 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3739 assert_eq!(tool_call.id, nonexistent_id);
3740 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3741 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3742
3743 // Check that the content contains the error message
3744 assert_eq!(tool_call.content.len(), 1);
3745 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3746 match content_block {
3747 ContentBlock::Markdown { markdown } => {
3748 let markdown_text = markdown.read(cx).source();
3749 assert!(markdown_text.contains("Tool call not found"));
3750 }
3751 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3752 ContentBlock::ResourceLink { .. } => {
3753 panic!("Expected markdown content, got resource link")
3754 }
3755 }
3756 } else {
3757 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3758 }
3759 } else {
3760 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3761 }
3762 });
3763 }
3764
3765 /// Tests that restoring a checkpoint properly cleans up terminals that were
3766 /// created after that checkpoint, and cancels any in-progress generation.
3767 ///
3768 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3769 /// that were started after that checkpoint should be terminated, and any in-progress
3770 /// AI generation should be canceled.
3771 #[gpui::test]
3772 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3773 init_test(cx);
3774
3775 let fs = FakeFs::new(cx.executor());
3776 let project = Project::test(fs, [], cx).await;
3777 let connection = Rc::new(FakeAgentConnection::new());
3778 let thread = cx
3779 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3780 .await
3781 .unwrap();
3782
3783 // Send first user message to create a checkpoint
3784 cx.update(|cx| {
3785 thread.update(cx, |thread, cx| {
3786 thread.send(vec!["first message".into()], cx)
3787 })
3788 })
3789 .await
3790 .unwrap();
3791
3792 // Send second message (creates another checkpoint) - we'll restore to this one
3793 cx.update(|cx| {
3794 thread.update(cx, |thread, cx| {
3795 thread.send(vec!["second message".into()], cx)
3796 })
3797 })
3798 .await
3799 .unwrap();
3800
3801 // Create 2 terminals BEFORE the checkpoint that have completed running
3802 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3803 let mock_terminal_1 = cx.new(|cx| {
3804 let builder = ::terminal::TerminalBuilder::new_display_only(
3805 ::terminal::terminal_settings::CursorShape::default(),
3806 ::terminal::terminal_settings::AlternateScroll::On,
3807 None,
3808 0,
3809 )
3810 .unwrap();
3811 builder.subscribe(cx)
3812 });
3813
3814 thread.update(cx, |thread, cx| {
3815 thread.on_terminal_provider_event(
3816 TerminalProviderEvent::Created {
3817 terminal_id: terminal_id_1.clone(),
3818 label: "echo 'first'".to_string(),
3819 cwd: Some(PathBuf::from("/test")),
3820 output_byte_limit: None,
3821 terminal: mock_terminal_1.clone(),
3822 },
3823 cx,
3824 );
3825 });
3826
3827 thread.update(cx, |thread, cx| {
3828 thread.on_terminal_provider_event(
3829 TerminalProviderEvent::Output {
3830 terminal_id: terminal_id_1.clone(),
3831 data: b"first\n".to_vec(),
3832 },
3833 cx,
3834 );
3835 });
3836
3837 thread.update(cx, |thread, cx| {
3838 thread.on_terminal_provider_event(
3839 TerminalProviderEvent::Exit {
3840 terminal_id: terminal_id_1.clone(),
3841 status: acp::TerminalExitStatus::new().exit_code(0),
3842 },
3843 cx,
3844 );
3845 });
3846
3847 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3848 let mock_terminal_2 = cx.new(|cx| {
3849 let builder = ::terminal::TerminalBuilder::new_display_only(
3850 ::terminal::terminal_settings::CursorShape::default(),
3851 ::terminal::terminal_settings::AlternateScroll::On,
3852 None,
3853 0,
3854 )
3855 .unwrap();
3856 builder.subscribe(cx)
3857 });
3858
3859 thread.update(cx, |thread, cx| {
3860 thread.on_terminal_provider_event(
3861 TerminalProviderEvent::Created {
3862 terminal_id: terminal_id_2.clone(),
3863 label: "echo 'second'".to_string(),
3864 cwd: Some(PathBuf::from("/test")),
3865 output_byte_limit: None,
3866 terminal: mock_terminal_2.clone(),
3867 },
3868 cx,
3869 );
3870 });
3871
3872 thread.update(cx, |thread, cx| {
3873 thread.on_terminal_provider_event(
3874 TerminalProviderEvent::Output {
3875 terminal_id: terminal_id_2.clone(),
3876 data: b"second\n".to_vec(),
3877 },
3878 cx,
3879 );
3880 });
3881
3882 thread.update(cx, |thread, cx| {
3883 thread.on_terminal_provider_event(
3884 TerminalProviderEvent::Exit {
3885 terminal_id: terminal_id_2.clone(),
3886 status: acp::TerminalExitStatus::new().exit_code(0),
3887 },
3888 cx,
3889 );
3890 });
3891
3892 // Get the second message ID to restore to
3893 let second_message_id = thread.read_with(cx, |thread, _| {
3894 // At this point we have:
3895 // - Index 0: First user message (with checkpoint)
3896 // - Index 1: Second user message (with checkpoint)
3897 // No assistant responses because FakeAgentConnection just returns EndTurn
3898 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3899 panic!("expected user message at index 1");
3900 };
3901 message.id.clone().unwrap()
3902 });
3903
3904 // Create a terminal AFTER the checkpoint we'll restore to.
3905 // This simulates the AI agent starting a long-running terminal command.
3906 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3907 let mock_terminal = cx.new(|cx| {
3908 let builder = ::terminal::TerminalBuilder::new_display_only(
3909 ::terminal::terminal_settings::CursorShape::default(),
3910 ::terminal::terminal_settings::AlternateScroll::On,
3911 None,
3912 0,
3913 )
3914 .unwrap();
3915 builder.subscribe(cx)
3916 });
3917
3918 // Register the terminal as created
3919 thread.update(cx, |thread, cx| {
3920 thread.on_terminal_provider_event(
3921 TerminalProviderEvent::Created {
3922 terminal_id: terminal_id.clone(),
3923 label: "sleep 1000".to_string(),
3924 cwd: Some(PathBuf::from("/test")),
3925 output_byte_limit: None,
3926 terminal: mock_terminal.clone(),
3927 },
3928 cx,
3929 );
3930 });
3931
3932 // Simulate the terminal producing output (still running)
3933 thread.update(cx, |thread, cx| {
3934 thread.on_terminal_provider_event(
3935 TerminalProviderEvent::Output {
3936 terminal_id: terminal_id.clone(),
3937 data: b"terminal is running...\n".to_vec(),
3938 },
3939 cx,
3940 );
3941 });
3942
3943 // Create a tool call entry that references this terminal
3944 // This represents the agent requesting a terminal command
3945 thread.update(cx, |thread, cx| {
3946 thread
3947 .handle_session_update(
3948 acp::SessionUpdate::ToolCall(
3949 acp::ToolCall::new("terminal-tool-1", "Running command")
3950 .kind(acp::ToolKind::Execute)
3951 .status(acp::ToolCallStatus::InProgress)
3952 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
3953 terminal_id.clone(),
3954 ))])
3955 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
3956 ),
3957 cx,
3958 )
3959 .unwrap();
3960 });
3961
3962 // Verify terminal exists and is in the thread
3963 let terminal_exists_before =
3964 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
3965 assert!(
3966 terminal_exists_before,
3967 "Terminal should exist before checkpoint restore"
3968 );
3969
3970 // Verify the terminal's underlying task is still running (not completed)
3971 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
3972 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
3973 terminal_entity.read_with(cx, |term, _cx| {
3974 term.output().is_none() // output is None means it's still running
3975 })
3976 });
3977 assert!(
3978 terminal_running_before,
3979 "Terminal should be running before checkpoint restore"
3980 );
3981
3982 // Verify we have the expected entries before restore
3983 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
3984 assert!(
3985 entry_count_before > 1,
3986 "Should have multiple entries before restore"
3987 );
3988
3989 // Restore the checkpoint to the second message.
3990 // This should:
3991 // 1. Cancel any in-progress generation (via the cancel() call)
3992 // 2. Remove the terminal that was created after that point
3993 thread
3994 .update(cx, |thread, cx| {
3995 thread.restore_checkpoint(second_message_id, cx)
3996 })
3997 .await
3998 .unwrap();
3999
4000 // Verify that no send_task is in progress after restore
4001 // (cancel() clears the send_task)
4002 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4003 assert!(
4004 !has_send_task_after,
4005 "Should not have a send_task after restore (cancel should have cleared it)"
4006 );
4007
4008 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4009 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4010 assert_eq!(
4011 entry_count, 1,
4012 "Should have 1 entry after restore (only the first user message)"
4013 );
4014
4015 // Verify the 2 completed terminals from before the checkpoint still exist
4016 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4017 thread.terminals.contains_key(&terminal_id_1)
4018 });
4019 assert!(
4020 terminal_1_exists,
4021 "Terminal 1 (from before checkpoint) should still exist"
4022 );
4023
4024 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4025 thread.terminals.contains_key(&terminal_id_2)
4026 });
4027 assert!(
4028 terminal_2_exists,
4029 "Terminal 2 (from before checkpoint) should still exist"
4030 );
4031
4032 // Verify they're still in completed state
4033 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4034 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4035 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4036 });
4037 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4038
4039 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4040 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4041 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4042 });
4043 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4044
4045 // Verify the running terminal (created after checkpoint) was removed
4046 let terminal_3_exists =
4047 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4048 assert!(
4049 !terminal_3_exists,
4050 "Terminal 3 (created after checkpoint) should have been removed"
4051 );
4052
4053 // Verify total count is 2 (the two from before the checkpoint)
4054 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4055 assert_eq!(
4056 terminal_count, 2,
4057 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4058 );
4059 }
4060}