1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15use task::{Shell, ShellBuilder};
16pub use terminal::*;
17
18use action_log::{ActionLog, ActionLogTelemetry};
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(
98 chunk: &str,
99 language_registry: &Arc<LanguageRegistry>,
100 path_style: PathStyle,
101 cx: &mut App,
102 ) -> Self {
103 Self::Message {
104 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
105 }
106 }
107
108 fn to_markdown(&self, cx: &App) -> String {
109 match self {
110 Self::Message { block } => block.to_markdown(cx).to_string(),
111 Self::Thought { block } => {
112 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
113 }
114 }
115 }
116}
117
118#[derive(Debug)]
119pub enum AgentThreadEntry {
120 UserMessage(UserMessage),
121 AssistantMessage(AssistantMessage),
122 ToolCall(ToolCall),
123}
124
125impl AgentThreadEntry {
126 pub fn to_markdown(&self, cx: &App) -> String {
127 match self {
128 Self::UserMessage(message) => message.to_markdown(cx),
129 Self::AssistantMessage(message) => message.to_markdown(cx),
130 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
131 }
132 }
133
134 pub fn user_message(&self) -> Option<&UserMessage> {
135 if let AgentThreadEntry::UserMessage(message) = self {
136 Some(message)
137 } else {
138 None
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
151 if let AgentThreadEntry::ToolCall(call) = self {
152 itertools::Either::Left(call.terminals())
153 } else {
154 itertools::Either::Right(std::iter::empty())
155 }
156 }
157
158 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
159 if let AgentThreadEntry::ToolCall(ToolCall {
160 locations,
161 resolved_locations,
162 ..
163 }) = self
164 {
165 Some((
166 locations.get(ix)?.clone(),
167 resolved_locations.get(ix)?.clone()?,
168 ))
169 } else {
170 None
171 }
172 }
173}
174
175#[derive(Debug)]
176pub struct ToolCall {
177 pub id: acp::ToolCallId,
178 pub label: Entity<Markdown>,
179 pub kind: acp::ToolKind,
180 pub content: Vec<ToolCallContent>,
181 pub status: ToolCallStatus,
182 pub locations: Vec<acp::ToolCallLocation>,
183 pub resolved_locations: Vec<Option<AgentLocation>>,
184 pub raw_input: Option<serde_json::Value>,
185 pub raw_output: Option<serde_json::Value>,
186}
187
188impl ToolCall {
189 fn from_acp(
190 tool_call: acp::ToolCall,
191 status: ToolCallStatus,
192 language_registry: Arc<LanguageRegistry>,
193 path_style: PathStyle,
194 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
195 cx: &mut App,
196 ) -> Result<Self> {
197 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
198 first_line.to_owned() + "…"
199 } else {
200 tool_call.title
201 };
202 let mut content = Vec::with_capacity(tool_call.content.len());
203 for item in tool_call.content {
204 if let Some(item) = ToolCallContent::from_acp(
205 item,
206 language_registry.clone(),
207 path_style,
208 terminals,
209 cx,
210 )? {
211 content.push(item);
212 }
213 }
214
215 let result = Self {
216 id: tool_call.tool_call_id,
217 label: cx
218 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
219 kind: tool_call.kind,
220 content,
221 locations: tool_call.locations,
222 resolved_locations: Vec::default(),
223 status,
224 raw_input: tool_call.raw_input,
225 raw_output: tool_call.raw_output,
226 };
227 Ok(result)
228 }
229
230 fn update_fields(
231 &mut self,
232 fields: acp::ToolCallUpdateFields,
233 language_registry: Arc<LanguageRegistry>,
234 path_style: PathStyle,
235 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
236 cx: &mut App,
237 ) -> Result<()> {
238 let acp::ToolCallUpdateFields {
239 kind,
240 status,
241 title,
242 content,
243 locations,
244 raw_input,
245 raw_output,
246 ..
247 } = fields;
248
249 if let Some(kind) = kind {
250 self.kind = kind;
251 }
252
253 if let Some(status) = status {
254 self.status = status.into();
255 }
256
257 if let Some(title) = title {
258 self.label.update(cx, |label, cx| {
259 if let Some((first_line, _)) = title.split_once("\n") {
260 label.replace(first_line.to_owned() + "…", cx)
261 } else {
262 label.replace(title, cx);
263 }
264 });
265 }
266
267 if let Some(content) = content {
268 let mut new_content_len = content.len();
269 let mut content = content.into_iter();
270
271 // Reuse existing content if we can
272 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
273 let valid_content =
274 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
275 if !valid_content {
276 new_content_len -= 1;
277 }
278 }
279 for new in content {
280 if let Some(new) = ToolCallContent::from_acp(
281 new,
282 language_registry.clone(),
283 path_style,
284 terminals,
285 cx,
286 )? {
287 self.content.push(new);
288 } else {
289 new_content_len -= 1;
290 }
291 }
292 self.content.truncate(new_content_len);
293 }
294
295 if let Some(locations) = locations {
296 self.locations = locations;
297 }
298
299 if let Some(raw_input) = raw_input {
300 self.raw_input = Some(raw_input);
301 }
302
303 if let Some(raw_output) = raw_output {
304 if self.content.is_empty()
305 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
306 {
307 self.content
308 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
309 markdown,
310 }));
311 }
312 self.raw_output = Some(raw_output);
313 }
314 Ok(())
315 }
316
317 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
318 self.content.iter().filter_map(|content| match content {
319 ToolCallContent::Diff(diff) => Some(diff),
320 ToolCallContent::ContentBlock(_) => None,
321 ToolCallContent::Terminal(_) => None,
322 })
323 }
324
325 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
326 self.content.iter().filter_map(|content| match content {
327 ToolCallContent::Terminal(terminal) => Some(terminal),
328 ToolCallContent::ContentBlock(_) => None,
329 ToolCallContent::Diff(_) => None,
330 })
331 }
332
333 fn to_markdown(&self, cx: &App) -> String {
334 let mut markdown = format!(
335 "**Tool Call: {}**\nStatus: {}\n\n",
336 self.label.read(cx).source(),
337 self.status
338 );
339 for content in &self.content {
340 markdown.push_str(content.to_markdown(cx).as_str());
341 markdown.push_str("\n\n");
342 }
343 markdown
344 }
345
346 async fn resolve_location(
347 location: acp::ToolCallLocation,
348 project: WeakEntity<Project>,
349 cx: &mut AsyncApp,
350 ) -> Option<ResolvedLocation> {
351 let buffer = project
352 .update(cx, |project, cx| {
353 project
354 .project_path_for_absolute_path(&location.path, cx)
355 .map(|path| project.open_buffer(path, cx))
356 })
357 .ok()??;
358 let buffer = buffer.await.log_err()?;
359 let position = buffer
360 .update(cx, |buffer, _| {
361 let snapshot = buffer.snapshot();
362 if let Some(row) = location.line {
363 let column = snapshot.indent_size_for_line(row).len;
364 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
365 snapshot.anchor_before(point)
366 } else {
367 Anchor::min_for_buffer(snapshot.remote_id())
368 }
369 })
370 .ok()?;
371
372 Some(ResolvedLocation { buffer, position })
373 }
374
375 fn resolve_locations(
376 &self,
377 project: Entity<Project>,
378 cx: &mut App,
379 ) -> Task<Vec<Option<ResolvedLocation>>> {
380 let locations = self.locations.clone();
381 project.update(cx, |_, cx| {
382 cx.spawn(async move |project, cx| {
383 let mut new_locations = Vec::new();
384 for location in locations {
385 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
386 }
387 new_locations
388 })
389 })
390 }
391}
392
393// Separate so we can hold a strong reference to the buffer
394// for saving on the thread
395#[derive(Clone, Debug, PartialEq, Eq)]
396struct ResolvedLocation {
397 buffer: Entity<Buffer>,
398 position: Anchor,
399}
400
401impl From<&ResolvedLocation> for AgentLocation {
402 fn from(value: &ResolvedLocation) -> Self {
403 Self {
404 buffer: value.buffer.downgrade(),
405 position: value.position,
406 }
407 }
408}
409
410#[derive(Debug)]
411pub enum ToolCallStatus {
412 /// The tool call hasn't started running yet, but we start showing it to
413 /// the user.
414 Pending,
415 /// The tool call is waiting for confirmation from the user.
416 WaitingForConfirmation {
417 options: Vec<acp::PermissionOption>,
418 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
419 },
420 /// The tool call is currently running.
421 InProgress,
422 /// The tool call completed successfully.
423 Completed,
424 /// The tool call failed.
425 Failed,
426 /// The user rejected the tool call.
427 Rejected,
428 /// The user canceled generation so the tool call was canceled.
429 Canceled,
430}
431
432impl From<acp::ToolCallStatus> for ToolCallStatus {
433 fn from(status: acp::ToolCallStatus) -> Self {
434 match status {
435 acp::ToolCallStatus::Pending => Self::Pending,
436 acp::ToolCallStatus::InProgress => Self::InProgress,
437 acp::ToolCallStatus::Completed => Self::Completed,
438 acp::ToolCallStatus::Failed => Self::Failed,
439 _ => Self::Pending,
440 }
441 }
442}
443
444impl Display for ToolCallStatus {
445 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446 write!(
447 f,
448 "{}",
449 match self {
450 ToolCallStatus::Pending => "Pending",
451 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
452 ToolCallStatus::InProgress => "In Progress",
453 ToolCallStatus::Completed => "Completed",
454 ToolCallStatus::Failed => "Failed",
455 ToolCallStatus::Rejected => "Rejected",
456 ToolCallStatus::Canceled => "Canceled",
457 }
458 )
459 }
460}
461
462#[derive(Debug, PartialEq, Clone)]
463pub enum ContentBlock {
464 Empty,
465 Markdown { markdown: Entity<Markdown> },
466 ResourceLink { resource_link: acp::ResourceLink },
467}
468
469impl ContentBlock {
470 pub fn new(
471 block: acp::ContentBlock,
472 language_registry: &Arc<LanguageRegistry>,
473 path_style: PathStyle,
474 cx: &mut App,
475 ) -> Self {
476 let mut this = Self::Empty;
477 this.append(block, language_registry, path_style, cx);
478 this
479 }
480
481 pub fn new_combined(
482 blocks: impl IntoIterator<Item = acp::ContentBlock>,
483 language_registry: Arc<LanguageRegistry>,
484 path_style: PathStyle,
485 cx: &mut App,
486 ) -> Self {
487 let mut this = Self::Empty;
488 for block in blocks {
489 this.append(block, &language_registry, path_style, cx);
490 }
491 this
492 }
493
494 pub fn append(
495 &mut self,
496 block: acp::ContentBlock,
497 language_registry: &Arc<LanguageRegistry>,
498 path_style: PathStyle,
499 cx: &mut App,
500 ) {
501 if matches!(self, ContentBlock::Empty)
502 && let acp::ContentBlock::ResourceLink(resource_link) = block
503 {
504 *self = ContentBlock::ResourceLink { resource_link };
505 return;
506 }
507
508 let new_content = self.block_string_contents(block, path_style);
509
510 match self {
511 ContentBlock::Empty => {
512 *self = Self::create_markdown_block(new_content, language_registry, cx);
513 }
514 ContentBlock::Markdown { markdown } => {
515 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
516 }
517 ContentBlock::ResourceLink { resource_link } => {
518 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
519 let combined = format!("{}\n{}", existing_content, new_content);
520
521 *self = Self::create_markdown_block(combined, language_registry, cx);
522 }
523 }
524 }
525
526 fn create_markdown_block(
527 content: String,
528 language_registry: &Arc<LanguageRegistry>,
529 cx: &mut App,
530 ) -> ContentBlock {
531 ContentBlock::Markdown {
532 markdown: cx
533 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
534 }
535 }
536
537 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
538 match block {
539 acp::ContentBlock::Text(text_content) => text_content.text,
540 acp::ContentBlock::ResourceLink(resource_link) => {
541 Self::resource_link_md(&resource_link.uri, path_style)
542 }
543 acp::ContentBlock::Resource(acp::EmbeddedResource {
544 resource:
545 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
546 uri,
547 ..
548 }),
549 ..
550 }) => Self::resource_link_md(&uri, path_style),
551 acp::ContentBlock::Image(image) => Self::image_md(&image),
552 _ => String::new(),
553 }
554 }
555
556 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
557 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
558 uri.as_link().to_string()
559 } else {
560 uri.to_string()
561 }
562 }
563
564 fn image_md(_image: &acp::ImageContent) -> String {
565 "`Image`".into()
566 }
567
568 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
569 match self {
570 ContentBlock::Empty => "",
571 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
572 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
573 }
574 }
575
576 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
577 match self {
578 ContentBlock::Empty => None,
579 ContentBlock::Markdown { markdown } => Some(markdown),
580 ContentBlock::ResourceLink { .. } => None,
581 }
582 }
583
584 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
585 match self {
586 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
587 _ => None,
588 }
589 }
590}
591
592#[derive(Debug)]
593pub enum ToolCallContent {
594 ContentBlock(ContentBlock),
595 Diff(Entity<Diff>),
596 Terminal(Entity<Terminal>),
597}
598
599impl ToolCallContent {
600 pub fn from_acp(
601 content: acp::ToolCallContent,
602 language_registry: Arc<LanguageRegistry>,
603 path_style: PathStyle,
604 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
605 cx: &mut App,
606 ) -> Result<Option<Self>> {
607 match content {
608 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
609 Ok(Some(Self::ContentBlock(ContentBlock::new(
610 content,
611 &language_registry,
612 path_style,
613 cx,
614 ))))
615 }
616 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
617 Diff::finalized(
618 diff.path.to_string_lossy().into_owned(),
619 diff.old_text,
620 diff.new_text,
621 language_registry,
622 cx,
623 )
624 })))),
625 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
626 .get(&terminal_id)
627 .cloned()
628 .map(|terminal| Some(Self::Terminal(terminal)))
629 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
630 _ => Ok(None),
631 }
632 }
633
634 pub fn update_from_acp(
635 &mut self,
636 new: acp::ToolCallContent,
637 language_registry: Arc<LanguageRegistry>,
638 path_style: PathStyle,
639 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
640 cx: &mut App,
641 ) -> Result<bool> {
642 let needs_update = match (&self, &new) {
643 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
644 old_diff.read(cx).needs_update(
645 new_diff.old_text.as_deref().unwrap_or(""),
646 &new_diff.new_text,
647 cx,
648 )
649 }
650 _ => true,
651 };
652
653 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
654 if needs_update {
655 *self = update;
656 }
657 Ok(true)
658 } else {
659 Ok(false)
660 }
661 }
662
663 pub fn to_markdown(&self, cx: &App) -> String {
664 match self {
665 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
666 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
667 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
668 }
669 }
670}
671
672#[derive(Debug, PartialEq)]
673pub enum ToolCallUpdate {
674 UpdateFields(acp::ToolCallUpdate),
675 UpdateDiff(ToolCallUpdateDiff),
676 UpdateTerminal(ToolCallUpdateTerminal),
677}
678
679impl ToolCallUpdate {
680 fn id(&self) -> &acp::ToolCallId {
681 match self {
682 Self::UpdateFields(update) => &update.tool_call_id,
683 Self::UpdateDiff(diff) => &diff.id,
684 Self::UpdateTerminal(terminal) => &terminal.id,
685 }
686 }
687}
688
689impl From<acp::ToolCallUpdate> for ToolCallUpdate {
690 fn from(update: acp::ToolCallUpdate) -> Self {
691 Self::UpdateFields(update)
692 }
693}
694
695impl From<ToolCallUpdateDiff> for ToolCallUpdate {
696 fn from(diff: ToolCallUpdateDiff) -> Self {
697 Self::UpdateDiff(diff)
698 }
699}
700
701#[derive(Debug, PartialEq)]
702pub struct ToolCallUpdateDiff {
703 pub id: acp::ToolCallId,
704 pub diff: Entity<Diff>,
705}
706
707impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
708 fn from(terminal: ToolCallUpdateTerminal) -> Self {
709 Self::UpdateTerminal(terminal)
710 }
711}
712
713#[derive(Debug, PartialEq)]
714pub struct ToolCallUpdateTerminal {
715 pub id: acp::ToolCallId,
716 pub terminal: Entity<Terminal>,
717}
718
719#[derive(Debug, Default)]
720pub struct Plan {
721 pub entries: Vec<PlanEntry>,
722}
723
724#[derive(Debug)]
725pub struct PlanStats<'a> {
726 pub in_progress_entry: Option<&'a PlanEntry>,
727 pub pending: u32,
728 pub completed: u32,
729}
730
731impl Plan {
732 pub fn is_empty(&self) -> bool {
733 self.entries.is_empty()
734 }
735
736 pub fn stats(&self) -> PlanStats<'_> {
737 let mut stats = PlanStats {
738 in_progress_entry: None,
739 pending: 0,
740 completed: 0,
741 };
742
743 for entry in &self.entries {
744 match &entry.status {
745 acp::PlanEntryStatus::Pending => {
746 stats.pending += 1;
747 }
748 acp::PlanEntryStatus::InProgress => {
749 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
750 }
751 acp::PlanEntryStatus::Completed => {
752 stats.completed += 1;
753 }
754 _ => {}
755 }
756 }
757
758 stats
759 }
760}
761
762#[derive(Debug)]
763pub struct PlanEntry {
764 pub content: Entity<Markdown>,
765 pub priority: acp::PlanEntryPriority,
766 pub status: acp::PlanEntryStatus,
767}
768
769impl PlanEntry {
770 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
771 Self {
772 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
773 priority: entry.priority,
774 status: entry.status,
775 }
776 }
777}
778
779#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
780pub struct TokenUsage {
781 pub max_tokens: u64,
782 pub used_tokens: u64,
783}
784
785impl TokenUsage {
786 pub fn ratio(&self) -> TokenUsageRatio {
787 #[cfg(debug_assertions)]
788 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
789 .unwrap_or("0.8".to_string())
790 .parse()
791 .unwrap();
792 #[cfg(not(debug_assertions))]
793 let warning_threshold: f32 = 0.8;
794
795 // When the maximum is unknown because there is no selected model,
796 // avoid showing the token limit warning.
797 if self.max_tokens == 0 {
798 TokenUsageRatio::Normal
799 } else if self.used_tokens >= self.max_tokens {
800 TokenUsageRatio::Exceeded
801 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
802 TokenUsageRatio::Warning
803 } else {
804 TokenUsageRatio::Normal
805 }
806 }
807}
808
809#[derive(Debug, Clone, PartialEq, Eq)]
810pub enum TokenUsageRatio {
811 Normal,
812 Warning,
813 Exceeded,
814}
815
816#[derive(Debug, Clone)]
817pub struct RetryStatus {
818 pub last_error: SharedString,
819 pub attempt: usize,
820 pub max_attempts: usize,
821 pub started_at: Instant,
822 pub duration: Duration,
823}
824
825pub struct AcpThread {
826 title: SharedString,
827 entries: Vec<AgentThreadEntry>,
828 plan: Plan,
829 project: Entity<Project>,
830 action_log: Entity<ActionLog>,
831 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
832 send_task: Option<Task<()>>,
833 connection: Rc<dyn AgentConnection>,
834 session_id: acp::SessionId,
835 token_usage: Option<TokenUsage>,
836 prompt_capabilities: acp::PromptCapabilities,
837 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
838 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
839 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
840 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
841}
842
843impl From<&AcpThread> for ActionLogTelemetry {
844 fn from(value: &AcpThread) -> Self {
845 Self {
846 agent_telemetry_id: value.connection().telemetry_id(),
847 session_id: value.session_id.0.clone(),
848 }
849 }
850}
851
852#[derive(Debug)]
853pub enum AcpThreadEvent {
854 NewEntry,
855 TitleUpdated,
856 TokenUsageUpdated,
857 EntryUpdated(usize),
858 EntriesRemoved(Range<usize>),
859 ToolAuthorizationRequired,
860 Retry(RetryStatus),
861 Stopped,
862 Error,
863 LoadError(LoadError),
864 PromptCapabilitiesUpdated,
865 Refusal,
866 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
867 ModeUpdated(acp::SessionModeId),
868}
869
870impl EventEmitter<AcpThreadEvent> for AcpThread {}
871
872#[derive(Debug, Clone)]
873pub enum TerminalProviderEvent {
874 Created {
875 terminal_id: acp::TerminalId,
876 label: String,
877 cwd: Option<PathBuf>,
878 output_byte_limit: Option<u64>,
879 terminal: Entity<::terminal::Terminal>,
880 },
881 Output {
882 terminal_id: acp::TerminalId,
883 data: Vec<u8>,
884 },
885 TitleChanged {
886 terminal_id: acp::TerminalId,
887 title: String,
888 },
889 Exit {
890 terminal_id: acp::TerminalId,
891 status: acp::TerminalExitStatus,
892 },
893}
894
895#[derive(Debug, Clone)]
896pub enum TerminalProviderCommand {
897 WriteInput {
898 terminal_id: acp::TerminalId,
899 bytes: Vec<u8>,
900 },
901 Resize {
902 terminal_id: acp::TerminalId,
903 cols: u16,
904 rows: u16,
905 },
906 Close {
907 terminal_id: acp::TerminalId,
908 },
909}
910
911impl AcpThread {
912 pub fn on_terminal_provider_event(
913 &mut self,
914 event: TerminalProviderEvent,
915 cx: &mut Context<Self>,
916 ) {
917 match event {
918 TerminalProviderEvent::Created {
919 terminal_id,
920 label,
921 cwd,
922 output_byte_limit,
923 terminal,
924 } => {
925 let entity = self.register_terminal_created(
926 terminal_id.clone(),
927 label,
928 cwd,
929 output_byte_limit,
930 terminal,
931 cx,
932 );
933
934 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
935 for data in chunks.drain(..) {
936 entity.update(cx, |term, cx| {
937 term.inner().update(cx, |inner, cx| {
938 inner.write_output(&data, cx);
939 })
940 });
941 }
942 }
943
944 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
945 entity.update(cx, |_term, cx| {
946 cx.notify();
947 });
948 }
949
950 cx.notify();
951 }
952 TerminalProviderEvent::Output { terminal_id, data } => {
953 if let Some(entity) = self.terminals.get(&terminal_id) {
954 entity.update(cx, |term, cx| {
955 term.inner().update(cx, |inner, cx| {
956 inner.write_output(&data, cx);
957 })
958 });
959 } else {
960 self.pending_terminal_output
961 .entry(terminal_id)
962 .or_default()
963 .push(data);
964 }
965 }
966 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
967 if let Some(entity) = self.terminals.get(&terminal_id) {
968 entity.update(cx, |term, cx| {
969 term.inner().update(cx, |inner, cx| {
970 inner.breadcrumb_text = title;
971 cx.emit(::terminal::Event::BreadcrumbsChanged);
972 })
973 });
974 }
975 }
976 TerminalProviderEvent::Exit {
977 terminal_id,
978 status,
979 } => {
980 if let Some(entity) = self.terminals.get(&terminal_id) {
981 entity.update(cx, |_term, cx| {
982 cx.notify();
983 });
984 } else {
985 self.pending_terminal_exit.insert(terminal_id, status);
986 }
987 }
988 }
989 }
990}
991
992#[derive(PartialEq, Eq, Debug)]
993pub enum ThreadStatus {
994 Idle,
995 Generating,
996}
997
998#[derive(Debug, Clone)]
999pub enum LoadError {
1000 Unsupported {
1001 command: SharedString,
1002 current_version: SharedString,
1003 minimum_version: SharedString,
1004 },
1005 FailedToInstall(SharedString),
1006 Exited {
1007 status: ExitStatus,
1008 },
1009 Other(SharedString),
1010}
1011
1012impl Display for LoadError {
1013 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1014 match self {
1015 LoadError::Unsupported {
1016 command: path,
1017 current_version,
1018 minimum_version,
1019 } => {
1020 write!(
1021 f,
1022 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1023 )
1024 }
1025 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1026 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1027 LoadError::Other(msg) => write!(f, "{msg}"),
1028 }
1029 }
1030}
1031
1032impl Error for LoadError {}
1033
1034impl AcpThread {
1035 pub fn new(
1036 title: impl Into<SharedString>,
1037 connection: Rc<dyn AgentConnection>,
1038 project: Entity<Project>,
1039 action_log: Entity<ActionLog>,
1040 session_id: acp::SessionId,
1041 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1042 cx: &mut Context<Self>,
1043 ) -> Self {
1044 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1045 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1046 loop {
1047 let caps = prompt_capabilities_rx.recv().await?;
1048 this.update(cx, |this, cx| {
1049 this.prompt_capabilities = caps;
1050 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1051 })?;
1052 }
1053 });
1054
1055 Self {
1056 action_log,
1057 shared_buffers: Default::default(),
1058 entries: Default::default(),
1059 plan: Default::default(),
1060 title: title.into(),
1061 project,
1062 send_task: None,
1063 connection,
1064 session_id,
1065 token_usage: None,
1066 prompt_capabilities,
1067 _observe_prompt_capabilities: task,
1068 terminals: HashMap::default(),
1069 pending_terminal_output: HashMap::default(),
1070 pending_terminal_exit: HashMap::default(),
1071 }
1072 }
1073
1074 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1075 self.prompt_capabilities.clone()
1076 }
1077
1078 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1079 &self.connection
1080 }
1081
1082 pub fn action_log(&self) -> &Entity<ActionLog> {
1083 &self.action_log
1084 }
1085
1086 pub fn project(&self) -> &Entity<Project> {
1087 &self.project
1088 }
1089
1090 pub fn title(&self) -> SharedString {
1091 self.title.clone()
1092 }
1093
1094 pub fn entries(&self) -> &[AgentThreadEntry] {
1095 &self.entries
1096 }
1097
1098 pub fn session_id(&self) -> &acp::SessionId {
1099 &self.session_id
1100 }
1101
1102 pub fn status(&self) -> ThreadStatus {
1103 if self.send_task.is_some() {
1104 ThreadStatus::Generating
1105 } else {
1106 ThreadStatus::Idle
1107 }
1108 }
1109
1110 pub fn token_usage(&self) -> Option<&TokenUsage> {
1111 self.token_usage.as_ref()
1112 }
1113
1114 pub fn has_pending_edit_tool_calls(&self) -> bool {
1115 for entry in self.entries.iter().rev() {
1116 match entry {
1117 AgentThreadEntry::UserMessage(_) => return false,
1118 AgentThreadEntry::ToolCall(
1119 call @ ToolCall {
1120 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1121 ..
1122 },
1123 ) if call.diffs().next().is_some() => {
1124 return true;
1125 }
1126 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1127 }
1128 }
1129
1130 false
1131 }
1132
1133 pub fn used_tools_since_last_user_message(&self) -> bool {
1134 for entry in self.entries.iter().rev() {
1135 match entry {
1136 AgentThreadEntry::UserMessage(..) => return false,
1137 AgentThreadEntry::AssistantMessage(..) => continue,
1138 AgentThreadEntry::ToolCall(..) => return true,
1139 }
1140 }
1141
1142 false
1143 }
1144
1145 pub fn handle_session_update(
1146 &mut self,
1147 update: acp::SessionUpdate,
1148 cx: &mut Context<Self>,
1149 ) -> Result<(), acp::Error> {
1150 match update {
1151 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1152 self.push_user_content_block(None, content, cx);
1153 }
1154 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1155 self.push_assistant_content_block(content, false, cx);
1156 }
1157 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1158 self.push_assistant_content_block(content, true, cx);
1159 }
1160 acp::SessionUpdate::ToolCall(tool_call) => {
1161 self.upsert_tool_call(tool_call, cx)?;
1162 }
1163 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1164 self.update_tool_call(tool_call_update, cx)?;
1165 }
1166 acp::SessionUpdate::Plan(plan) => {
1167 self.update_plan(plan, cx);
1168 }
1169 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1170 available_commands,
1171 ..
1172 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1173 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1174 current_mode_id,
1175 ..
1176 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1177 _ => {}
1178 }
1179 Ok(())
1180 }
1181
1182 pub fn push_user_content_block(
1183 &mut self,
1184 message_id: Option<UserMessageId>,
1185 chunk: acp::ContentBlock,
1186 cx: &mut Context<Self>,
1187 ) {
1188 let language_registry = self.project.read(cx).languages().clone();
1189 let path_style = self.project.read(cx).path_style(cx);
1190 let entries_len = self.entries.len();
1191
1192 if let Some(last_entry) = self.entries.last_mut()
1193 && let AgentThreadEntry::UserMessage(UserMessage {
1194 id,
1195 content,
1196 chunks,
1197 ..
1198 }) = last_entry
1199 {
1200 *id = message_id.or(id.take());
1201 content.append(chunk.clone(), &language_registry, path_style, cx);
1202 chunks.push(chunk);
1203 let idx = entries_len - 1;
1204 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1205 } else {
1206 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1207 self.push_entry(
1208 AgentThreadEntry::UserMessage(UserMessage {
1209 id: message_id,
1210 content,
1211 chunks: vec![chunk],
1212 checkpoint: None,
1213 }),
1214 cx,
1215 );
1216 }
1217 }
1218
1219 pub fn push_assistant_content_block(
1220 &mut self,
1221 chunk: acp::ContentBlock,
1222 is_thought: bool,
1223 cx: &mut Context<Self>,
1224 ) {
1225 let language_registry = self.project.read(cx).languages().clone();
1226 let path_style = self.project.read(cx).path_style(cx);
1227 let entries_len = self.entries.len();
1228 if let Some(last_entry) = self.entries.last_mut()
1229 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1230 {
1231 let idx = entries_len - 1;
1232 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1233 match (chunks.last_mut(), is_thought) {
1234 (Some(AssistantMessageChunk::Message { block }), false)
1235 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1236 block.append(chunk, &language_registry, path_style, cx)
1237 }
1238 _ => {
1239 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1240 if is_thought {
1241 chunks.push(AssistantMessageChunk::Thought { block })
1242 } else {
1243 chunks.push(AssistantMessageChunk::Message { block })
1244 }
1245 }
1246 }
1247 } else {
1248 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1249 let chunk = if is_thought {
1250 AssistantMessageChunk::Thought { block }
1251 } else {
1252 AssistantMessageChunk::Message { block }
1253 };
1254
1255 self.push_entry(
1256 AgentThreadEntry::AssistantMessage(AssistantMessage {
1257 chunks: vec![chunk],
1258 }),
1259 cx,
1260 );
1261 }
1262 }
1263
1264 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1265 self.entries.push(entry);
1266 cx.emit(AcpThreadEvent::NewEntry);
1267 }
1268
1269 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1270 self.connection.set_title(&self.session_id, cx).is_some()
1271 }
1272
1273 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1274 if title != self.title {
1275 self.title = title.clone();
1276 cx.emit(AcpThreadEvent::TitleUpdated);
1277 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1278 return set_title.run(title, cx);
1279 }
1280 }
1281 Task::ready(Ok(()))
1282 }
1283
1284 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1285 self.token_usage = usage;
1286 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1287 }
1288
1289 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1290 cx.emit(AcpThreadEvent::Retry(status));
1291 }
1292
1293 pub fn update_tool_call(
1294 &mut self,
1295 update: impl Into<ToolCallUpdate>,
1296 cx: &mut Context<Self>,
1297 ) -> Result<()> {
1298 let update = update.into();
1299 let languages = self.project.read(cx).languages().clone();
1300 let path_style = self.project.read(cx).path_style(cx);
1301
1302 let ix = match self.index_for_tool_call(update.id()) {
1303 Some(ix) => ix,
1304 None => {
1305 // Tool call not found - create a failed tool call entry
1306 let failed_tool_call = ToolCall {
1307 id: update.id().clone(),
1308 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1309 kind: acp::ToolKind::Fetch,
1310 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1311 "Tool call not found".into(),
1312 &languages,
1313 path_style,
1314 cx,
1315 ))],
1316 status: ToolCallStatus::Failed,
1317 locations: Vec::new(),
1318 resolved_locations: Vec::new(),
1319 raw_input: None,
1320 raw_output: None,
1321 };
1322 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1323 return Ok(());
1324 }
1325 };
1326 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1327 unreachable!()
1328 };
1329
1330 match update {
1331 ToolCallUpdate::UpdateFields(update) => {
1332 let location_updated = update.fields.locations.is_some();
1333 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1334 if location_updated {
1335 self.resolve_locations(update.tool_call_id, cx);
1336 }
1337 }
1338 ToolCallUpdate::UpdateDiff(update) => {
1339 call.content.clear();
1340 call.content.push(ToolCallContent::Diff(update.diff));
1341 }
1342 ToolCallUpdate::UpdateTerminal(update) => {
1343 call.content.clear();
1344 call.content
1345 .push(ToolCallContent::Terminal(update.terminal));
1346 }
1347 }
1348
1349 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1350
1351 Ok(())
1352 }
1353
1354 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1355 pub fn upsert_tool_call(
1356 &mut self,
1357 tool_call: acp::ToolCall,
1358 cx: &mut Context<Self>,
1359 ) -> Result<(), acp::Error> {
1360 let status = tool_call.status.into();
1361 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1362 }
1363
1364 /// Fails if id does not match an existing entry.
1365 pub fn upsert_tool_call_inner(
1366 &mut self,
1367 update: acp::ToolCallUpdate,
1368 status: ToolCallStatus,
1369 cx: &mut Context<Self>,
1370 ) -> Result<(), acp::Error> {
1371 let language_registry = self.project.read(cx).languages().clone();
1372 let path_style = self.project.read(cx).path_style(cx);
1373 let id = update.tool_call_id.clone();
1374
1375 let agent = self.connection().telemetry_id();
1376 let session = self.session_id();
1377 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1378 let status = if matches!(status, ToolCallStatus::Completed) {
1379 "completed"
1380 } else {
1381 "failed"
1382 };
1383 telemetry::event!("Agent Tool Call Completed", agent, session, status);
1384 }
1385
1386 if let Some(ix) = self.index_for_tool_call(&id) {
1387 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1388 unreachable!()
1389 };
1390
1391 call.update_fields(
1392 update.fields,
1393 language_registry,
1394 path_style,
1395 &self.terminals,
1396 cx,
1397 )?;
1398 call.status = status;
1399
1400 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1401 } else {
1402 let call = ToolCall::from_acp(
1403 update.try_into()?,
1404 status,
1405 language_registry,
1406 self.project.read(cx).path_style(cx),
1407 &self.terminals,
1408 cx,
1409 )?;
1410 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1411 };
1412
1413 self.resolve_locations(id, cx);
1414 Ok(())
1415 }
1416
1417 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1418 self.entries
1419 .iter()
1420 .enumerate()
1421 .rev()
1422 .find_map(|(index, entry)| {
1423 if let AgentThreadEntry::ToolCall(tool_call) = entry
1424 && &tool_call.id == id
1425 {
1426 Some(index)
1427 } else {
1428 None
1429 }
1430 })
1431 }
1432
1433 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1434 // The tool call we are looking for is typically the last one, or very close to the end.
1435 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1436 self.entries
1437 .iter_mut()
1438 .enumerate()
1439 .rev()
1440 .find_map(|(index, tool_call)| {
1441 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1442 && &tool_call.id == id
1443 {
1444 Some((index, tool_call))
1445 } else {
1446 None
1447 }
1448 })
1449 }
1450
1451 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1452 self.entries
1453 .iter()
1454 .enumerate()
1455 .rev()
1456 .find_map(|(index, tool_call)| {
1457 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1458 && &tool_call.id == id
1459 {
1460 Some((index, tool_call))
1461 } else {
1462 None
1463 }
1464 })
1465 }
1466
1467 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1468 let project = self.project.clone();
1469 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1470 return;
1471 };
1472 let task = tool_call.resolve_locations(project, cx);
1473 cx.spawn(async move |this, cx| {
1474 let resolved_locations = task.await;
1475
1476 this.update(cx, |this, cx| {
1477 let project = this.project.clone();
1478
1479 for location in resolved_locations.iter().flatten() {
1480 this.shared_buffers
1481 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1482 }
1483 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1484 return;
1485 };
1486
1487 if let Some(Some(location)) = resolved_locations.last() {
1488 project.update(cx, |project, cx| {
1489 let should_ignore = if let Some(agent_location) = project
1490 .agent_location()
1491 .filter(|agent_location| agent_location.buffer == location.buffer)
1492 {
1493 let snapshot = location.buffer.read(cx).snapshot();
1494 let old_position = agent_location.position.to_point(&snapshot);
1495 let new_position = location.position.to_point(&snapshot);
1496
1497 // ignore this so that when we get updates from the edit tool
1498 // the position doesn't reset to the startof line
1499 old_position.row == new_position.row
1500 && old_position.column > new_position.column
1501 } else {
1502 false
1503 };
1504 if !should_ignore {
1505 project.set_agent_location(Some(location.into()), cx);
1506 }
1507 });
1508 }
1509
1510 let resolved_locations = resolved_locations
1511 .iter()
1512 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1513 .collect::<Vec<_>>();
1514
1515 if tool_call.resolved_locations != resolved_locations {
1516 tool_call.resolved_locations = resolved_locations;
1517 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1518 }
1519 })
1520 })
1521 .detach();
1522 }
1523
1524 pub fn request_tool_call_authorization(
1525 &mut self,
1526 tool_call: acp::ToolCallUpdate,
1527 options: Vec<acp::PermissionOption>,
1528 respect_always_allow_setting: bool,
1529 cx: &mut Context<Self>,
1530 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1531 let (tx, rx) = oneshot::channel();
1532
1533 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1534 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1535 // some tools would (incorrectly) continue to auto-accept.
1536 if let Some(allow_once_option) = options.iter().find_map(|option| {
1537 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1538 Some(option.option_id.clone())
1539 } else {
1540 None
1541 }
1542 }) {
1543 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1544 return Ok(async {
1545 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1546 allow_once_option,
1547 ))
1548 }
1549 .boxed());
1550 }
1551 }
1552
1553 let status = ToolCallStatus::WaitingForConfirmation {
1554 options,
1555 respond_tx: tx,
1556 };
1557
1558 self.upsert_tool_call_inner(tool_call, status, cx)?;
1559 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1560
1561 let fut = async {
1562 match rx.await {
1563 Ok(option) => acp::RequestPermissionOutcome::Selected(
1564 acp::SelectedPermissionOutcome::new(option),
1565 ),
1566 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1567 }
1568 }
1569 .boxed();
1570
1571 Ok(fut)
1572 }
1573
1574 pub fn authorize_tool_call(
1575 &mut self,
1576 id: acp::ToolCallId,
1577 option_id: acp::PermissionOptionId,
1578 option_kind: acp::PermissionOptionKind,
1579 cx: &mut Context<Self>,
1580 ) {
1581 let Some((ix, call)) = self.tool_call_mut(&id) else {
1582 return;
1583 };
1584
1585 let new_status = match option_kind {
1586 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1587 ToolCallStatus::Rejected
1588 }
1589 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1590 ToolCallStatus::InProgress
1591 }
1592 _ => ToolCallStatus::InProgress,
1593 };
1594
1595 let curr_status = mem::replace(&mut call.status, new_status);
1596
1597 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1598 respond_tx.send(option_id).log_err();
1599 } else if cfg!(debug_assertions) {
1600 panic!("tried to authorize an already authorized tool call");
1601 }
1602
1603 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1604 }
1605
1606 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1607 let mut first_tool_call = None;
1608
1609 for entry in self.entries.iter().rev() {
1610 match &entry {
1611 AgentThreadEntry::ToolCall(call) => {
1612 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1613 first_tool_call = Some(call);
1614 } else {
1615 continue;
1616 }
1617 }
1618 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1619 // Reached the beginning of the turn.
1620 // If we had pending permission requests in the previous turn, they have been cancelled.
1621 break;
1622 }
1623 }
1624 }
1625
1626 first_tool_call
1627 }
1628
1629 pub fn plan(&self) -> &Plan {
1630 &self.plan
1631 }
1632
1633 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1634 let new_entries_len = request.entries.len();
1635 let mut new_entries = request.entries.into_iter();
1636
1637 // Reuse existing markdown to prevent flickering
1638 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1639 let PlanEntry {
1640 content,
1641 priority,
1642 status,
1643 } = old;
1644 content.update(cx, |old, cx| {
1645 old.replace(new.content, cx);
1646 });
1647 *priority = new.priority;
1648 *status = new.status;
1649 }
1650 for new in new_entries {
1651 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1652 }
1653 self.plan.entries.truncate(new_entries_len);
1654
1655 cx.notify();
1656 }
1657
1658 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1659 self.plan
1660 .entries
1661 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1662 cx.notify();
1663 }
1664
1665 #[cfg(any(test, feature = "test-support"))]
1666 pub fn send_raw(
1667 &mut self,
1668 message: &str,
1669 cx: &mut Context<Self>,
1670 ) -> BoxFuture<'static, Result<()>> {
1671 self.send(vec![message.into()], cx)
1672 }
1673
1674 pub fn send(
1675 &mut self,
1676 message: Vec<acp::ContentBlock>,
1677 cx: &mut Context<Self>,
1678 ) -> BoxFuture<'static, Result<()>> {
1679 let block = ContentBlock::new_combined(
1680 message.clone(),
1681 self.project.read(cx).languages().clone(),
1682 self.project.read(cx).path_style(cx),
1683 cx,
1684 );
1685 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1686 let git_store = self.project.read(cx).git_store().clone();
1687
1688 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1689 Some(UserMessageId::new())
1690 } else {
1691 None
1692 };
1693
1694 self.run_turn(cx, async move |this, cx| {
1695 this.update(cx, |this, cx| {
1696 this.push_entry(
1697 AgentThreadEntry::UserMessage(UserMessage {
1698 id: message_id.clone(),
1699 content: block,
1700 chunks: message,
1701 checkpoint: None,
1702 }),
1703 cx,
1704 );
1705 })
1706 .ok();
1707
1708 let old_checkpoint = git_store
1709 .update(cx, |git, cx| git.checkpoint(cx))?
1710 .await
1711 .context("failed to get old checkpoint")
1712 .log_err();
1713 this.update(cx, |this, cx| {
1714 if let Some((_ix, message)) = this.last_user_message() {
1715 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1716 git_checkpoint,
1717 show: false,
1718 });
1719 }
1720 this.connection.prompt(message_id, request, cx)
1721 })?
1722 .await
1723 })
1724 }
1725
1726 pub fn can_resume(&self, cx: &App) -> bool {
1727 self.connection.resume(&self.session_id, cx).is_some()
1728 }
1729
1730 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1731 self.run_turn(cx, async move |this, cx| {
1732 this.update(cx, |this, cx| {
1733 this.connection
1734 .resume(&this.session_id, cx)
1735 .map(|resume| resume.run(cx))
1736 })?
1737 .context("resuming a session is not supported")?
1738 .await
1739 })
1740 }
1741
1742 fn run_turn(
1743 &mut self,
1744 cx: &mut Context<Self>,
1745 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1746 ) -> BoxFuture<'static, Result<()>> {
1747 self.clear_completed_plan_entries(cx);
1748
1749 let (tx, rx) = oneshot::channel();
1750 let cancel_task = self.cancel(cx);
1751
1752 self.send_task = Some(cx.spawn(async move |this, cx| {
1753 cancel_task.await;
1754 tx.send(f(this, cx).await).ok();
1755 }));
1756
1757 cx.spawn(async move |this, cx| {
1758 let response = rx.await;
1759
1760 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1761 .await?;
1762
1763 this.update(cx, |this, cx| {
1764 this.project
1765 .update(cx, |project, cx| project.set_agent_location(None, cx));
1766 match response {
1767 Ok(Err(e)) => {
1768 this.send_task.take();
1769 cx.emit(AcpThreadEvent::Error);
1770 Err(e)
1771 }
1772 result => {
1773 let canceled = matches!(
1774 result,
1775 Ok(Ok(acp::PromptResponse {
1776 stop_reason: acp::StopReason::Cancelled,
1777 ..
1778 }))
1779 );
1780
1781 // We only take the task if the current prompt wasn't canceled.
1782 //
1783 // This prompt may have been canceled because another one was sent
1784 // while it was still generating. In these cases, dropping `send_task`
1785 // would cause the next generation to be canceled.
1786 if !canceled {
1787 this.send_task.take();
1788 }
1789
1790 // Handle refusal - distinguish between user prompt and tool call refusals
1791 if let Ok(Ok(acp::PromptResponse {
1792 stop_reason: acp::StopReason::Refusal,
1793 ..
1794 })) = result
1795 {
1796 if let Some((user_msg_ix, _)) = this.last_user_message() {
1797 // Check if there's a completed tool call with results after the last user message
1798 // This indicates the refusal is in response to tool output, not the user's prompt
1799 let has_completed_tool_call_after_user_msg =
1800 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1801 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1802 // Check if the tool call has completed and has output
1803 matches!(tool_call.status, ToolCallStatus::Completed)
1804 && tool_call.raw_output.is_some()
1805 } else {
1806 false
1807 }
1808 });
1809
1810 if has_completed_tool_call_after_user_msg {
1811 // Refusal is due to tool output - don't truncate, just notify
1812 // The model refused based on what the tool returned
1813 cx.emit(AcpThreadEvent::Refusal);
1814 } else {
1815 // User prompt was refused - truncate back to before the user message
1816 let range = user_msg_ix..this.entries.len();
1817 if range.start < range.end {
1818 this.entries.truncate(user_msg_ix);
1819 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1820 }
1821 cx.emit(AcpThreadEvent::Refusal);
1822 }
1823 } else {
1824 // No user message found, treat as general refusal
1825 cx.emit(AcpThreadEvent::Refusal);
1826 }
1827 }
1828
1829 cx.emit(AcpThreadEvent::Stopped);
1830 Ok(())
1831 }
1832 }
1833 })?
1834 })
1835 .boxed()
1836 }
1837
1838 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1839 let Some(send_task) = self.send_task.take() else {
1840 return Task::ready(());
1841 };
1842
1843 for entry in self.entries.iter_mut() {
1844 if let AgentThreadEntry::ToolCall(call) = entry {
1845 let cancel = matches!(
1846 call.status,
1847 ToolCallStatus::Pending
1848 | ToolCallStatus::WaitingForConfirmation { .. }
1849 | ToolCallStatus::InProgress
1850 );
1851
1852 if cancel {
1853 call.status = ToolCallStatus::Canceled;
1854 }
1855 }
1856 }
1857
1858 self.connection.cancel(&self.session_id, cx);
1859
1860 // Wait for the send task to complete
1861 cx.foreground_executor().spawn(send_task)
1862 }
1863
1864 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1865 pub fn restore_checkpoint(
1866 &mut self,
1867 id: UserMessageId,
1868 cx: &mut Context<Self>,
1869 ) -> Task<Result<()>> {
1870 let Some((_, message)) = self.user_message_mut(&id) else {
1871 return Task::ready(Err(anyhow!("message not found")));
1872 };
1873
1874 let checkpoint = message
1875 .checkpoint
1876 .as_ref()
1877 .map(|c| c.git_checkpoint.clone());
1878
1879 // Cancel any in-progress generation before restoring
1880 let cancel_task = self.cancel(cx);
1881 let rewind = self.rewind(id.clone(), cx);
1882 let git_store = self.project.read(cx).git_store().clone();
1883
1884 cx.spawn(async move |_, cx| {
1885 cancel_task.await;
1886 rewind.await?;
1887 if let Some(checkpoint) = checkpoint {
1888 git_store
1889 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1890 .await?;
1891 }
1892
1893 Ok(())
1894 })
1895 }
1896
1897 /// Rewinds this thread to before the entry at `index`, removing it and all
1898 /// subsequent entries while rejecting any action_log changes made from that point.
1899 /// Unlike `restore_checkpoint`, this method does not restore from git.
1900 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1901 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1902 return Task::ready(Err(anyhow!("not supported")));
1903 };
1904
1905 let telemetry = ActionLogTelemetry::from(&*self);
1906 cx.spawn(async move |this, cx| {
1907 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1908 this.update(cx, |this, cx| {
1909 if let Some((ix, _)) = this.user_message_mut(&id) {
1910 // Collect all terminals from entries that will be removed
1911 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1912 .iter()
1913 .flat_map(|entry| entry.terminals())
1914 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1915 .collect();
1916
1917 let range = ix..this.entries.len();
1918 this.entries.truncate(ix);
1919 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1920
1921 // Kill and remove the terminals
1922 for terminal_id in terminals_to_remove {
1923 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1924 terminal.update(cx, |terminal, cx| {
1925 terminal.kill(cx);
1926 });
1927 }
1928 }
1929 }
1930 this.action_log().update(cx, |action_log, cx| {
1931 action_log.reject_all_edits(Some(telemetry), cx)
1932 })
1933 })?
1934 .await;
1935 Ok(())
1936 })
1937 }
1938
1939 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1940 let git_store = self.project.read(cx).git_store().clone();
1941
1942 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1943 if let Some(checkpoint) = message.checkpoint.as_ref() {
1944 checkpoint.git_checkpoint.clone()
1945 } else {
1946 return Task::ready(Ok(()));
1947 }
1948 } else {
1949 return Task::ready(Ok(()));
1950 };
1951
1952 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1953 cx.spawn(async move |this, cx| {
1954 let new_checkpoint = new_checkpoint
1955 .await
1956 .context("failed to get new checkpoint")
1957 .log_err();
1958 if let Some(new_checkpoint) = new_checkpoint {
1959 let equal = git_store
1960 .update(cx, |git, cx| {
1961 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1962 })?
1963 .await
1964 .unwrap_or(true);
1965 this.update(cx, |this, cx| {
1966 let (ix, message) = this.last_user_message().context("no user message")?;
1967 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1968 checkpoint.show = !equal;
1969 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1970 anyhow::Ok(())
1971 })??;
1972 }
1973
1974 Ok(())
1975 })
1976 }
1977
1978 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1979 self.entries
1980 .iter_mut()
1981 .enumerate()
1982 .rev()
1983 .find_map(|(ix, entry)| {
1984 if let AgentThreadEntry::UserMessage(message) = entry {
1985 Some((ix, message))
1986 } else {
1987 None
1988 }
1989 })
1990 }
1991
1992 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1993 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1994 if let AgentThreadEntry::UserMessage(message) = entry {
1995 if message.id.as_ref() == Some(id) {
1996 Some((ix, message))
1997 } else {
1998 None
1999 }
2000 } else {
2001 None
2002 }
2003 })
2004 }
2005
2006 pub fn read_text_file(
2007 &self,
2008 path: PathBuf,
2009 line: Option<u32>,
2010 limit: Option<u32>,
2011 reuse_shared_snapshot: bool,
2012 cx: &mut Context<Self>,
2013 ) -> Task<Result<String, acp::Error>> {
2014 // Args are 1-based, move to 0-based
2015 let line = line.unwrap_or_default().saturating_sub(1);
2016 let limit = limit.unwrap_or(u32::MAX);
2017 let project = self.project.clone();
2018 let action_log = self.action_log.clone();
2019 cx.spawn(async move |this, cx| {
2020 let load = project
2021 .update(cx, |project, cx| {
2022 let path = project
2023 .project_path_for_absolute_path(&path, cx)
2024 .ok_or_else(|| {
2025 acp::Error::resource_not_found(Some(path.display().to_string()))
2026 })?;
2027 Ok(project.open_buffer(path, cx))
2028 })
2029 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2030 .flatten()?;
2031
2032 let buffer = load.await?;
2033
2034 let snapshot = if reuse_shared_snapshot {
2035 this.read_with(cx, |this, _| {
2036 this.shared_buffers.get(&buffer.clone()).cloned()
2037 })
2038 .log_err()
2039 .flatten()
2040 } else {
2041 None
2042 };
2043
2044 let snapshot = if let Some(snapshot) = snapshot {
2045 snapshot
2046 } else {
2047 action_log.update(cx, |action_log, cx| {
2048 action_log.buffer_read(buffer.clone(), cx);
2049 })?;
2050
2051 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2052 this.update(cx, |this, _| {
2053 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2054 })?;
2055 snapshot
2056 };
2057
2058 let max_point = snapshot.max_point();
2059 let start_position = Point::new(line, 0);
2060
2061 if start_position > max_point {
2062 return Err(acp::Error::invalid_params().data(format!(
2063 "Attempting to read beyond the end of the file, line {}:{}",
2064 max_point.row + 1,
2065 max_point.column
2066 )));
2067 }
2068
2069 let start = snapshot.anchor_before(start_position);
2070 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2071
2072 project.update(cx, |project, cx| {
2073 project.set_agent_location(
2074 Some(AgentLocation {
2075 buffer: buffer.downgrade(),
2076 position: start,
2077 }),
2078 cx,
2079 );
2080 })?;
2081
2082 Ok(snapshot.text_for_range(start..end).collect::<String>())
2083 })
2084 }
2085
2086 pub fn write_text_file(
2087 &self,
2088 path: PathBuf,
2089 content: String,
2090 cx: &mut Context<Self>,
2091 ) -> Task<Result<()>> {
2092 let project = self.project.clone();
2093 let action_log = self.action_log.clone();
2094 cx.spawn(async move |this, cx| {
2095 let load = project.update(cx, |project, cx| {
2096 let path = project
2097 .project_path_for_absolute_path(&path, cx)
2098 .context("invalid path")?;
2099 anyhow::Ok(project.open_buffer(path, cx))
2100 });
2101 let buffer = load??.await?;
2102 let snapshot = this.update(cx, |this, cx| {
2103 this.shared_buffers
2104 .get(&buffer)
2105 .cloned()
2106 .unwrap_or_else(|| buffer.read(cx).snapshot())
2107 })?;
2108 let edits = cx
2109 .background_executor()
2110 .spawn(async move {
2111 let old_text = snapshot.text();
2112 text_diff(old_text.as_str(), &content)
2113 .into_iter()
2114 .map(|(range, replacement)| {
2115 (
2116 snapshot.anchor_after(range.start)
2117 ..snapshot.anchor_before(range.end),
2118 replacement,
2119 )
2120 })
2121 .collect::<Vec<_>>()
2122 })
2123 .await;
2124
2125 project.update(cx, |project, cx| {
2126 project.set_agent_location(
2127 Some(AgentLocation {
2128 buffer: buffer.downgrade(),
2129 position: edits
2130 .last()
2131 .map(|(range, _)| range.end)
2132 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2133 }),
2134 cx,
2135 );
2136 })?;
2137
2138 let format_on_save = cx.update(|cx| {
2139 action_log.update(cx, |action_log, cx| {
2140 action_log.buffer_read(buffer.clone(), cx);
2141 });
2142
2143 let format_on_save = buffer.update(cx, |buffer, cx| {
2144 buffer.edit(edits, None, cx);
2145
2146 let settings = language::language_settings::language_settings(
2147 buffer.language().map(|l| l.name()),
2148 buffer.file(),
2149 cx,
2150 );
2151
2152 settings.format_on_save != FormatOnSave::Off
2153 });
2154 action_log.update(cx, |action_log, cx| {
2155 action_log.buffer_edited(buffer.clone(), cx);
2156 });
2157 format_on_save
2158 })?;
2159
2160 if format_on_save {
2161 let format_task = project.update(cx, |project, cx| {
2162 project.format(
2163 HashSet::from_iter([buffer.clone()]),
2164 LspFormatTarget::Buffers,
2165 false,
2166 FormatTrigger::Save,
2167 cx,
2168 )
2169 })?;
2170 format_task.await.log_err();
2171
2172 action_log.update(cx, |action_log, cx| {
2173 action_log.buffer_edited(buffer.clone(), cx);
2174 })?;
2175 }
2176
2177 project
2178 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2179 .await
2180 })
2181 }
2182
2183 pub fn create_terminal(
2184 &self,
2185 command: String,
2186 args: Vec<String>,
2187 extra_env: Vec<acp::EnvVariable>,
2188 cwd: Option<PathBuf>,
2189 output_byte_limit: Option<u64>,
2190 cx: &mut Context<Self>,
2191 ) -> Task<Result<Entity<Terminal>>> {
2192 let env = match &cwd {
2193 Some(dir) => self.project.update(cx, |project, cx| {
2194 project.environment().update(cx, |env, cx| {
2195 env.directory_environment(dir.as_path().into(), cx)
2196 })
2197 }),
2198 None => Task::ready(None).shared(),
2199 };
2200 let env = cx.spawn(async move |_, _| {
2201 let mut env = env.await.unwrap_or_default();
2202 // Disables paging for `git` and hopefully other commands
2203 env.insert("PAGER".into(), "".into());
2204 for var in extra_env {
2205 env.insert(var.name, var.value);
2206 }
2207 env
2208 });
2209
2210 let project = self.project.clone();
2211 let language_registry = project.read(cx).languages().clone();
2212 let is_windows = project.read(cx).path_style(cx).is_windows();
2213
2214 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2215 let terminal_task = cx.spawn({
2216 let terminal_id = terminal_id.clone();
2217 async move |_this, cx| {
2218 let env = env.await;
2219 let shell = project
2220 .update(cx, |project, cx| {
2221 project
2222 .remote_client()
2223 .and_then(|r| r.read(cx).default_system_shell())
2224 })?
2225 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2226 let (task_command, task_args) =
2227 ShellBuilder::new(&Shell::Program(shell), is_windows)
2228 .redirect_stdin_to_dev_null()
2229 .build(Some(command.clone()), &args);
2230 let terminal = project
2231 .update(cx, |project, cx| {
2232 project.create_terminal_task(
2233 task::SpawnInTerminal {
2234 command: Some(task_command),
2235 args: task_args,
2236 cwd: cwd.clone(),
2237 env,
2238 ..Default::default()
2239 },
2240 cx,
2241 )
2242 })?
2243 .await?;
2244
2245 cx.new(|cx| {
2246 Terminal::new(
2247 terminal_id,
2248 &format!("{} {}", command, args.join(" ")),
2249 cwd,
2250 output_byte_limit.map(|l| l as usize),
2251 terminal,
2252 language_registry,
2253 cx,
2254 )
2255 })
2256 }
2257 });
2258
2259 cx.spawn(async move |this, cx| {
2260 let terminal = terminal_task.await?;
2261 this.update(cx, |this, _cx| {
2262 this.terminals.insert(terminal_id, terminal.clone());
2263 terminal
2264 })
2265 })
2266 }
2267
2268 pub fn kill_terminal(
2269 &mut self,
2270 terminal_id: acp::TerminalId,
2271 cx: &mut Context<Self>,
2272 ) -> Result<()> {
2273 self.terminals
2274 .get(&terminal_id)
2275 .context("Terminal not found")?
2276 .update(cx, |terminal, cx| {
2277 terminal.kill(cx);
2278 });
2279
2280 Ok(())
2281 }
2282
2283 pub fn release_terminal(
2284 &mut self,
2285 terminal_id: acp::TerminalId,
2286 cx: &mut Context<Self>,
2287 ) -> Result<()> {
2288 self.terminals
2289 .remove(&terminal_id)
2290 .context("Terminal not found")?
2291 .update(cx, |terminal, cx| {
2292 terminal.kill(cx);
2293 });
2294
2295 Ok(())
2296 }
2297
2298 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2299 self.terminals
2300 .get(&terminal_id)
2301 .context("Terminal not found")
2302 .cloned()
2303 }
2304
2305 pub fn to_markdown(&self, cx: &App) -> String {
2306 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2307 }
2308
2309 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2310 cx.emit(AcpThreadEvent::LoadError(error));
2311 }
2312
2313 pub fn register_terminal_created(
2314 &mut self,
2315 terminal_id: acp::TerminalId,
2316 command_label: String,
2317 working_dir: Option<PathBuf>,
2318 output_byte_limit: Option<u64>,
2319 terminal: Entity<::terminal::Terminal>,
2320 cx: &mut Context<Self>,
2321 ) -> Entity<Terminal> {
2322 let language_registry = self.project.read(cx).languages().clone();
2323
2324 let entity = cx.new(|cx| {
2325 Terminal::new(
2326 terminal_id.clone(),
2327 &command_label,
2328 working_dir.clone(),
2329 output_byte_limit.map(|l| l as usize),
2330 terminal,
2331 language_registry,
2332 cx,
2333 )
2334 });
2335 self.terminals.insert(terminal_id.clone(), entity.clone());
2336 entity
2337 }
2338}
2339
2340fn markdown_for_raw_output(
2341 raw_output: &serde_json::Value,
2342 language_registry: &Arc<LanguageRegistry>,
2343 cx: &mut App,
2344) -> Option<Entity<Markdown>> {
2345 match raw_output {
2346 serde_json::Value::Null => None,
2347 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2348 Markdown::new(
2349 value.to_string().into(),
2350 Some(language_registry.clone()),
2351 None,
2352 cx,
2353 )
2354 })),
2355 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2356 Markdown::new(
2357 value.to_string().into(),
2358 Some(language_registry.clone()),
2359 None,
2360 cx,
2361 )
2362 })),
2363 serde_json::Value::String(value) => Some(cx.new(|cx| {
2364 Markdown::new(
2365 value.clone().into(),
2366 Some(language_registry.clone()),
2367 None,
2368 cx,
2369 )
2370 })),
2371 value => Some(cx.new(|cx| {
2372 Markdown::new(
2373 format!("```json\n{}\n```", value).into(),
2374 Some(language_registry.clone()),
2375 None,
2376 cx,
2377 )
2378 })),
2379 }
2380}
2381
2382#[cfg(test)]
2383mod tests {
2384 use super::*;
2385 use anyhow::anyhow;
2386 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2387 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2388 use indoc::indoc;
2389 use project::{FakeFs, Fs};
2390 use rand::{distr, prelude::*};
2391 use serde_json::json;
2392 use settings::SettingsStore;
2393 use smol::stream::StreamExt as _;
2394 use std::{
2395 any::Any,
2396 cell::RefCell,
2397 path::Path,
2398 rc::Rc,
2399 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2400 time::Duration,
2401 };
2402 use util::path;
2403
2404 fn init_test(cx: &mut TestAppContext) {
2405 env_logger::try_init().ok();
2406 cx.update(|cx| {
2407 let settings_store = SettingsStore::test(cx);
2408 cx.set_global(settings_store);
2409 });
2410 }
2411
2412 #[gpui::test]
2413 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2414 init_test(cx);
2415
2416 let fs = FakeFs::new(cx.executor());
2417 let project = Project::test(fs, [], cx).await;
2418 let connection = Rc::new(FakeAgentConnection::new());
2419 let thread = cx
2420 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2421 .await
2422 .unwrap();
2423
2424 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2425
2426 // Send Output BEFORE Created - should be buffered by acp_thread
2427 thread.update(cx, |thread, cx| {
2428 thread.on_terminal_provider_event(
2429 TerminalProviderEvent::Output {
2430 terminal_id: terminal_id.clone(),
2431 data: b"hello buffered".to_vec(),
2432 },
2433 cx,
2434 );
2435 });
2436
2437 // Create a display-only terminal and then send Created
2438 let lower = cx.new(|cx| {
2439 let builder = ::terminal::TerminalBuilder::new_display_only(
2440 ::terminal::terminal_settings::CursorShape::default(),
2441 ::terminal::terminal_settings::AlternateScroll::On,
2442 None,
2443 0,
2444 )
2445 .unwrap();
2446 builder.subscribe(cx)
2447 });
2448
2449 thread.update(cx, |thread, cx| {
2450 thread.on_terminal_provider_event(
2451 TerminalProviderEvent::Created {
2452 terminal_id: terminal_id.clone(),
2453 label: "Buffered Test".to_string(),
2454 cwd: None,
2455 output_byte_limit: None,
2456 terminal: lower.clone(),
2457 },
2458 cx,
2459 );
2460 });
2461
2462 // After Created, buffered Output should have been flushed into the renderer
2463 let content = thread.read_with(cx, |thread, cx| {
2464 let term = thread.terminal(terminal_id.clone()).unwrap();
2465 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2466 });
2467
2468 assert!(
2469 content.contains("hello buffered"),
2470 "expected buffered output to render, got: {content}"
2471 );
2472 }
2473
2474 #[gpui::test]
2475 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2476 init_test(cx);
2477
2478 let fs = FakeFs::new(cx.executor());
2479 let project = Project::test(fs, [], cx).await;
2480 let connection = Rc::new(FakeAgentConnection::new());
2481 let thread = cx
2482 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2483 .await
2484 .unwrap();
2485
2486 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2487
2488 // Send Output BEFORE Created
2489 thread.update(cx, |thread, cx| {
2490 thread.on_terminal_provider_event(
2491 TerminalProviderEvent::Output {
2492 terminal_id: terminal_id.clone(),
2493 data: b"pre-exit data".to_vec(),
2494 },
2495 cx,
2496 );
2497 });
2498
2499 // Send Exit BEFORE Created
2500 thread.update(cx, |thread, cx| {
2501 thread.on_terminal_provider_event(
2502 TerminalProviderEvent::Exit {
2503 terminal_id: terminal_id.clone(),
2504 status: acp::TerminalExitStatus::new().exit_code(0),
2505 },
2506 cx,
2507 );
2508 });
2509
2510 // Now create a display-only lower-level terminal and send Created
2511 let lower = cx.new(|cx| {
2512 let builder = ::terminal::TerminalBuilder::new_display_only(
2513 ::terminal::terminal_settings::CursorShape::default(),
2514 ::terminal::terminal_settings::AlternateScroll::On,
2515 None,
2516 0,
2517 )
2518 .unwrap();
2519 builder.subscribe(cx)
2520 });
2521
2522 thread.update(cx, |thread, cx| {
2523 thread.on_terminal_provider_event(
2524 TerminalProviderEvent::Created {
2525 terminal_id: terminal_id.clone(),
2526 label: "Buffered Exit Test".to_string(),
2527 cwd: None,
2528 output_byte_limit: None,
2529 terminal: lower.clone(),
2530 },
2531 cx,
2532 );
2533 });
2534
2535 // Output should be present after Created (flushed from buffer)
2536 let content = thread.read_with(cx, |thread, cx| {
2537 let term = thread.terminal(terminal_id.clone()).unwrap();
2538 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2539 });
2540
2541 assert!(
2542 content.contains("pre-exit data"),
2543 "expected pre-exit data to render, got: {content}"
2544 );
2545 }
2546
2547 #[gpui::test]
2548 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2549 init_test(cx);
2550
2551 let fs = FakeFs::new(cx.executor());
2552 let project = Project::test(fs, [], cx).await;
2553 let connection = Rc::new(FakeAgentConnection::new());
2554 let thread = cx
2555 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2556 .await
2557 .unwrap();
2558
2559 // Test creating a new user message
2560 thread.update(cx, |thread, cx| {
2561 thread.push_user_content_block(None, "Hello, ".into(), cx);
2562 });
2563
2564 thread.update(cx, |thread, cx| {
2565 assert_eq!(thread.entries.len(), 1);
2566 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2567 assert_eq!(user_msg.id, None);
2568 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2569 } else {
2570 panic!("Expected UserMessage");
2571 }
2572 });
2573
2574 // Test appending to existing user message
2575 let message_1_id = UserMessageId::new();
2576 thread.update(cx, |thread, cx| {
2577 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2578 });
2579
2580 thread.update(cx, |thread, cx| {
2581 assert_eq!(thread.entries.len(), 1);
2582 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2583 assert_eq!(user_msg.id, Some(message_1_id));
2584 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2585 } else {
2586 panic!("Expected UserMessage");
2587 }
2588 });
2589
2590 // Test creating new user message after assistant message
2591 thread.update(cx, |thread, cx| {
2592 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2593 });
2594
2595 let message_2_id = UserMessageId::new();
2596 thread.update(cx, |thread, cx| {
2597 thread.push_user_content_block(
2598 Some(message_2_id.clone()),
2599 "New user message".into(),
2600 cx,
2601 );
2602 });
2603
2604 thread.update(cx, |thread, cx| {
2605 assert_eq!(thread.entries.len(), 3);
2606 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2607 assert_eq!(user_msg.id, Some(message_2_id));
2608 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2609 } else {
2610 panic!("Expected UserMessage at index 2");
2611 }
2612 });
2613 }
2614
2615 #[gpui::test]
2616 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2617 init_test(cx);
2618
2619 let fs = FakeFs::new(cx.executor());
2620 let project = Project::test(fs, [], cx).await;
2621 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2622 |_, thread, mut cx| {
2623 async move {
2624 thread.update(&mut cx, |thread, cx| {
2625 thread
2626 .handle_session_update(
2627 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2628 "Thinking ".into(),
2629 )),
2630 cx,
2631 )
2632 .unwrap();
2633 thread
2634 .handle_session_update(
2635 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2636 "hard!".into(),
2637 )),
2638 cx,
2639 )
2640 .unwrap();
2641 })?;
2642 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2643 }
2644 .boxed_local()
2645 },
2646 ));
2647
2648 let thread = cx
2649 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2650 .await
2651 .unwrap();
2652
2653 thread
2654 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2655 .await
2656 .unwrap();
2657
2658 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2659 assert_eq!(
2660 output,
2661 indoc! {r#"
2662 ## User
2663
2664 Hello from Zed!
2665
2666 ## Assistant
2667
2668 <thinking>
2669 Thinking hard!
2670 </thinking>
2671
2672 "#}
2673 );
2674 }
2675
2676 #[gpui::test]
2677 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2678 init_test(cx);
2679
2680 let fs = FakeFs::new(cx.executor());
2681 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2682 .await;
2683 let project = Project::test(fs.clone(), [], cx).await;
2684 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2685 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2686 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2687 move |_, thread, mut cx| {
2688 let read_file_tx = read_file_tx.clone();
2689 async move {
2690 let content = thread
2691 .update(&mut cx, |thread, cx| {
2692 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2693 })
2694 .unwrap()
2695 .await
2696 .unwrap();
2697 assert_eq!(content, "one\ntwo\nthree\n");
2698 read_file_tx.take().unwrap().send(()).unwrap();
2699 thread
2700 .update(&mut cx, |thread, cx| {
2701 thread.write_text_file(
2702 path!("/tmp/foo").into(),
2703 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2704 cx,
2705 )
2706 })
2707 .unwrap()
2708 .await
2709 .unwrap();
2710 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2711 }
2712 .boxed_local()
2713 },
2714 ));
2715
2716 let (worktree, pathbuf) = project
2717 .update(cx, |project, cx| {
2718 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2719 })
2720 .await
2721 .unwrap();
2722 let buffer = project
2723 .update(cx, |project, cx| {
2724 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2725 })
2726 .await
2727 .unwrap();
2728
2729 let thread = cx
2730 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2731 .await
2732 .unwrap();
2733
2734 let request = thread.update(cx, |thread, cx| {
2735 thread.send_raw("Extend the count in /tmp/foo", cx)
2736 });
2737 read_file_rx.await.ok();
2738 buffer.update(cx, |buffer, cx| {
2739 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2740 });
2741 cx.run_until_parked();
2742 assert_eq!(
2743 buffer.read_with(cx, |buffer, _| buffer.text()),
2744 "zero\none\ntwo\nthree\nfour\nfive\n"
2745 );
2746 assert_eq!(
2747 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2748 "zero\none\ntwo\nthree\nfour\nfive\n"
2749 );
2750 request.await.unwrap();
2751 }
2752
2753 #[gpui::test]
2754 async fn test_reading_from_line(cx: &mut TestAppContext) {
2755 init_test(cx);
2756
2757 let fs = FakeFs::new(cx.executor());
2758 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2759 .await;
2760 let project = Project::test(fs.clone(), [], cx).await;
2761 project
2762 .update(cx, |project, cx| {
2763 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2764 })
2765 .await
2766 .unwrap();
2767
2768 let connection = Rc::new(FakeAgentConnection::new());
2769
2770 let thread = cx
2771 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2772 .await
2773 .unwrap();
2774
2775 // Whole file
2776 let content = thread
2777 .update(cx, |thread, cx| {
2778 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2779 })
2780 .await
2781 .unwrap();
2782
2783 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2784
2785 // Only start line
2786 let content = thread
2787 .update(cx, |thread, cx| {
2788 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2789 })
2790 .await
2791 .unwrap();
2792
2793 assert_eq!(content, "three\nfour\n");
2794
2795 // Only limit
2796 let content = thread
2797 .update(cx, |thread, cx| {
2798 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2799 })
2800 .await
2801 .unwrap();
2802
2803 assert_eq!(content, "one\ntwo\n");
2804
2805 // Range
2806 let content = thread
2807 .update(cx, |thread, cx| {
2808 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2809 })
2810 .await
2811 .unwrap();
2812
2813 assert_eq!(content, "two\nthree\n");
2814
2815 // Invalid
2816 let err = thread
2817 .update(cx, |thread, cx| {
2818 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2819 })
2820 .await
2821 .unwrap_err();
2822
2823 assert_eq!(
2824 err.to_string(),
2825 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2826 );
2827 }
2828
2829 #[gpui::test]
2830 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2831 init_test(cx);
2832
2833 let fs = FakeFs::new(cx.executor());
2834 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2835 let project = Project::test(fs.clone(), [], cx).await;
2836 project
2837 .update(cx, |project, cx| {
2838 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2839 })
2840 .await
2841 .unwrap();
2842
2843 let connection = Rc::new(FakeAgentConnection::new());
2844
2845 let thread = cx
2846 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2847 .await
2848 .unwrap();
2849
2850 // Whole file
2851 let content = thread
2852 .update(cx, |thread, cx| {
2853 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2854 })
2855 .await
2856 .unwrap();
2857
2858 assert_eq!(content, "");
2859
2860 // Only start line
2861 let content = thread
2862 .update(cx, |thread, cx| {
2863 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2864 })
2865 .await
2866 .unwrap();
2867
2868 assert_eq!(content, "");
2869
2870 // Only limit
2871 let content = thread
2872 .update(cx, |thread, cx| {
2873 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2874 })
2875 .await
2876 .unwrap();
2877
2878 assert_eq!(content, "");
2879
2880 // Range
2881 let content = thread
2882 .update(cx, |thread, cx| {
2883 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2884 })
2885 .await
2886 .unwrap();
2887
2888 assert_eq!(content, "");
2889
2890 // Invalid
2891 let err = thread
2892 .update(cx, |thread, cx| {
2893 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2894 })
2895 .await
2896 .unwrap_err();
2897
2898 assert_eq!(
2899 err.to_string(),
2900 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2901 );
2902 }
2903 #[gpui::test]
2904 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2905 init_test(cx);
2906
2907 let fs = FakeFs::new(cx.executor());
2908 fs.insert_tree(path!("/tmp"), json!({})).await;
2909 let project = Project::test(fs.clone(), [], cx).await;
2910 project
2911 .update(cx, |project, cx| {
2912 project.find_or_create_worktree(path!("/tmp"), true, cx)
2913 })
2914 .await
2915 .unwrap();
2916
2917 let connection = Rc::new(FakeAgentConnection::new());
2918
2919 let thread = cx
2920 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2921 .await
2922 .unwrap();
2923
2924 // Out of project file
2925 let err = thread
2926 .update(cx, |thread, cx| {
2927 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2928 })
2929 .await
2930 .unwrap_err();
2931
2932 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2933 }
2934
2935 #[gpui::test]
2936 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2937 init_test(cx);
2938
2939 let fs = FakeFs::new(cx.executor());
2940 let project = Project::test(fs, [], cx).await;
2941 let id = acp::ToolCallId::new("test");
2942
2943 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2944 let id = id.clone();
2945 move |_, thread, mut cx| {
2946 let id = id.clone();
2947 async move {
2948 thread
2949 .update(&mut cx, |thread, cx| {
2950 thread.handle_session_update(
2951 acp::SessionUpdate::ToolCall(
2952 acp::ToolCall::new(id.clone(), "Label")
2953 .kind(acp::ToolKind::Fetch)
2954 .status(acp::ToolCallStatus::InProgress),
2955 ),
2956 cx,
2957 )
2958 })
2959 .unwrap()
2960 .unwrap();
2961 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2962 }
2963 .boxed_local()
2964 }
2965 }));
2966
2967 let thread = cx
2968 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2969 .await
2970 .unwrap();
2971
2972 let request = thread.update(cx, |thread, cx| {
2973 thread.send_raw("Fetch https://example.com", cx)
2974 });
2975
2976 run_until_first_tool_call(&thread, cx).await;
2977
2978 thread.read_with(cx, |thread, _| {
2979 assert!(matches!(
2980 thread.entries[1],
2981 AgentThreadEntry::ToolCall(ToolCall {
2982 status: ToolCallStatus::InProgress,
2983 ..
2984 })
2985 ));
2986 });
2987
2988 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2989
2990 thread.read_with(cx, |thread, _| {
2991 assert!(matches!(
2992 &thread.entries[1],
2993 AgentThreadEntry::ToolCall(ToolCall {
2994 status: ToolCallStatus::Canceled,
2995 ..
2996 })
2997 ));
2998 });
2999
3000 thread
3001 .update(cx, |thread, cx| {
3002 thread.handle_session_update(
3003 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3004 id,
3005 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3006 )),
3007 cx,
3008 )
3009 })
3010 .unwrap();
3011
3012 request.await.unwrap();
3013
3014 thread.read_with(cx, |thread, _| {
3015 assert!(matches!(
3016 thread.entries[1],
3017 AgentThreadEntry::ToolCall(ToolCall {
3018 status: ToolCallStatus::Completed,
3019 ..
3020 })
3021 ));
3022 });
3023 }
3024
3025 #[gpui::test]
3026 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3027 init_test(cx);
3028 let fs = FakeFs::new(cx.background_executor.clone());
3029 fs.insert_tree(path!("/test"), json!({})).await;
3030 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3031
3032 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3033 move |_, thread, mut cx| {
3034 async move {
3035 thread
3036 .update(&mut cx, |thread, cx| {
3037 thread.handle_session_update(
3038 acp::SessionUpdate::ToolCall(
3039 acp::ToolCall::new("test", "Label")
3040 .kind(acp::ToolKind::Edit)
3041 .status(acp::ToolCallStatus::Completed)
3042 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3043 "/test/test.txt",
3044 "foo",
3045 ))]),
3046 ),
3047 cx,
3048 )
3049 })
3050 .unwrap()
3051 .unwrap();
3052 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3053 }
3054 .boxed_local()
3055 }
3056 }));
3057
3058 let thread = cx
3059 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3060 .await
3061 .unwrap();
3062
3063 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3064 .await
3065 .unwrap();
3066
3067 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3068 }
3069
3070 #[gpui::test(iterations = 10)]
3071 async fn test_checkpoints(cx: &mut TestAppContext) {
3072 init_test(cx);
3073 let fs = FakeFs::new(cx.background_executor.clone());
3074 fs.insert_tree(
3075 path!("/test"),
3076 json!({
3077 ".git": {}
3078 }),
3079 )
3080 .await;
3081 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3082
3083 let simulate_changes = Arc::new(AtomicBool::new(true));
3084 let next_filename = Arc::new(AtomicUsize::new(0));
3085 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3086 let simulate_changes = simulate_changes.clone();
3087 let next_filename = next_filename.clone();
3088 let fs = fs.clone();
3089 move |request, thread, mut cx| {
3090 let fs = fs.clone();
3091 let simulate_changes = simulate_changes.clone();
3092 let next_filename = next_filename.clone();
3093 async move {
3094 if simulate_changes.load(SeqCst) {
3095 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3096 fs.write(Path::new(&filename), b"").await?;
3097 }
3098
3099 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3100 panic!("expected text content block");
3101 };
3102 thread.update(&mut cx, |thread, cx| {
3103 thread
3104 .handle_session_update(
3105 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3106 content.text.to_uppercase().into(),
3107 )),
3108 cx,
3109 )
3110 .unwrap();
3111 })?;
3112 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3113 }
3114 .boxed_local()
3115 }
3116 }));
3117 let thread = cx
3118 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3119 .await
3120 .unwrap();
3121
3122 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3123 .await
3124 .unwrap();
3125 thread.read_with(cx, |thread, cx| {
3126 assert_eq!(
3127 thread.to_markdown(cx),
3128 indoc! {"
3129 ## User (checkpoint)
3130
3131 Lorem
3132
3133 ## Assistant
3134
3135 LOREM
3136
3137 "}
3138 );
3139 });
3140 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3141
3142 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3143 .await
3144 .unwrap();
3145 thread.read_with(cx, |thread, cx| {
3146 assert_eq!(
3147 thread.to_markdown(cx),
3148 indoc! {"
3149 ## User (checkpoint)
3150
3151 Lorem
3152
3153 ## Assistant
3154
3155 LOREM
3156
3157 ## User (checkpoint)
3158
3159 ipsum
3160
3161 ## Assistant
3162
3163 IPSUM
3164
3165 "}
3166 );
3167 });
3168 assert_eq!(
3169 fs.files(),
3170 vec![
3171 Path::new(path!("/test/file-0")),
3172 Path::new(path!("/test/file-1"))
3173 ]
3174 );
3175
3176 // Checkpoint isn't stored when there are no changes.
3177 simulate_changes.store(false, SeqCst);
3178 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3179 .await
3180 .unwrap();
3181 thread.read_with(cx, |thread, cx| {
3182 assert_eq!(
3183 thread.to_markdown(cx),
3184 indoc! {"
3185 ## User (checkpoint)
3186
3187 Lorem
3188
3189 ## Assistant
3190
3191 LOREM
3192
3193 ## User (checkpoint)
3194
3195 ipsum
3196
3197 ## Assistant
3198
3199 IPSUM
3200
3201 ## User
3202
3203 dolor
3204
3205 ## Assistant
3206
3207 DOLOR
3208
3209 "}
3210 );
3211 });
3212 assert_eq!(
3213 fs.files(),
3214 vec![
3215 Path::new(path!("/test/file-0")),
3216 Path::new(path!("/test/file-1"))
3217 ]
3218 );
3219
3220 // Rewinding the conversation truncates the history and restores the checkpoint.
3221 thread
3222 .update(cx, |thread, cx| {
3223 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3224 panic!("unexpected entries {:?}", thread.entries)
3225 };
3226 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3227 })
3228 .await
3229 .unwrap();
3230 thread.read_with(cx, |thread, cx| {
3231 assert_eq!(
3232 thread.to_markdown(cx),
3233 indoc! {"
3234 ## User (checkpoint)
3235
3236 Lorem
3237
3238 ## Assistant
3239
3240 LOREM
3241
3242 "}
3243 );
3244 });
3245 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3246 }
3247
3248 #[gpui::test]
3249 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3250 use std::sync::atomic::AtomicUsize;
3251 init_test(cx);
3252
3253 let fs = FakeFs::new(cx.executor());
3254 let project = Project::test(fs, None, cx).await;
3255
3256 // Create a connection that simulates refusal after tool result
3257 let prompt_count = Arc::new(AtomicUsize::new(0));
3258 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3259 let prompt_count = prompt_count.clone();
3260 move |_request, thread, mut cx| {
3261 let count = prompt_count.fetch_add(1, SeqCst);
3262 async move {
3263 if count == 0 {
3264 // First prompt: Generate a tool call with result
3265 thread.update(&mut cx, |thread, cx| {
3266 thread
3267 .handle_session_update(
3268 acp::SessionUpdate::ToolCall(
3269 acp::ToolCall::new("tool1", "Test Tool")
3270 .kind(acp::ToolKind::Fetch)
3271 .status(acp::ToolCallStatus::Completed)
3272 .raw_input(serde_json::json!({"query": "test"}))
3273 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3274 ),
3275 cx,
3276 )
3277 .unwrap();
3278 })?;
3279
3280 // Now return refusal because of the tool result
3281 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3282 } else {
3283 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3284 }
3285 }
3286 .boxed_local()
3287 }
3288 }));
3289
3290 let thread = cx
3291 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3292 .await
3293 .unwrap();
3294
3295 // Track if we see a Refusal event
3296 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3297 let saw_refusal_event_captured = saw_refusal_event.clone();
3298 thread.update(cx, |_thread, cx| {
3299 cx.subscribe(
3300 &thread,
3301 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3302 if matches!(event, AcpThreadEvent::Refusal) {
3303 *saw_refusal_event_captured.lock().unwrap() = true;
3304 }
3305 },
3306 )
3307 .detach();
3308 });
3309
3310 // Send a user message - this will trigger tool call and then refusal
3311 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3312 cx.background_executor.spawn(send_task).detach();
3313 cx.run_until_parked();
3314
3315 // Verify that:
3316 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3317 // 2. The user message was NOT truncated
3318 assert!(
3319 *saw_refusal_event.lock().unwrap(),
3320 "Refusal event should be emitted for tool result refusals"
3321 );
3322
3323 thread.read_with(cx, |thread, _| {
3324 let entries = thread.entries();
3325 assert!(entries.len() >= 2, "Should have user message and tool call");
3326
3327 // Verify user message is still there
3328 assert!(
3329 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3330 "User message should not be truncated"
3331 );
3332
3333 // Verify tool call is there with result
3334 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3335 assert!(
3336 tool_call.raw_output.is_some(),
3337 "Tool call should have output"
3338 );
3339 } else {
3340 panic!("Expected tool call at index 1");
3341 }
3342 });
3343 }
3344
3345 #[gpui::test]
3346 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3347 init_test(cx);
3348
3349 let fs = FakeFs::new(cx.executor());
3350 let project = Project::test(fs, None, cx).await;
3351
3352 let refuse_next = Arc::new(AtomicBool::new(false));
3353 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3354 let refuse_next = refuse_next.clone();
3355 move |_request, _thread, _cx| {
3356 if refuse_next.load(SeqCst) {
3357 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3358 .boxed_local()
3359 } else {
3360 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3361 .boxed_local()
3362 }
3363 }
3364 }));
3365
3366 let thread = cx
3367 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3368 .await
3369 .unwrap();
3370
3371 // Track if we see a Refusal event
3372 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3373 let saw_refusal_event_captured = saw_refusal_event.clone();
3374 thread.update(cx, |_thread, cx| {
3375 cx.subscribe(
3376 &thread,
3377 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3378 if matches!(event, AcpThreadEvent::Refusal) {
3379 *saw_refusal_event_captured.lock().unwrap() = true;
3380 }
3381 },
3382 )
3383 .detach();
3384 });
3385
3386 // Send a message that will be refused
3387 refuse_next.store(true, SeqCst);
3388 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3389 .await
3390 .unwrap();
3391
3392 // Verify that a Refusal event WAS emitted for user prompt refusal
3393 assert!(
3394 *saw_refusal_event.lock().unwrap(),
3395 "Refusal event should be emitted for user prompt refusals"
3396 );
3397
3398 // Verify the message was truncated (user prompt refusal)
3399 thread.read_with(cx, |thread, cx| {
3400 assert_eq!(thread.to_markdown(cx), "");
3401 });
3402 }
3403
3404 #[gpui::test]
3405 async fn test_refusal(cx: &mut TestAppContext) {
3406 init_test(cx);
3407 let fs = FakeFs::new(cx.background_executor.clone());
3408 fs.insert_tree(path!("/"), json!({})).await;
3409 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3410
3411 let refuse_next = Arc::new(AtomicBool::new(false));
3412 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3413 let refuse_next = refuse_next.clone();
3414 move |request, thread, mut cx| {
3415 let refuse_next = refuse_next.clone();
3416 async move {
3417 if refuse_next.load(SeqCst) {
3418 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3419 }
3420
3421 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3422 panic!("expected text content block");
3423 };
3424 thread.update(&mut cx, |thread, cx| {
3425 thread
3426 .handle_session_update(
3427 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3428 content.text.to_uppercase().into(),
3429 )),
3430 cx,
3431 )
3432 .unwrap();
3433 })?;
3434 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3435 }
3436 .boxed_local()
3437 }
3438 }));
3439 let thread = cx
3440 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3441 .await
3442 .unwrap();
3443
3444 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3445 .await
3446 .unwrap();
3447 thread.read_with(cx, |thread, cx| {
3448 assert_eq!(
3449 thread.to_markdown(cx),
3450 indoc! {"
3451 ## User
3452
3453 hello
3454
3455 ## Assistant
3456
3457 HELLO
3458
3459 "}
3460 );
3461 });
3462
3463 // Simulate refusing the second message. The message should be truncated
3464 // when a user prompt is refused.
3465 refuse_next.store(true, SeqCst);
3466 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3467 .await
3468 .unwrap();
3469 thread.read_with(cx, |thread, cx| {
3470 assert_eq!(
3471 thread.to_markdown(cx),
3472 indoc! {"
3473 ## User
3474
3475 hello
3476
3477 ## Assistant
3478
3479 HELLO
3480
3481 "}
3482 );
3483 });
3484 }
3485
3486 async fn run_until_first_tool_call(
3487 thread: &Entity<AcpThread>,
3488 cx: &mut TestAppContext,
3489 ) -> usize {
3490 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3491
3492 let subscription = cx.update(|cx| {
3493 cx.subscribe(thread, move |thread, _, cx| {
3494 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3495 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3496 return tx.try_send(ix).unwrap();
3497 }
3498 }
3499 })
3500 });
3501
3502 select! {
3503 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3504 panic!("Timeout waiting for tool call")
3505 }
3506 ix = rx.next().fuse() => {
3507 drop(subscription);
3508 ix.unwrap()
3509 }
3510 }
3511 }
3512
3513 #[derive(Clone, Default)]
3514 struct FakeAgentConnection {
3515 auth_methods: Vec<acp::AuthMethod>,
3516 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3517 on_user_message: Option<
3518 Rc<
3519 dyn Fn(
3520 acp::PromptRequest,
3521 WeakEntity<AcpThread>,
3522 AsyncApp,
3523 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3524 + 'static,
3525 >,
3526 >,
3527 }
3528
3529 impl FakeAgentConnection {
3530 fn new() -> Self {
3531 Self {
3532 auth_methods: Vec::new(),
3533 on_user_message: None,
3534 sessions: Arc::default(),
3535 }
3536 }
3537
3538 #[expect(unused)]
3539 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3540 self.auth_methods = auth_methods;
3541 self
3542 }
3543
3544 fn on_user_message(
3545 mut self,
3546 handler: impl Fn(
3547 acp::PromptRequest,
3548 WeakEntity<AcpThread>,
3549 AsyncApp,
3550 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3551 + 'static,
3552 ) -> Self {
3553 self.on_user_message.replace(Rc::new(handler));
3554 self
3555 }
3556 }
3557
3558 impl AgentConnection for FakeAgentConnection {
3559 fn telemetry_id(&self) -> &'static str {
3560 "fake"
3561 }
3562
3563 fn auth_methods(&self) -> &[acp::AuthMethod] {
3564 &self.auth_methods
3565 }
3566
3567 fn new_thread(
3568 self: Rc<Self>,
3569 project: Entity<Project>,
3570 _cwd: &Path,
3571 cx: &mut App,
3572 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3573 let session_id = acp::SessionId::new(
3574 rand::rng()
3575 .sample_iter(&distr::Alphanumeric)
3576 .take(7)
3577 .map(char::from)
3578 .collect::<String>(),
3579 );
3580 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3581 let thread = cx.new(|cx| {
3582 AcpThread::new(
3583 "Test",
3584 self.clone(),
3585 project,
3586 action_log,
3587 session_id.clone(),
3588 watch::Receiver::constant(
3589 acp::PromptCapabilities::new()
3590 .image(true)
3591 .audio(true)
3592 .embedded_context(true),
3593 ),
3594 cx,
3595 )
3596 });
3597 self.sessions.lock().insert(session_id, thread.downgrade());
3598 Task::ready(Ok(thread))
3599 }
3600
3601 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3602 if self.auth_methods().iter().any(|m| m.id == method) {
3603 Task::ready(Ok(()))
3604 } else {
3605 Task::ready(Err(anyhow!("Invalid Auth Method")))
3606 }
3607 }
3608
3609 fn prompt(
3610 &self,
3611 _id: Option<UserMessageId>,
3612 params: acp::PromptRequest,
3613 cx: &mut App,
3614 ) -> Task<gpui::Result<acp::PromptResponse>> {
3615 let sessions = self.sessions.lock();
3616 let thread = sessions.get(¶ms.session_id).unwrap();
3617 if let Some(handler) = &self.on_user_message {
3618 let handler = handler.clone();
3619 let thread = thread.clone();
3620 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3621 } else {
3622 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3623 }
3624 }
3625
3626 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3627 let sessions = self.sessions.lock();
3628 let thread = sessions.get(session_id).unwrap().clone();
3629
3630 cx.spawn(async move |cx| {
3631 thread
3632 .update(cx, |thread, cx| thread.cancel(cx))
3633 .unwrap()
3634 .await
3635 })
3636 .detach();
3637 }
3638
3639 fn truncate(
3640 &self,
3641 session_id: &acp::SessionId,
3642 _cx: &App,
3643 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3644 Some(Rc::new(FakeAgentSessionEditor {
3645 _session_id: session_id.clone(),
3646 }))
3647 }
3648
3649 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3650 self
3651 }
3652 }
3653
3654 struct FakeAgentSessionEditor {
3655 _session_id: acp::SessionId,
3656 }
3657
3658 impl AgentSessionTruncate for FakeAgentSessionEditor {
3659 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3660 Task::ready(Ok(()))
3661 }
3662 }
3663
3664 #[gpui::test]
3665 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3666 init_test(cx);
3667
3668 let fs = FakeFs::new(cx.executor());
3669 let project = Project::test(fs, [], cx).await;
3670 let connection = Rc::new(FakeAgentConnection::new());
3671 let thread = cx
3672 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3673 .await
3674 .unwrap();
3675
3676 // Try to update a tool call that doesn't exist
3677 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3678 thread.update(cx, |thread, cx| {
3679 let result = thread.handle_session_update(
3680 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3681 nonexistent_id.clone(),
3682 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3683 )),
3684 cx,
3685 );
3686
3687 // The update should succeed (not return an error)
3688 assert!(result.is_ok());
3689
3690 // There should now be exactly one entry in the thread
3691 assert_eq!(thread.entries.len(), 1);
3692
3693 // The entry should be a failed tool call
3694 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3695 assert_eq!(tool_call.id, nonexistent_id);
3696 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3697 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3698
3699 // Check that the content contains the error message
3700 assert_eq!(tool_call.content.len(), 1);
3701 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3702 match content_block {
3703 ContentBlock::Markdown { markdown } => {
3704 let markdown_text = markdown.read(cx).source();
3705 assert!(markdown_text.contains("Tool call not found"));
3706 }
3707 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3708 ContentBlock::ResourceLink { .. } => {
3709 panic!("Expected markdown content, got resource link")
3710 }
3711 }
3712 } else {
3713 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3714 }
3715 } else {
3716 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3717 }
3718 });
3719 }
3720
3721 /// Tests that restoring a checkpoint properly cleans up terminals that were
3722 /// created after that checkpoint, and cancels any in-progress generation.
3723 ///
3724 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3725 /// that were started after that checkpoint should be terminated, and any in-progress
3726 /// AI generation should be canceled.
3727 #[gpui::test]
3728 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3729 init_test(cx);
3730
3731 let fs = FakeFs::new(cx.executor());
3732 let project = Project::test(fs, [], cx).await;
3733 let connection = Rc::new(FakeAgentConnection::new());
3734 let thread = cx
3735 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3736 .await
3737 .unwrap();
3738
3739 // Send first user message to create a checkpoint
3740 cx.update(|cx| {
3741 thread.update(cx, |thread, cx| {
3742 thread.send(vec!["first message".into()], cx)
3743 })
3744 })
3745 .await
3746 .unwrap();
3747
3748 // Send second message (creates another checkpoint) - we'll restore to this one
3749 cx.update(|cx| {
3750 thread.update(cx, |thread, cx| {
3751 thread.send(vec!["second message".into()], cx)
3752 })
3753 })
3754 .await
3755 .unwrap();
3756
3757 // Create 2 terminals BEFORE the checkpoint that have completed running
3758 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3759 let mock_terminal_1 = cx.new(|cx| {
3760 let builder = ::terminal::TerminalBuilder::new_display_only(
3761 ::terminal::terminal_settings::CursorShape::default(),
3762 ::terminal::terminal_settings::AlternateScroll::On,
3763 None,
3764 0,
3765 )
3766 .unwrap();
3767 builder.subscribe(cx)
3768 });
3769
3770 thread.update(cx, |thread, cx| {
3771 thread.on_terminal_provider_event(
3772 TerminalProviderEvent::Created {
3773 terminal_id: terminal_id_1.clone(),
3774 label: "echo 'first'".to_string(),
3775 cwd: Some(PathBuf::from("/test")),
3776 output_byte_limit: None,
3777 terminal: mock_terminal_1.clone(),
3778 },
3779 cx,
3780 );
3781 });
3782
3783 thread.update(cx, |thread, cx| {
3784 thread.on_terminal_provider_event(
3785 TerminalProviderEvent::Output {
3786 terminal_id: terminal_id_1.clone(),
3787 data: b"first\n".to_vec(),
3788 },
3789 cx,
3790 );
3791 });
3792
3793 thread.update(cx, |thread, cx| {
3794 thread.on_terminal_provider_event(
3795 TerminalProviderEvent::Exit {
3796 terminal_id: terminal_id_1.clone(),
3797 status: acp::TerminalExitStatus::new().exit_code(0),
3798 },
3799 cx,
3800 );
3801 });
3802
3803 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3804 let mock_terminal_2 = cx.new(|cx| {
3805 let builder = ::terminal::TerminalBuilder::new_display_only(
3806 ::terminal::terminal_settings::CursorShape::default(),
3807 ::terminal::terminal_settings::AlternateScroll::On,
3808 None,
3809 0,
3810 )
3811 .unwrap();
3812 builder.subscribe(cx)
3813 });
3814
3815 thread.update(cx, |thread, cx| {
3816 thread.on_terminal_provider_event(
3817 TerminalProviderEvent::Created {
3818 terminal_id: terminal_id_2.clone(),
3819 label: "echo 'second'".to_string(),
3820 cwd: Some(PathBuf::from("/test")),
3821 output_byte_limit: None,
3822 terminal: mock_terminal_2.clone(),
3823 },
3824 cx,
3825 );
3826 });
3827
3828 thread.update(cx, |thread, cx| {
3829 thread.on_terminal_provider_event(
3830 TerminalProviderEvent::Output {
3831 terminal_id: terminal_id_2.clone(),
3832 data: b"second\n".to_vec(),
3833 },
3834 cx,
3835 );
3836 });
3837
3838 thread.update(cx, |thread, cx| {
3839 thread.on_terminal_provider_event(
3840 TerminalProviderEvent::Exit {
3841 terminal_id: terminal_id_2.clone(),
3842 status: acp::TerminalExitStatus::new().exit_code(0),
3843 },
3844 cx,
3845 );
3846 });
3847
3848 // Get the second message ID to restore to
3849 let second_message_id = thread.read_with(cx, |thread, _| {
3850 // At this point we have:
3851 // - Index 0: First user message (with checkpoint)
3852 // - Index 1: Second user message (with checkpoint)
3853 // No assistant responses because FakeAgentConnection just returns EndTurn
3854 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3855 panic!("expected user message at index 1");
3856 };
3857 message.id.clone().unwrap()
3858 });
3859
3860 // Create a terminal AFTER the checkpoint we'll restore to.
3861 // This simulates the AI agent starting a long-running terminal command.
3862 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3863 let mock_terminal = cx.new(|cx| {
3864 let builder = ::terminal::TerminalBuilder::new_display_only(
3865 ::terminal::terminal_settings::CursorShape::default(),
3866 ::terminal::terminal_settings::AlternateScroll::On,
3867 None,
3868 0,
3869 )
3870 .unwrap();
3871 builder.subscribe(cx)
3872 });
3873
3874 // Register the terminal as created
3875 thread.update(cx, |thread, cx| {
3876 thread.on_terminal_provider_event(
3877 TerminalProviderEvent::Created {
3878 terminal_id: terminal_id.clone(),
3879 label: "sleep 1000".to_string(),
3880 cwd: Some(PathBuf::from("/test")),
3881 output_byte_limit: None,
3882 terminal: mock_terminal.clone(),
3883 },
3884 cx,
3885 );
3886 });
3887
3888 // Simulate the terminal producing output (still running)
3889 thread.update(cx, |thread, cx| {
3890 thread.on_terminal_provider_event(
3891 TerminalProviderEvent::Output {
3892 terminal_id: terminal_id.clone(),
3893 data: b"terminal is running...\n".to_vec(),
3894 },
3895 cx,
3896 );
3897 });
3898
3899 // Create a tool call entry that references this terminal
3900 // This represents the agent requesting a terminal command
3901 thread.update(cx, |thread, cx| {
3902 thread
3903 .handle_session_update(
3904 acp::SessionUpdate::ToolCall(
3905 acp::ToolCall::new("terminal-tool-1", "Running command")
3906 .kind(acp::ToolKind::Execute)
3907 .status(acp::ToolCallStatus::InProgress)
3908 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
3909 terminal_id.clone(),
3910 ))])
3911 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
3912 ),
3913 cx,
3914 )
3915 .unwrap();
3916 });
3917
3918 // Verify terminal exists and is in the thread
3919 let terminal_exists_before =
3920 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
3921 assert!(
3922 terminal_exists_before,
3923 "Terminal should exist before checkpoint restore"
3924 );
3925
3926 // Verify the terminal's underlying task is still running (not completed)
3927 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
3928 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
3929 terminal_entity.read_with(cx, |term, _cx| {
3930 term.output().is_none() // output is None means it's still running
3931 })
3932 });
3933 assert!(
3934 terminal_running_before,
3935 "Terminal should be running before checkpoint restore"
3936 );
3937
3938 // Verify we have the expected entries before restore
3939 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
3940 assert!(
3941 entry_count_before > 1,
3942 "Should have multiple entries before restore"
3943 );
3944
3945 // Restore the checkpoint to the second message.
3946 // This should:
3947 // 1. Cancel any in-progress generation (via the cancel() call)
3948 // 2. Remove the terminal that was created after that point
3949 thread
3950 .update(cx, |thread, cx| {
3951 thread.restore_checkpoint(second_message_id, cx)
3952 })
3953 .await
3954 .unwrap();
3955
3956 // Verify that no send_task is in progress after restore
3957 // (cancel() clears the send_task)
3958 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
3959 assert!(
3960 !has_send_task_after,
3961 "Should not have a send_task after restore (cancel should have cleared it)"
3962 );
3963
3964 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
3965 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
3966 assert_eq!(
3967 entry_count, 1,
3968 "Should have 1 entry after restore (only the first user message)"
3969 );
3970
3971 // Verify the 2 completed terminals from before the checkpoint still exist
3972 let terminal_1_exists = thread.read_with(cx, |thread, _| {
3973 thread.terminals.contains_key(&terminal_id_1)
3974 });
3975 assert!(
3976 terminal_1_exists,
3977 "Terminal 1 (from before checkpoint) should still exist"
3978 );
3979
3980 let terminal_2_exists = thread.read_with(cx, |thread, _| {
3981 thread.terminals.contains_key(&terminal_id_2)
3982 });
3983 assert!(
3984 terminal_2_exists,
3985 "Terminal 2 (from before checkpoint) should still exist"
3986 );
3987
3988 // Verify they're still in completed state
3989 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
3990 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
3991 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
3992 });
3993 assert!(terminal_1_completed, "Terminal 1 should still be completed");
3994
3995 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
3996 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
3997 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
3998 });
3999 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4000
4001 // Verify the running terminal (created after checkpoint) was removed
4002 let terminal_3_exists =
4003 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4004 assert!(
4005 !terminal_3_exists,
4006 "Terminal 3 (created after checkpoint) should have been removed"
4007 );
4008
4009 // Verify total count is 2 (the two from before the checkpoint)
4010 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4011 assert_eq!(
4012 terminal_count, 2,
4013 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4014 );
4015 }
4016}