1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15use task::{Shell, ShellBuilder};
16pub use terminal::*;
17
18use action_log::{ActionLog, ActionLogTelemetry};
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(
98 chunk: &str,
99 language_registry: &Arc<LanguageRegistry>,
100 path_style: PathStyle,
101 cx: &mut App,
102 ) -> Self {
103 Self::Message {
104 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
105 }
106 }
107
108 fn to_markdown(&self, cx: &App) -> String {
109 match self {
110 Self::Message { block } => block.to_markdown(cx).to_string(),
111 Self::Thought { block } => {
112 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
113 }
114 }
115 }
116}
117
118#[derive(Debug)]
119pub enum AgentThreadEntry {
120 UserMessage(UserMessage),
121 AssistantMessage(AssistantMessage),
122 ToolCall(ToolCall),
123}
124
125impl AgentThreadEntry {
126 pub fn to_markdown(&self, cx: &App) -> String {
127 match self {
128 Self::UserMessage(message) => message.to_markdown(cx),
129 Self::AssistantMessage(message) => message.to_markdown(cx),
130 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
131 }
132 }
133
134 pub fn user_message(&self) -> Option<&UserMessage> {
135 if let AgentThreadEntry::UserMessage(message) = self {
136 Some(message)
137 } else {
138 None
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
151 if let AgentThreadEntry::ToolCall(call) = self {
152 itertools::Either::Left(call.terminals())
153 } else {
154 itertools::Either::Right(std::iter::empty())
155 }
156 }
157
158 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
159 if let AgentThreadEntry::ToolCall(ToolCall {
160 locations,
161 resolved_locations,
162 ..
163 }) = self
164 {
165 Some((
166 locations.get(ix)?.clone(),
167 resolved_locations.get(ix)?.clone()?,
168 ))
169 } else {
170 None
171 }
172 }
173}
174
175#[derive(Debug)]
176pub struct ToolCall {
177 pub id: acp::ToolCallId,
178 pub label: Entity<Markdown>,
179 pub kind: acp::ToolKind,
180 pub content: Vec<ToolCallContent>,
181 pub status: ToolCallStatus,
182 pub locations: Vec<acp::ToolCallLocation>,
183 pub resolved_locations: Vec<Option<AgentLocation>>,
184 pub raw_input: Option<serde_json::Value>,
185 pub raw_output: Option<serde_json::Value>,
186}
187
188impl ToolCall {
189 fn from_acp(
190 tool_call: acp::ToolCall,
191 status: ToolCallStatus,
192 language_registry: Arc<LanguageRegistry>,
193 path_style: PathStyle,
194 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
195 cx: &mut App,
196 ) -> Result<Self> {
197 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
198 first_line.to_owned() + "…"
199 } else {
200 tool_call.title
201 };
202 let mut content = Vec::with_capacity(tool_call.content.len());
203 for item in tool_call.content {
204 if let Some(item) = ToolCallContent::from_acp(
205 item,
206 language_registry.clone(),
207 path_style,
208 terminals,
209 cx,
210 )? {
211 content.push(item);
212 }
213 }
214
215 let result = Self {
216 id: tool_call.tool_call_id,
217 label: cx
218 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
219 kind: tool_call.kind,
220 content,
221 locations: tool_call.locations,
222 resolved_locations: Vec::default(),
223 status,
224 raw_input: tool_call.raw_input,
225 raw_output: tool_call.raw_output,
226 };
227 Ok(result)
228 }
229
230 fn update_fields(
231 &mut self,
232 fields: acp::ToolCallUpdateFields,
233 language_registry: Arc<LanguageRegistry>,
234 path_style: PathStyle,
235 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
236 cx: &mut App,
237 ) -> Result<()> {
238 let acp::ToolCallUpdateFields {
239 kind,
240 status,
241 title,
242 content,
243 locations,
244 raw_input,
245 raw_output,
246 ..
247 } = fields;
248
249 if let Some(kind) = kind {
250 self.kind = kind;
251 }
252
253 if let Some(status) = status {
254 self.status = status.into();
255 }
256
257 if let Some(title) = title {
258 self.label.update(cx, |label, cx| {
259 if let Some((first_line, _)) = title.split_once("\n") {
260 label.replace(first_line.to_owned() + "…", cx)
261 } else {
262 label.replace(title, cx);
263 }
264 });
265 }
266
267 if let Some(content) = content {
268 let mut new_content_len = content.len();
269 let mut content = content.into_iter();
270
271 // Reuse existing content if we can
272 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
273 let valid_content =
274 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
275 if !valid_content {
276 new_content_len -= 1;
277 }
278 }
279 for new in content {
280 if let Some(new) = ToolCallContent::from_acp(
281 new,
282 language_registry.clone(),
283 path_style,
284 terminals,
285 cx,
286 )? {
287 self.content.push(new);
288 } else {
289 new_content_len -= 1;
290 }
291 }
292 self.content.truncate(new_content_len);
293 }
294
295 if let Some(locations) = locations {
296 self.locations = locations;
297 }
298
299 if let Some(raw_input) = raw_input {
300 self.raw_input = Some(raw_input);
301 }
302
303 if let Some(raw_output) = raw_output {
304 if self.content.is_empty()
305 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
306 {
307 self.content
308 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
309 markdown,
310 }));
311 }
312 self.raw_output = Some(raw_output);
313 }
314 Ok(())
315 }
316
317 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
318 self.content.iter().filter_map(|content| match content {
319 ToolCallContent::Diff(diff) => Some(diff),
320 ToolCallContent::ContentBlock(_) => None,
321 ToolCallContent::Terminal(_) => None,
322 })
323 }
324
325 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
326 self.content.iter().filter_map(|content| match content {
327 ToolCallContent::Terminal(terminal) => Some(terminal),
328 ToolCallContent::ContentBlock(_) => None,
329 ToolCallContent::Diff(_) => None,
330 })
331 }
332
333 fn to_markdown(&self, cx: &App) -> String {
334 let mut markdown = format!(
335 "**Tool Call: {}**\nStatus: {}\n\n",
336 self.label.read(cx).source(),
337 self.status
338 );
339 for content in &self.content {
340 markdown.push_str(content.to_markdown(cx).as_str());
341 markdown.push_str("\n\n");
342 }
343 markdown
344 }
345
346 async fn resolve_location(
347 location: acp::ToolCallLocation,
348 project: WeakEntity<Project>,
349 cx: &mut AsyncApp,
350 ) -> Option<ResolvedLocation> {
351 let buffer = project
352 .update(cx, |project, cx| {
353 project
354 .project_path_for_absolute_path(&location.path, cx)
355 .map(|path| project.open_buffer(path, cx))
356 })
357 .ok()??;
358 let buffer = buffer.await.log_err()?;
359 let position = buffer
360 .update(cx, |buffer, _| {
361 let snapshot = buffer.snapshot();
362 if let Some(row) = location.line {
363 let column = snapshot.indent_size_for_line(row).len;
364 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
365 snapshot.anchor_before(point)
366 } else {
367 Anchor::min_for_buffer(snapshot.remote_id())
368 }
369 })
370 .ok()?;
371
372 Some(ResolvedLocation { buffer, position })
373 }
374
375 fn resolve_locations(
376 &self,
377 project: Entity<Project>,
378 cx: &mut App,
379 ) -> Task<Vec<Option<ResolvedLocation>>> {
380 let locations = self.locations.clone();
381 project.update(cx, |_, cx| {
382 cx.spawn(async move |project, cx| {
383 let mut new_locations = Vec::new();
384 for location in locations {
385 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
386 }
387 new_locations
388 })
389 })
390 }
391}
392
393// Separate so we can hold a strong reference to the buffer
394// for saving on the thread
395#[derive(Clone, Debug, PartialEq, Eq)]
396struct ResolvedLocation {
397 buffer: Entity<Buffer>,
398 position: Anchor,
399}
400
401impl From<&ResolvedLocation> for AgentLocation {
402 fn from(value: &ResolvedLocation) -> Self {
403 Self {
404 buffer: value.buffer.downgrade(),
405 position: value.position,
406 }
407 }
408}
409
410#[derive(Debug)]
411pub enum ToolCallStatus {
412 /// The tool call hasn't started running yet, but we start showing it to
413 /// the user.
414 Pending,
415 /// The tool call is waiting for confirmation from the user.
416 WaitingForConfirmation {
417 options: Vec<acp::PermissionOption>,
418 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
419 },
420 /// The tool call is currently running.
421 InProgress,
422 /// The tool call completed successfully.
423 Completed,
424 /// The tool call failed.
425 Failed,
426 /// The user rejected the tool call.
427 Rejected,
428 /// The user canceled generation so the tool call was canceled.
429 Canceled,
430}
431
432impl From<acp::ToolCallStatus> for ToolCallStatus {
433 fn from(status: acp::ToolCallStatus) -> Self {
434 match status {
435 acp::ToolCallStatus::Pending => Self::Pending,
436 acp::ToolCallStatus::InProgress => Self::InProgress,
437 acp::ToolCallStatus::Completed => Self::Completed,
438 acp::ToolCallStatus::Failed => Self::Failed,
439 _ => Self::Pending,
440 }
441 }
442}
443
444impl Display for ToolCallStatus {
445 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446 write!(
447 f,
448 "{}",
449 match self {
450 ToolCallStatus::Pending => "Pending",
451 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
452 ToolCallStatus::InProgress => "In Progress",
453 ToolCallStatus::Completed => "Completed",
454 ToolCallStatus::Failed => "Failed",
455 ToolCallStatus::Rejected => "Rejected",
456 ToolCallStatus::Canceled => "Canceled",
457 }
458 )
459 }
460}
461
462#[derive(Debug, PartialEq, Clone)]
463pub enum ContentBlock {
464 Empty,
465 Markdown { markdown: Entity<Markdown> },
466 ResourceLink { resource_link: acp::ResourceLink },
467}
468
469impl ContentBlock {
470 pub fn new(
471 block: acp::ContentBlock,
472 language_registry: &Arc<LanguageRegistry>,
473 path_style: PathStyle,
474 cx: &mut App,
475 ) -> Self {
476 let mut this = Self::Empty;
477 this.append(block, language_registry, path_style, cx);
478 this
479 }
480
481 pub fn new_combined(
482 blocks: impl IntoIterator<Item = acp::ContentBlock>,
483 language_registry: Arc<LanguageRegistry>,
484 path_style: PathStyle,
485 cx: &mut App,
486 ) -> Self {
487 let mut this = Self::Empty;
488 for block in blocks {
489 this.append(block, &language_registry, path_style, cx);
490 }
491 this
492 }
493
494 pub fn append(
495 &mut self,
496 block: acp::ContentBlock,
497 language_registry: &Arc<LanguageRegistry>,
498 path_style: PathStyle,
499 cx: &mut App,
500 ) {
501 if matches!(self, ContentBlock::Empty)
502 && let acp::ContentBlock::ResourceLink(resource_link) = block
503 {
504 *self = ContentBlock::ResourceLink { resource_link };
505 return;
506 }
507
508 let new_content = self.block_string_contents(block, path_style);
509
510 match self {
511 ContentBlock::Empty => {
512 *self = Self::create_markdown_block(new_content, language_registry, cx);
513 }
514 ContentBlock::Markdown { markdown } => {
515 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
516 }
517 ContentBlock::ResourceLink { resource_link } => {
518 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
519 let combined = format!("{}\n{}", existing_content, new_content);
520
521 *self = Self::create_markdown_block(combined, language_registry, cx);
522 }
523 }
524 }
525
526 fn create_markdown_block(
527 content: String,
528 language_registry: &Arc<LanguageRegistry>,
529 cx: &mut App,
530 ) -> ContentBlock {
531 ContentBlock::Markdown {
532 markdown: cx
533 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
534 }
535 }
536
537 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
538 match block {
539 acp::ContentBlock::Text(text_content) => text_content.text,
540 acp::ContentBlock::ResourceLink(resource_link) => {
541 Self::resource_link_md(&resource_link.uri, path_style)
542 }
543 acp::ContentBlock::Resource(acp::EmbeddedResource {
544 resource:
545 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
546 uri,
547 ..
548 }),
549 ..
550 }) => Self::resource_link_md(&uri, path_style),
551 acp::ContentBlock::Image(image) => Self::image_md(&image),
552 _ => String::new(),
553 }
554 }
555
556 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
557 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
558 uri.as_link().to_string()
559 } else {
560 uri.to_string()
561 }
562 }
563
564 fn image_md(_image: &acp::ImageContent) -> String {
565 "`Image`".into()
566 }
567
568 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
569 match self {
570 ContentBlock::Empty => "",
571 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
572 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
573 }
574 }
575
576 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
577 match self {
578 ContentBlock::Empty => None,
579 ContentBlock::Markdown { markdown } => Some(markdown),
580 ContentBlock::ResourceLink { .. } => None,
581 }
582 }
583
584 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
585 match self {
586 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
587 _ => None,
588 }
589 }
590}
591
592#[derive(Debug)]
593pub enum ToolCallContent {
594 ContentBlock(ContentBlock),
595 Diff(Entity<Diff>),
596 Terminal(Entity<Terminal>),
597}
598
599impl ToolCallContent {
600 pub fn from_acp(
601 content: acp::ToolCallContent,
602 language_registry: Arc<LanguageRegistry>,
603 path_style: PathStyle,
604 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
605 cx: &mut App,
606 ) -> Result<Option<Self>> {
607 match content {
608 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
609 Ok(Some(Self::ContentBlock(ContentBlock::new(
610 content,
611 &language_registry,
612 path_style,
613 cx,
614 ))))
615 }
616 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
617 Diff::finalized(
618 diff.path.to_string_lossy().into_owned(),
619 diff.old_text,
620 diff.new_text,
621 language_registry,
622 cx,
623 )
624 })))),
625 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
626 .get(&terminal_id)
627 .cloned()
628 .map(|terminal| Some(Self::Terminal(terminal)))
629 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
630 _ => Ok(None),
631 }
632 }
633
634 pub fn update_from_acp(
635 &mut self,
636 new: acp::ToolCallContent,
637 language_registry: Arc<LanguageRegistry>,
638 path_style: PathStyle,
639 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
640 cx: &mut App,
641 ) -> Result<bool> {
642 let needs_update = match (&self, &new) {
643 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
644 old_diff.read(cx).needs_update(
645 new_diff.old_text.as_deref().unwrap_or(""),
646 &new_diff.new_text,
647 cx,
648 )
649 }
650 _ => true,
651 };
652
653 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
654 if needs_update {
655 *self = update;
656 }
657 Ok(true)
658 } else {
659 Ok(false)
660 }
661 }
662
663 pub fn to_markdown(&self, cx: &App) -> String {
664 match self {
665 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
666 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
667 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
668 }
669 }
670}
671
672#[derive(Debug, PartialEq)]
673pub enum ToolCallUpdate {
674 UpdateFields(acp::ToolCallUpdate),
675 UpdateDiff(ToolCallUpdateDiff),
676 UpdateTerminal(ToolCallUpdateTerminal),
677}
678
679impl ToolCallUpdate {
680 fn id(&self) -> &acp::ToolCallId {
681 match self {
682 Self::UpdateFields(update) => &update.tool_call_id,
683 Self::UpdateDiff(diff) => &diff.id,
684 Self::UpdateTerminal(terminal) => &terminal.id,
685 }
686 }
687}
688
689impl From<acp::ToolCallUpdate> for ToolCallUpdate {
690 fn from(update: acp::ToolCallUpdate) -> Self {
691 Self::UpdateFields(update)
692 }
693}
694
695impl From<ToolCallUpdateDiff> for ToolCallUpdate {
696 fn from(diff: ToolCallUpdateDiff) -> Self {
697 Self::UpdateDiff(diff)
698 }
699}
700
701#[derive(Debug, PartialEq)]
702pub struct ToolCallUpdateDiff {
703 pub id: acp::ToolCallId,
704 pub diff: Entity<Diff>,
705}
706
707impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
708 fn from(terminal: ToolCallUpdateTerminal) -> Self {
709 Self::UpdateTerminal(terminal)
710 }
711}
712
713#[derive(Debug, PartialEq)]
714pub struct ToolCallUpdateTerminal {
715 pub id: acp::ToolCallId,
716 pub terminal: Entity<Terminal>,
717}
718
719#[derive(Debug, Default)]
720pub struct Plan {
721 pub entries: Vec<PlanEntry>,
722}
723
724#[derive(Debug)]
725pub struct PlanStats<'a> {
726 pub in_progress_entry: Option<&'a PlanEntry>,
727 pub pending: u32,
728 pub completed: u32,
729}
730
731impl Plan {
732 pub fn is_empty(&self) -> bool {
733 self.entries.is_empty()
734 }
735
736 pub fn stats(&self) -> PlanStats<'_> {
737 let mut stats = PlanStats {
738 in_progress_entry: None,
739 pending: 0,
740 completed: 0,
741 };
742
743 for entry in &self.entries {
744 match &entry.status {
745 acp::PlanEntryStatus::Pending => {
746 stats.pending += 1;
747 }
748 acp::PlanEntryStatus::InProgress => {
749 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
750 }
751 acp::PlanEntryStatus::Completed => {
752 stats.completed += 1;
753 }
754 _ => {}
755 }
756 }
757
758 stats
759 }
760}
761
762#[derive(Debug)]
763pub struct PlanEntry {
764 pub content: Entity<Markdown>,
765 pub priority: acp::PlanEntryPriority,
766 pub status: acp::PlanEntryStatus,
767}
768
769impl PlanEntry {
770 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
771 Self {
772 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
773 priority: entry.priority,
774 status: entry.status,
775 }
776 }
777}
778
779#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
780pub struct TokenUsage {
781 pub max_tokens: u64,
782 pub used_tokens: u64,
783}
784
785impl TokenUsage {
786 pub fn ratio(&self) -> TokenUsageRatio {
787 #[cfg(debug_assertions)]
788 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
789 .unwrap_or("0.8".to_string())
790 .parse()
791 .unwrap();
792 #[cfg(not(debug_assertions))]
793 let warning_threshold: f32 = 0.8;
794
795 // When the maximum is unknown because there is no selected model,
796 // avoid showing the token limit warning.
797 if self.max_tokens == 0 {
798 TokenUsageRatio::Normal
799 } else if self.used_tokens >= self.max_tokens {
800 TokenUsageRatio::Exceeded
801 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
802 TokenUsageRatio::Warning
803 } else {
804 TokenUsageRatio::Normal
805 }
806 }
807}
808
809#[derive(Debug, Clone, PartialEq, Eq)]
810pub enum TokenUsageRatio {
811 Normal,
812 Warning,
813 Exceeded,
814}
815
816#[derive(Debug, Clone)]
817pub struct RetryStatus {
818 pub last_error: SharedString,
819 pub attempt: usize,
820 pub max_attempts: usize,
821 pub started_at: Instant,
822 pub duration: Duration,
823}
824
825pub struct AcpThread {
826 title: SharedString,
827 entries: Vec<AgentThreadEntry>,
828 plan: Plan,
829 project: Entity<Project>,
830 action_log: Entity<ActionLog>,
831 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
832 send_task: Option<Task<()>>,
833 connection: Rc<dyn AgentConnection>,
834 session_id: acp::SessionId,
835 token_usage: Option<TokenUsage>,
836 prompt_capabilities: acp::PromptCapabilities,
837 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
838 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
839 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
840 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
841}
842
843impl From<&AcpThread> for ActionLogTelemetry {
844 fn from(value: &AcpThread) -> Self {
845 Self {
846 agent_telemetry_id: value.connection().telemetry_id(),
847 session_id: value.session_id.0.clone(),
848 }
849 }
850}
851
852#[derive(Debug)]
853pub enum AcpThreadEvent {
854 NewEntry,
855 TitleUpdated,
856 TokenUsageUpdated,
857 EntryUpdated(usize),
858 EntriesRemoved(Range<usize>),
859 ToolAuthorizationRequired,
860 Retry(RetryStatus),
861 Stopped,
862 Error,
863 LoadError(LoadError),
864 PromptCapabilitiesUpdated,
865 Refusal,
866 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
867 ModeUpdated(acp::SessionModeId),
868}
869
870impl EventEmitter<AcpThreadEvent> for AcpThread {}
871
872#[derive(Debug, Clone)]
873pub enum TerminalProviderEvent {
874 Created {
875 terminal_id: acp::TerminalId,
876 label: String,
877 cwd: Option<PathBuf>,
878 output_byte_limit: Option<u64>,
879 terminal: Entity<::terminal::Terminal>,
880 },
881 Output {
882 terminal_id: acp::TerminalId,
883 data: Vec<u8>,
884 },
885 TitleChanged {
886 terminal_id: acp::TerminalId,
887 title: String,
888 },
889 Exit {
890 terminal_id: acp::TerminalId,
891 status: acp::TerminalExitStatus,
892 },
893}
894
895#[derive(Debug, Clone)]
896pub enum TerminalProviderCommand {
897 WriteInput {
898 terminal_id: acp::TerminalId,
899 bytes: Vec<u8>,
900 },
901 Resize {
902 terminal_id: acp::TerminalId,
903 cols: u16,
904 rows: u16,
905 },
906 Close {
907 terminal_id: acp::TerminalId,
908 },
909}
910
911impl AcpThread {
912 pub fn on_terminal_provider_event(
913 &mut self,
914 event: TerminalProviderEvent,
915 cx: &mut Context<Self>,
916 ) {
917 match event {
918 TerminalProviderEvent::Created {
919 terminal_id,
920 label,
921 cwd,
922 output_byte_limit,
923 terminal,
924 } => {
925 let entity = self.register_terminal_created(
926 terminal_id.clone(),
927 label,
928 cwd,
929 output_byte_limit,
930 terminal,
931 cx,
932 );
933
934 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
935 for data in chunks.drain(..) {
936 entity.update(cx, |term, cx| {
937 term.inner().update(cx, |inner, cx| {
938 inner.write_output(&data, cx);
939 })
940 });
941 }
942 }
943
944 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
945 entity.update(cx, |_term, cx| {
946 cx.notify();
947 });
948 }
949
950 cx.notify();
951 }
952 TerminalProviderEvent::Output { terminal_id, data } => {
953 if let Some(entity) = self.terminals.get(&terminal_id) {
954 entity.update(cx, |term, cx| {
955 term.inner().update(cx, |inner, cx| {
956 inner.write_output(&data, cx);
957 })
958 });
959 } else {
960 self.pending_terminal_output
961 .entry(terminal_id)
962 .or_default()
963 .push(data);
964 }
965 }
966 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
967 if let Some(entity) = self.terminals.get(&terminal_id) {
968 entity.update(cx, |term, cx| {
969 term.inner().update(cx, |inner, cx| {
970 inner.breadcrumb_text = title;
971 cx.emit(::terminal::Event::BreadcrumbsChanged);
972 })
973 });
974 }
975 }
976 TerminalProviderEvent::Exit {
977 terminal_id,
978 status,
979 } => {
980 if let Some(entity) = self.terminals.get(&terminal_id) {
981 entity.update(cx, |_term, cx| {
982 cx.notify();
983 });
984 } else {
985 self.pending_terminal_exit.insert(terminal_id, status);
986 }
987 }
988 }
989 }
990}
991
992#[derive(PartialEq, Eq, Debug)]
993pub enum ThreadStatus {
994 Idle,
995 Generating,
996}
997
998#[derive(Debug, Clone)]
999pub enum LoadError {
1000 Unsupported {
1001 command: SharedString,
1002 current_version: SharedString,
1003 minimum_version: SharedString,
1004 },
1005 FailedToInstall(SharedString),
1006 Exited {
1007 status: ExitStatus,
1008 },
1009 Other(SharedString),
1010}
1011
1012impl Display for LoadError {
1013 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1014 match self {
1015 LoadError::Unsupported {
1016 command: path,
1017 current_version,
1018 minimum_version,
1019 } => {
1020 write!(
1021 f,
1022 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1023 )
1024 }
1025 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1026 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1027 LoadError::Other(msg) => write!(f, "{msg}"),
1028 }
1029 }
1030}
1031
1032impl Error for LoadError {}
1033
1034impl AcpThread {
1035 pub fn new(
1036 title: impl Into<SharedString>,
1037 connection: Rc<dyn AgentConnection>,
1038 project: Entity<Project>,
1039 action_log: Entity<ActionLog>,
1040 session_id: acp::SessionId,
1041 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1042 cx: &mut Context<Self>,
1043 ) -> Self {
1044 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1045 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1046 loop {
1047 let caps = prompt_capabilities_rx.recv().await?;
1048 this.update(cx, |this, cx| {
1049 this.prompt_capabilities = caps;
1050 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1051 })?;
1052 }
1053 });
1054
1055 Self {
1056 action_log,
1057 shared_buffers: Default::default(),
1058 entries: Default::default(),
1059 plan: Default::default(),
1060 title: title.into(),
1061 project,
1062 send_task: None,
1063 connection,
1064 session_id,
1065 token_usage: None,
1066 prompt_capabilities,
1067 _observe_prompt_capabilities: task,
1068 terminals: HashMap::default(),
1069 pending_terminal_output: HashMap::default(),
1070 pending_terminal_exit: HashMap::default(),
1071 }
1072 }
1073
1074 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1075 self.prompt_capabilities.clone()
1076 }
1077
1078 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1079 &self.connection
1080 }
1081
1082 pub fn action_log(&self) -> &Entity<ActionLog> {
1083 &self.action_log
1084 }
1085
1086 pub fn project(&self) -> &Entity<Project> {
1087 &self.project
1088 }
1089
1090 pub fn title(&self) -> SharedString {
1091 self.title.clone()
1092 }
1093
1094 pub fn entries(&self) -> &[AgentThreadEntry] {
1095 &self.entries
1096 }
1097
1098 pub fn session_id(&self) -> &acp::SessionId {
1099 &self.session_id
1100 }
1101
1102 pub fn status(&self) -> ThreadStatus {
1103 if self.send_task.is_some() {
1104 ThreadStatus::Generating
1105 } else {
1106 ThreadStatus::Idle
1107 }
1108 }
1109
1110 pub fn token_usage(&self) -> Option<&TokenUsage> {
1111 self.token_usage.as_ref()
1112 }
1113
1114 pub fn has_pending_edit_tool_calls(&self) -> bool {
1115 for entry in self.entries.iter().rev() {
1116 match entry {
1117 AgentThreadEntry::UserMessage(_) => return false,
1118 AgentThreadEntry::ToolCall(
1119 call @ ToolCall {
1120 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1121 ..
1122 },
1123 ) if call.diffs().next().is_some() => {
1124 return true;
1125 }
1126 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1127 }
1128 }
1129
1130 false
1131 }
1132
1133 pub fn used_tools_since_last_user_message(&self) -> bool {
1134 for entry in self.entries.iter().rev() {
1135 match entry {
1136 AgentThreadEntry::UserMessage(..) => return false,
1137 AgentThreadEntry::AssistantMessage(..) => continue,
1138 AgentThreadEntry::ToolCall(..) => return true,
1139 }
1140 }
1141
1142 false
1143 }
1144
1145 pub fn handle_session_update(
1146 &mut self,
1147 update: acp::SessionUpdate,
1148 cx: &mut Context<Self>,
1149 ) -> Result<(), acp::Error> {
1150 match update {
1151 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1152 self.push_user_content_block(None, content, cx);
1153 }
1154 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1155 self.push_assistant_content_block(content, false, cx);
1156 }
1157 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1158 self.push_assistant_content_block(content, true, cx);
1159 }
1160 acp::SessionUpdate::ToolCall(tool_call) => {
1161 self.upsert_tool_call(tool_call, cx)?;
1162 }
1163 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1164 self.update_tool_call(tool_call_update, cx)?;
1165 }
1166 acp::SessionUpdate::Plan(plan) => {
1167 self.update_plan(plan, cx);
1168 }
1169 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1170 available_commands,
1171 ..
1172 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1173 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1174 current_mode_id,
1175 ..
1176 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1177 _ => {}
1178 }
1179 Ok(())
1180 }
1181
1182 pub fn push_user_content_block(
1183 &mut self,
1184 message_id: Option<UserMessageId>,
1185 chunk: acp::ContentBlock,
1186 cx: &mut Context<Self>,
1187 ) {
1188 let language_registry = self.project.read(cx).languages().clone();
1189 let path_style = self.project.read(cx).path_style(cx);
1190 let entries_len = self.entries.len();
1191
1192 if let Some(last_entry) = self.entries.last_mut()
1193 && let AgentThreadEntry::UserMessage(UserMessage {
1194 id,
1195 content,
1196 chunks,
1197 ..
1198 }) = last_entry
1199 {
1200 *id = message_id.or(id.take());
1201 content.append(chunk.clone(), &language_registry, path_style, cx);
1202 chunks.push(chunk);
1203 let idx = entries_len - 1;
1204 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1205 } else {
1206 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1207 self.push_entry(
1208 AgentThreadEntry::UserMessage(UserMessage {
1209 id: message_id,
1210 content,
1211 chunks: vec![chunk],
1212 checkpoint: None,
1213 }),
1214 cx,
1215 );
1216 }
1217 }
1218
1219 pub fn push_assistant_content_block(
1220 &mut self,
1221 chunk: acp::ContentBlock,
1222 is_thought: bool,
1223 cx: &mut Context<Self>,
1224 ) {
1225 let language_registry = self.project.read(cx).languages().clone();
1226 let path_style = self.project.read(cx).path_style(cx);
1227 let entries_len = self.entries.len();
1228 if let Some(last_entry) = self.entries.last_mut()
1229 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1230 {
1231 let idx = entries_len - 1;
1232 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1233 match (chunks.last_mut(), is_thought) {
1234 (Some(AssistantMessageChunk::Message { block }), false)
1235 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1236 block.append(chunk, &language_registry, path_style, cx)
1237 }
1238 _ => {
1239 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1240 if is_thought {
1241 chunks.push(AssistantMessageChunk::Thought { block })
1242 } else {
1243 chunks.push(AssistantMessageChunk::Message { block })
1244 }
1245 }
1246 }
1247 } else {
1248 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1249 let chunk = if is_thought {
1250 AssistantMessageChunk::Thought { block }
1251 } else {
1252 AssistantMessageChunk::Message { block }
1253 };
1254
1255 self.push_entry(
1256 AgentThreadEntry::AssistantMessage(AssistantMessage {
1257 chunks: vec![chunk],
1258 }),
1259 cx,
1260 );
1261 }
1262 }
1263
1264 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1265 self.entries.push(entry);
1266 cx.emit(AcpThreadEvent::NewEntry);
1267 }
1268
1269 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1270 self.connection.set_title(&self.session_id, cx).is_some()
1271 }
1272
1273 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1274 if title != self.title {
1275 self.title = title.clone();
1276 cx.emit(AcpThreadEvent::TitleUpdated);
1277 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1278 return set_title.run(title, cx);
1279 }
1280 }
1281 Task::ready(Ok(()))
1282 }
1283
1284 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1285 self.token_usage = usage;
1286 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1287 }
1288
1289 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1290 cx.emit(AcpThreadEvent::Retry(status));
1291 }
1292
1293 pub fn update_tool_call(
1294 &mut self,
1295 update: impl Into<ToolCallUpdate>,
1296 cx: &mut Context<Self>,
1297 ) -> Result<()> {
1298 let update = update.into();
1299 let languages = self.project.read(cx).languages().clone();
1300 let path_style = self.project.read(cx).path_style(cx);
1301
1302 let ix = match self.index_for_tool_call(update.id()) {
1303 Some(ix) => ix,
1304 None => {
1305 // Tool call not found - create a failed tool call entry
1306 let failed_tool_call = ToolCall {
1307 id: update.id().clone(),
1308 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1309 kind: acp::ToolKind::Fetch,
1310 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1311 "Tool call not found".into(),
1312 &languages,
1313 path_style,
1314 cx,
1315 ))],
1316 status: ToolCallStatus::Failed,
1317 locations: Vec::new(),
1318 resolved_locations: Vec::new(),
1319 raw_input: None,
1320 raw_output: None,
1321 };
1322 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1323 return Ok(());
1324 }
1325 };
1326 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1327 unreachable!()
1328 };
1329
1330 match update {
1331 ToolCallUpdate::UpdateFields(update) => {
1332 let location_updated = update.fields.locations.is_some();
1333 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1334 if location_updated {
1335 self.resolve_locations(update.tool_call_id, cx);
1336 }
1337 }
1338 ToolCallUpdate::UpdateDiff(update) => {
1339 call.content.clear();
1340 call.content.push(ToolCallContent::Diff(update.diff));
1341 }
1342 ToolCallUpdate::UpdateTerminal(update) => {
1343 call.content.clear();
1344 call.content
1345 .push(ToolCallContent::Terminal(update.terminal));
1346 }
1347 }
1348
1349 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1350
1351 Ok(())
1352 }
1353
1354 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1355 pub fn upsert_tool_call(
1356 &mut self,
1357 tool_call: acp::ToolCall,
1358 cx: &mut Context<Self>,
1359 ) -> Result<(), acp::Error> {
1360 let status = tool_call.status.into();
1361 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1362 }
1363
1364 /// Fails if id does not match an existing entry.
1365 pub fn upsert_tool_call_inner(
1366 &mut self,
1367 update: acp::ToolCallUpdate,
1368 status: ToolCallStatus,
1369 cx: &mut Context<Self>,
1370 ) -> Result<(), acp::Error> {
1371 let language_registry = self.project.read(cx).languages().clone();
1372 let path_style = self.project.read(cx).path_style(cx);
1373 let id = update.tool_call_id.clone();
1374
1375 let agent_telemetry_id = self.connection().telemetry_id();
1376 let session = self.session_id();
1377 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1378 let status = if matches!(status, ToolCallStatus::Completed) {
1379 "completed"
1380 } else {
1381 "failed"
1382 };
1383 telemetry::event!(
1384 "Agent Tool Call Completed",
1385 agent_telemetry_id,
1386 session,
1387 status
1388 );
1389 }
1390
1391 if let Some(ix) = self.index_for_tool_call(&id) {
1392 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1393 unreachable!()
1394 };
1395
1396 call.update_fields(
1397 update.fields,
1398 language_registry,
1399 path_style,
1400 &self.terminals,
1401 cx,
1402 )?;
1403 call.status = status;
1404
1405 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1406 } else {
1407 let call = ToolCall::from_acp(
1408 update.try_into()?,
1409 status,
1410 language_registry,
1411 self.project.read(cx).path_style(cx),
1412 &self.terminals,
1413 cx,
1414 )?;
1415 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1416 };
1417
1418 self.resolve_locations(id, cx);
1419 Ok(())
1420 }
1421
1422 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1423 self.entries
1424 .iter()
1425 .enumerate()
1426 .rev()
1427 .find_map(|(index, entry)| {
1428 if let AgentThreadEntry::ToolCall(tool_call) = entry
1429 && &tool_call.id == id
1430 {
1431 Some(index)
1432 } else {
1433 None
1434 }
1435 })
1436 }
1437
1438 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1439 // The tool call we are looking for is typically the last one, or very close to the end.
1440 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1441 self.entries
1442 .iter_mut()
1443 .enumerate()
1444 .rev()
1445 .find_map(|(index, tool_call)| {
1446 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1447 && &tool_call.id == id
1448 {
1449 Some((index, tool_call))
1450 } else {
1451 None
1452 }
1453 })
1454 }
1455
1456 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1457 self.entries
1458 .iter()
1459 .enumerate()
1460 .rev()
1461 .find_map(|(index, tool_call)| {
1462 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1463 && &tool_call.id == id
1464 {
1465 Some((index, tool_call))
1466 } else {
1467 None
1468 }
1469 })
1470 }
1471
1472 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1473 let project = self.project.clone();
1474 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1475 return;
1476 };
1477 let task = tool_call.resolve_locations(project, cx);
1478 cx.spawn(async move |this, cx| {
1479 let resolved_locations = task.await;
1480
1481 this.update(cx, |this, cx| {
1482 let project = this.project.clone();
1483
1484 for location in resolved_locations.iter().flatten() {
1485 this.shared_buffers
1486 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1487 }
1488 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1489 return;
1490 };
1491
1492 if let Some(Some(location)) = resolved_locations.last() {
1493 project.update(cx, |project, cx| {
1494 let should_ignore = if let Some(agent_location) = project
1495 .agent_location()
1496 .filter(|agent_location| agent_location.buffer == location.buffer)
1497 {
1498 let snapshot = location.buffer.read(cx).snapshot();
1499 let old_position = agent_location.position.to_point(&snapshot);
1500 let new_position = location.position.to_point(&snapshot);
1501
1502 // ignore this so that when we get updates from the edit tool
1503 // the position doesn't reset to the startof line
1504 old_position.row == new_position.row
1505 && old_position.column > new_position.column
1506 } else {
1507 false
1508 };
1509 if !should_ignore {
1510 project.set_agent_location(Some(location.into()), cx);
1511 }
1512 });
1513 }
1514
1515 let resolved_locations = resolved_locations
1516 .iter()
1517 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1518 .collect::<Vec<_>>();
1519
1520 if tool_call.resolved_locations != resolved_locations {
1521 tool_call.resolved_locations = resolved_locations;
1522 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1523 }
1524 })
1525 })
1526 .detach();
1527 }
1528
1529 pub fn request_tool_call_authorization(
1530 &mut self,
1531 tool_call: acp::ToolCallUpdate,
1532 options: Vec<acp::PermissionOption>,
1533 respect_always_allow_setting: bool,
1534 cx: &mut Context<Self>,
1535 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1536 let (tx, rx) = oneshot::channel();
1537
1538 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1539 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1540 // some tools would (incorrectly) continue to auto-accept.
1541 if let Some(allow_once_option) = options.iter().find_map(|option| {
1542 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1543 Some(option.option_id.clone())
1544 } else {
1545 None
1546 }
1547 }) {
1548 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1549 return Ok(async {
1550 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1551 allow_once_option,
1552 ))
1553 }
1554 .boxed());
1555 }
1556 }
1557
1558 let status = ToolCallStatus::WaitingForConfirmation {
1559 options,
1560 respond_tx: tx,
1561 };
1562
1563 self.upsert_tool_call_inner(tool_call, status, cx)?;
1564 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1565
1566 let fut = async {
1567 match rx.await {
1568 Ok(option) => acp::RequestPermissionOutcome::Selected(
1569 acp::SelectedPermissionOutcome::new(option),
1570 ),
1571 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1572 }
1573 }
1574 .boxed();
1575
1576 Ok(fut)
1577 }
1578
1579 pub fn authorize_tool_call(
1580 &mut self,
1581 id: acp::ToolCallId,
1582 option_id: acp::PermissionOptionId,
1583 option_kind: acp::PermissionOptionKind,
1584 cx: &mut Context<Self>,
1585 ) {
1586 let Some((ix, call)) = self.tool_call_mut(&id) else {
1587 return;
1588 };
1589
1590 let new_status = match option_kind {
1591 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1592 ToolCallStatus::Rejected
1593 }
1594 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1595 ToolCallStatus::InProgress
1596 }
1597 _ => ToolCallStatus::InProgress,
1598 };
1599
1600 let curr_status = mem::replace(&mut call.status, new_status);
1601
1602 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1603 respond_tx.send(option_id).log_err();
1604 } else if cfg!(debug_assertions) {
1605 panic!("tried to authorize an already authorized tool call");
1606 }
1607
1608 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1609 }
1610
1611 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1612 let mut first_tool_call = None;
1613
1614 for entry in self.entries.iter().rev() {
1615 match &entry {
1616 AgentThreadEntry::ToolCall(call) => {
1617 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1618 first_tool_call = Some(call);
1619 } else {
1620 continue;
1621 }
1622 }
1623 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1624 // Reached the beginning of the turn.
1625 // If we had pending permission requests in the previous turn, they have been cancelled.
1626 break;
1627 }
1628 }
1629 }
1630
1631 first_tool_call
1632 }
1633
1634 pub fn plan(&self) -> &Plan {
1635 &self.plan
1636 }
1637
1638 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1639 let new_entries_len = request.entries.len();
1640 let mut new_entries = request.entries.into_iter();
1641
1642 // Reuse existing markdown to prevent flickering
1643 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1644 let PlanEntry {
1645 content,
1646 priority,
1647 status,
1648 } = old;
1649 content.update(cx, |old, cx| {
1650 old.replace(new.content, cx);
1651 });
1652 *priority = new.priority;
1653 *status = new.status;
1654 }
1655 for new in new_entries {
1656 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1657 }
1658 self.plan.entries.truncate(new_entries_len);
1659
1660 cx.notify();
1661 }
1662
1663 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1664 self.plan
1665 .entries
1666 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1667 cx.notify();
1668 }
1669
1670 #[cfg(any(test, feature = "test-support"))]
1671 pub fn send_raw(
1672 &mut self,
1673 message: &str,
1674 cx: &mut Context<Self>,
1675 ) -> BoxFuture<'static, Result<()>> {
1676 self.send(vec![message.into()], cx)
1677 }
1678
1679 pub fn send(
1680 &mut self,
1681 message: Vec<acp::ContentBlock>,
1682 cx: &mut Context<Self>,
1683 ) -> BoxFuture<'static, Result<()>> {
1684 let block = ContentBlock::new_combined(
1685 message.clone(),
1686 self.project.read(cx).languages().clone(),
1687 self.project.read(cx).path_style(cx),
1688 cx,
1689 );
1690 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1691 let git_store = self.project.read(cx).git_store().clone();
1692
1693 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1694 Some(UserMessageId::new())
1695 } else {
1696 None
1697 };
1698
1699 self.run_turn(cx, async move |this, cx| {
1700 this.update(cx, |this, cx| {
1701 this.push_entry(
1702 AgentThreadEntry::UserMessage(UserMessage {
1703 id: message_id.clone(),
1704 content: block,
1705 chunks: message,
1706 checkpoint: None,
1707 }),
1708 cx,
1709 );
1710 })
1711 .ok();
1712
1713 let old_checkpoint = git_store
1714 .update(cx, |git, cx| git.checkpoint(cx))?
1715 .await
1716 .context("failed to get old checkpoint")
1717 .log_err();
1718 this.update(cx, |this, cx| {
1719 if let Some((_ix, message)) = this.last_user_message() {
1720 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1721 git_checkpoint,
1722 show: false,
1723 });
1724 }
1725 this.connection.prompt(message_id, request, cx)
1726 })?
1727 .await
1728 })
1729 }
1730
1731 pub fn can_resume(&self, cx: &App) -> bool {
1732 self.connection.resume(&self.session_id, cx).is_some()
1733 }
1734
1735 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1736 self.run_turn(cx, async move |this, cx| {
1737 this.update(cx, |this, cx| {
1738 this.connection
1739 .resume(&this.session_id, cx)
1740 .map(|resume| resume.run(cx))
1741 })?
1742 .context("resuming a session is not supported")?
1743 .await
1744 })
1745 }
1746
1747 fn run_turn(
1748 &mut self,
1749 cx: &mut Context<Self>,
1750 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1751 ) -> BoxFuture<'static, Result<()>> {
1752 self.clear_completed_plan_entries(cx);
1753
1754 let (tx, rx) = oneshot::channel();
1755 let cancel_task = self.cancel(cx);
1756
1757 self.send_task = Some(cx.spawn(async move |this, cx| {
1758 cancel_task.await;
1759 tx.send(f(this, cx).await).ok();
1760 }));
1761
1762 cx.spawn(async move |this, cx| {
1763 let response = rx.await;
1764
1765 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1766 .await?;
1767
1768 this.update(cx, |this, cx| {
1769 this.project
1770 .update(cx, |project, cx| project.set_agent_location(None, cx));
1771 match response {
1772 Ok(Err(e)) => {
1773 this.send_task.take();
1774 cx.emit(AcpThreadEvent::Error);
1775 Err(e)
1776 }
1777 result => {
1778 let canceled = matches!(
1779 result,
1780 Ok(Ok(acp::PromptResponse {
1781 stop_reason: acp::StopReason::Cancelled,
1782 ..
1783 }))
1784 );
1785
1786 // We only take the task if the current prompt wasn't canceled.
1787 //
1788 // This prompt may have been canceled because another one was sent
1789 // while it was still generating. In these cases, dropping `send_task`
1790 // would cause the next generation to be canceled.
1791 if !canceled {
1792 this.send_task.take();
1793 }
1794
1795 // Handle refusal - distinguish between user prompt and tool call refusals
1796 if let Ok(Ok(acp::PromptResponse {
1797 stop_reason: acp::StopReason::Refusal,
1798 ..
1799 })) = result
1800 {
1801 if let Some((user_msg_ix, _)) = this.last_user_message() {
1802 // Check if there's a completed tool call with results after the last user message
1803 // This indicates the refusal is in response to tool output, not the user's prompt
1804 let has_completed_tool_call_after_user_msg =
1805 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1806 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1807 // Check if the tool call has completed and has output
1808 matches!(tool_call.status, ToolCallStatus::Completed)
1809 && tool_call.raw_output.is_some()
1810 } else {
1811 false
1812 }
1813 });
1814
1815 if has_completed_tool_call_after_user_msg {
1816 // Refusal is due to tool output - don't truncate, just notify
1817 // The model refused based on what the tool returned
1818 cx.emit(AcpThreadEvent::Refusal);
1819 } else {
1820 // User prompt was refused - truncate back to before the user message
1821 let range = user_msg_ix..this.entries.len();
1822 if range.start < range.end {
1823 this.entries.truncate(user_msg_ix);
1824 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1825 }
1826 cx.emit(AcpThreadEvent::Refusal);
1827 }
1828 } else {
1829 // No user message found, treat as general refusal
1830 cx.emit(AcpThreadEvent::Refusal);
1831 }
1832 }
1833
1834 cx.emit(AcpThreadEvent::Stopped);
1835 Ok(())
1836 }
1837 }
1838 })?
1839 })
1840 .boxed()
1841 }
1842
1843 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1844 let Some(send_task) = self.send_task.take() else {
1845 return Task::ready(());
1846 };
1847
1848 for entry in self.entries.iter_mut() {
1849 if let AgentThreadEntry::ToolCall(call) = entry {
1850 let cancel = matches!(
1851 call.status,
1852 ToolCallStatus::Pending
1853 | ToolCallStatus::WaitingForConfirmation { .. }
1854 | ToolCallStatus::InProgress
1855 );
1856
1857 if cancel {
1858 call.status = ToolCallStatus::Canceled;
1859 }
1860 }
1861 }
1862
1863 self.connection.cancel(&self.session_id, cx);
1864
1865 // Wait for the send task to complete
1866 cx.foreground_executor().spawn(send_task)
1867 }
1868
1869 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1870 pub fn restore_checkpoint(
1871 &mut self,
1872 id: UserMessageId,
1873 cx: &mut Context<Self>,
1874 ) -> Task<Result<()>> {
1875 let Some((_, message)) = self.user_message_mut(&id) else {
1876 return Task::ready(Err(anyhow!("message not found")));
1877 };
1878
1879 let checkpoint = message
1880 .checkpoint
1881 .as_ref()
1882 .map(|c| c.git_checkpoint.clone());
1883
1884 // Cancel any in-progress generation before restoring
1885 let cancel_task = self.cancel(cx);
1886 let rewind = self.rewind(id.clone(), cx);
1887 let git_store = self.project.read(cx).git_store().clone();
1888
1889 cx.spawn(async move |_, cx| {
1890 cancel_task.await;
1891 rewind.await?;
1892 if let Some(checkpoint) = checkpoint {
1893 git_store
1894 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1895 .await?;
1896 }
1897
1898 Ok(())
1899 })
1900 }
1901
1902 /// Rewinds this thread to before the entry at `index`, removing it and all
1903 /// subsequent entries while rejecting any action_log changes made from that point.
1904 /// Unlike `restore_checkpoint`, this method does not restore from git.
1905 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1906 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1907 return Task::ready(Err(anyhow!("not supported")));
1908 };
1909
1910 let telemetry = ActionLogTelemetry::from(&*self);
1911 cx.spawn(async move |this, cx| {
1912 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1913 this.update(cx, |this, cx| {
1914 if let Some((ix, _)) = this.user_message_mut(&id) {
1915 // Collect all terminals from entries that will be removed
1916 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1917 .iter()
1918 .flat_map(|entry| entry.terminals())
1919 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1920 .collect();
1921
1922 let range = ix..this.entries.len();
1923 this.entries.truncate(ix);
1924 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1925
1926 // Kill and remove the terminals
1927 for terminal_id in terminals_to_remove {
1928 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1929 terminal.update(cx, |terminal, cx| {
1930 terminal.kill(cx);
1931 });
1932 }
1933 }
1934 }
1935 this.action_log().update(cx, |action_log, cx| {
1936 action_log.reject_all_edits(Some(telemetry), cx)
1937 })
1938 })?
1939 .await;
1940 Ok(())
1941 })
1942 }
1943
1944 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1945 let git_store = self.project.read(cx).git_store().clone();
1946
1947 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1948 if let Some(checkpoint) = message.checkpoint.as_ref() {
1949 checkpoint.git_checkpoint.clone()
1950 } else {
1951 return Task::ready(Ok(()));
1952 }
1953 } else {
1954 return Task::ready(Ok(()));
1955 };
1956
1957 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1958 cx.spawn(async move |this, cx| {
1959 let new_checkpoint = new_checkpoint
1960 .await
1961 .context("failed to get new checkpoint")
1962 .log_err();
1963 if let Some(new_checkpoint) = new_checkpoint {
1964 let equal = git_store
1965 .update(cx, |git, cx| {
1966 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1967 })?
1968 .await
1969 .unwrap_or(true);
1970 this.update(cx, |this, cx| {
1971 let (ix, message) = this.last_user_message().context("no user message")?;
1972 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1973 checkpoint.show = !equal;
1974 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1975 anyhow::Ok(())
1976 })??;
1977 }
1978
1979 Ok(())
1980 })
1981 }
1982
1983 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1984 self.entries
1985 .iter_mut()
1986 .enumerate()
1987 .rev()
1988 .find_map(|(ix, entry)| {
1989 if let AgentThreadEntry::UserMessage(message) = entry {
1990 Some((ix, message))
1991 } else {
1992 None
1993 }
1994 })
1995 }
1996
1997 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1998 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1999 if let AgentThreadEntry::UserMessage(message) = entry {
2000 if message.id.as_ref() == Some(id) {
2001 Some((ix, message))
2002 } else {
2003 None
2004 }
2005 } else {
2006 None
2007 }
2008 })
2009 }
2010
2011 pub fn read_text_file(
2012 &self,
2013 path: PathBuf,
2014 line: Option<u32>,
2015 limit: Option<u32>,
2016 reuse_shared_snapshot: bool,
2017 cx: &mut Context<Self>,
2018 ) -> Task<Result<String, acp::Error>> {
2019 // Args are 1-based, move to 0-based
2020 let line = line.unwrap_or_default().saturating_sub(1);
2021 let limit = limit.unwrap_or(u32::MAX);
2022 let project = self.project.clone();
2023 let action_log = self.action_log.clone();
2024 cx.spawn(async move |this, cx| {
2025 let load = project
2026 .update(cx, |project, cx| {
2027 let path = project
2028 .project_path_for_absolute_path(&path, cx)
2029 .ok_or_else(|| {
2030 acp::Error::resource_not_found(Some(path.display().to_string()))
2031 })?;
2032 Ok(project.open_buffer(path, cx))
2033 })
2034 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2035 .flatten()?;
2036
2037 let buffer = load.await?;
2038
2039 let snapshot = if reuse_shared_snapshot {
2040 this.read_with(cx, |this, _| {
2041 this.shared_buffers.get(&buffer.clone()).cloned()
2042 })
2043 .log_err()
2044 .flatten()
2045 } else {
2046 None
2047 };
2048
2049 let snapshot = if let Some(snapshot) = snapshot {
2050 snapshot
2051 } else {
2052 action_log.update(cx, |action_log, cx| {
2053 action_log.buffer_read(buffer.clone(), cx);
2054 })?;
2055
2056 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2057 this.update(cx, |this, _| {
2058 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2059 })?;
2060 snapshot
2061 };
2062
2063 let max_point = snapshot.max_point();
2064 let start_position = Point::new(line, 0);
2065
2066 if start_position > max_point {
2067 return Err(acp::Error::invalid_params().data(format!(
2068 "Attempting to read beyond the end of the file, line {}:{}",
2069 max_point.row + 1,
2070 max_point.column
2071 )));
2072 }
2073
2074 let start = snapshot.anchor_before(start_position);
2075 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2076
2077 project.update(cx, |project, cx| {
2078 project.set_agent_location(
2079 Some(AgentLocation {
2080 buffer: buffer.downgrade(),
2081 position: start,
2082 }),
2083 cx,
2084 );
2085 })?;
2086
2087 Ok(snapshot.text_for_range(start..end).collect::<String>())
2088 })
2089 }
2090
2091 pub fn write_text_file(
2092 &self,
2093 path: PathBuf,
2094 content: String,
2095 cx: &mut Context<Self>,
2096 ) -> Task<Result<()>> {
2097 let project = self.project.clone();
2098 let action_log = self.action_log.clone();
2099 cx.spawn(async move |this, cx| {
2100 let load = project.update(cx, |project, cx| {
2101 let path = project
2102 .project_path_for_absolute_path(&path, cx)
2103 .context("invalid path")?;
2104 anyhow::Ok(project.open_buffer(path, cx))
2105 });
2106 let buffer = load??.await?;
2107 let snapshot = this.update(cx, |this, cx| {
2108 this.shared_buffers
2109 .get(&buffer)
2110 .cloned()
2111 .unwrap_or_else(|| buffer.read(cx).snapshot())
2112 })?;
2113 let edits = cx
2114 .background_executor()
2115 .spawn(async move {
2116 let old_text = snapshot.text();
2117 text_diff(old_text.as_str(), &content)
2118 .into_iter()
2119 .map(|(range, replacement)| {
2120 (
2121 snapshot.anchor_after(range.start)
2122 ..snapshot.anchor_before(range.end),
2123 replacement,
2124 )
2125 })
2126 .collect::<Vec<_>>()
2127 })
2128 .await;
2129
2130 project.update(cx, |project, cx| {
2131 project.set_agent_location(
2132 Some(AgentLocation {
2133 buffer: buffer.downgrade(),
2134 position: edits
2135 .last()
2136 .map(|(range, _)| range.end)
2137 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2138 }),
2139 cx,
2140 );
2141 })?;
2142
2143 let format_on_save = cx.update(|cx| {
2144 action_log.update(cx, |action_log, cx| {
2145 action_log.buffer_read(buffer.clone(), cx);
2146 });
2147
2148 let format_on_save = buffer.update(cx, |buffer, cx| {
2149 buffer.edit(edits, None, cx);
2150
2151 let settings = language::language_settings::language_settings(
2152 buffer.language().map(|l| l.name()),
2153 buffer.file(),
2154 cx,
2155 );
2156
2157 settings.format_on_save != FormatOnSave::Off
2158 });
2159 action_log.update(cx, |action_log, cx| {
2160 action_log.buffer_edited(buffer.clone(), cx);
2161 });
2162 format_on_save
2163 })?;
2164
2165 if format_on_save {
2166 let format_task = project.update(cx, |project, cx| {
2167 project.format(
2168 HashSet::from_iter([buffer.clone()]),
2169 LspFormatTarget::Buffers,
2170 false,
2171 FormatTrigger::Save,
2172 cx,
2173 )
2174 })?;
2175 format_task.await.log_err();
2176
2177 action_log.update(cx, |action_log, cx| {
2178 action_log.buffer_edited(buffer.clone(), cx);
2179 })?;
2180 }
2181
2182 project
2183 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2184 .await
2185 })
2186 }
2187
2188 pub fn create_terminal(
2189 &self,
2190 command: String,
2191 args: Vec<String>,
2192 extra_env: Vec<acp::EnvVariable>,
2193 cwd: Option<PathBuf>,
2194 output_byte_limit: Option<u64>,
2195 cx: &mut Context<Self>,
2196 ) -> Task<Result<Entity<Terminal>>> {
2197 let env = match &cwd {
2198 Some(dir) => self.project.update(cx, |project, cx| {
2199 project.environment().update(cx, |env, cx| {
2200 env.directory_environment(dir.as_path().into(), cx)
2201 })
2202 }),
2203 None => Task::ready(None).shared(),
2204 };
2205 let env = cx.spawn(async move |_, _| {
2206 let mut env = env.await.unwrap_or_default();
2207 // Disables paging for `git` and hopefully other commands
2208 env.insert("PAGER".into(), "".into());
2209 for var in extra_env {
2210 env.insert(var.name, var.value);
2211 }
2212 env
2213 });
2214
2215 let project = self.project.clone();
2216 let language_registry = project.read(cx).languages().clone();
2217 let is_windows = project.read(cx).path_style(cx).is_windows();
2218
2219 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2220 let terminal_task = cx.spawn({
2221 let terminal_id = terminal_id.clone();
2222 async move |_this, cx| {
2223 let env = env.await;
2224 let shell = project
2225 .update(cx, |project, cx| {
2226 project
2227 .remote_client()
2228 .and_then(|r| r.read(cx).default_system_shell())
2229 })?
2230 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2231 let (task_command, task_args) =
2232 ShellBuilder::new(&Shell::Program(shell), is_windows)
2233 .redirect_stdin_to_dev_null()
2234 .build(Some(command.clone()), &args);
2235 let terminal = project
2236 .update(cx, |project, cx| {
2237 project.create_terminal_task(
2238 task::SpawnInTerminal {
2239 command: Some(task_command),
2240 args: task_args,
2241 cwd: cwd.clone(),
2242 env,
2243 ..Default::default()
2244 },
2245 cx,
2246 )
2247 })?
2248 .await?;
2249
2250 cx.new(|cx| {
2251 Terminal::new(
2252 terminal_id,
2253 &format!("{} {}", command, args.join(" ")),
2254 cwd,
2255 output_byte_limit.map(|l| l as usize),
2256 terminal,
2257 language_registry,
2258 cx,
2259 )
2260 })
2261 }
2262 });
2263
2264 cx.spawn(async move |this, cx| {
2265 let terminal = terminal_task.await?;
2266 this.update(cx, |this, _cx| {
2267 this.terminals.insert(terminal_id, terminal.clone());
2268 terminal
2269 })
2270 })
2271 }
2272
2273 pub fn kill_terminal(
2274 &mut self,
2275 terminal_id: acp::TerminalId,
2276 cx: &mut Context<Self>,
2277 ) -> Result<()> {
2278 self.terminals
2279 .get(&terminal_id)
2280 .context("Terminal not found")?
2281 .update(cx, |terminal, cx| {
2282 terminal.kill(cx);
2283 });
2284
2285 Ok(())
2286 }
2287
2288 pub fn release_terminal(
2289 &mut self,
2290 terminal_id: acp::TerminalId,
2291 cx: &mut Context<Self>,
2292 ) -> Result<()> {
2293 self.terminals
2294 .remove(&terminal_id)
2295 .context("Terminal not found")?
2296 .update(cx, |terminal, cx| {
2297 terminal.kill(cx);
2298 });
2299
2300 Ok(())
2301 }
2302
2303 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2304 self.terminals
2305 .get(&terminal_id)
2306 .context("Terminal not found")
2307 .cloned()
2308 }
2309
2310 pub fn to_markdown(&self, cx: &App) -> String {
2311 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2312 }
2313
2314 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2315 cx.emit(AcpThreadEvent::LoadError(error));
2316 }
2317
2318 pub fn register_terminal_created(
2319 &mut self,
2320 terminal_id: acp::TerminalId,
2321 command_label: String,
2322 working_dir: Option<PathBuf>,
2323 output_byte_limit: Option<u64>,
2324 terminal: Entity<::terminal::Terminal>,
2325 cx: &mut Context<Self>,
2326 ) -> Entity<Terminal> {
2327 let language_registry = self.project.read(cx).languages().clone();
2328
2329 let entity = cx.new(|cx| {
2330 Terminal::new(
2331 terminal_id.clone(),
2332 &command_label,
2333 working_dir.clone(),
2334 output_byte_limit.map(|l| l as usize),
2335 terminal,
2336 language_registry,
2337 cx,
2338 )
2339 });
2340 self.terminals.insert(terminal_id.clone(), entity.clone());
2341 entity
2342 }
2343}
2344
2345fn markdown_for_raw_output(
2346 raw_output: &serde_json::Value,
2347 language_registry: &Arc<LanguageRegistry>,
2348 cx: &mut App,
2349) -> Option<Entity<Markdown>> {
2350 match raw_output {
2351 serde_json::Value::Null => None,
2352 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2353 Markdown::new(
2354 value.to_string().into(),
2355 Some(language_registry.clone()),
2356 None,
2357 cx,
2358 )
2359 })),
2360 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2361 Markdown::new(
2362 value.to_string().into(),
2363 Some(language_registry.clone()),
2364 None,
2365 cx,
2366 )
2367 })),
2368 serde_json::Value::String(value) => Some(cx.new(|cx| {
2369 Markdown::new(
2370 value.clone().into(),
2371 Some(language_registry.clone()),
2372 None,
2373 cx,
2374 )
2375 })),
2376 value => Some(cx.new(|cx| {
2377 Markdown::new(
2378 format!("```json\n{}\n```", value).into(),
2379 Some(language_registry.clone()),
2380 None,
2381 cx,
2382 )
2383 })),
2384 }
2385}
2386
2387#[cfg(test)]
2388mod tests {
2389 use super::*;
2390 use anyhow::anyhow;
2391 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2392 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2393 use indoc::indoc;
2394 use project::{FakeFs, Fs};
2395 use rand::{distr, prelude::*};
2396 use serde_json::json;
2397 use settings::SettingsStore;
2398 use smol::stream::StreamExt as _;
2399 use std::{
2400 any::Any,
2401 cell::RefCell,
2402 path::Path,
2403 rc::Rc,
2404 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2405 time::Duration,
2406 };
2407 use util::path;
2408
2409 fn init_test(cx: &mut TestAppContext) {
2410 env_logger::try_init().ok();
2411 cx.update(|cx| {
2412 let settings_store = SettingsStore::test(cx);
2413 cx.set_global(settings_store);
2414 });
2415 }
2416
2417 #[gpui::test]
2418 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2419 init_test(cx);
2420
2421 let fs = FakeFs::new(cx.executor());
2422 let project = Project::test(fs, [], cx).await;
2423 let connection = Rc::new(FakeAgentConnection::new());
2424 let thread = cx
2425 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2426 .await
2427 .unwrap();
2428
2429 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2430
2431 // Send Output BEFORE Created - should be buffered by acp_thread
2432 thread.update(cx, |thread, cx| {
2433 thread.on_terminal_provider_event(
2434 TerminalProviderEvent::Output {
2435 terminal_id: terminal_id.clone(),
2436 data: b"hello buffered".to_vec(),
2437 },
2438 cx,
2439 );
2440 });
2441
2442 // Create a display-only terminal and then send Created
2443 let lower = cx.new(|cx| {
2444 let builder = ::terminal::TerminalBuilder::new_display_only(
2445 ::terminal::terminal_settings::CursorShape::default(),
2446 ::terminal::terminal_settings::AlternateScroll::On,
2447 None,
2448 0,
2449 )
2450 .unwrap();
2451 builder.subscribe(cx)
2452 });
2453
2454 thread.update(cx, |thread, cx| {
2455 thread.on_terminal_provider_event(
2456 TerminalProviderEvent::Created {
2457 terminal_id: terminal_id.clone(),
2458 label: "Buffered Test".to_string(),
2459 cwd: None,
2460 output_byte_limit: None,
2461 terminal: lower.clone(),
2462 },
2463 cx,
2464 );
2465 });
2466
2467 // After Created, buffered Output should have been flushed into the renderer
2468 let content = thread.read_with(cx, |thread, cx| {
2469 let term = thread.terminal(terminal_id.clone()).unwrap();
2470 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2471 });
2472
2473 assert!(
2474 content.contains("hello buffered"),
2475 "expected buffered output to render, got: {content}"
2476 );
2477 }
2478
2479 #[gpui::test]
2480 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2481 init_test(cx);
2482
2483 let fs = FakeFs::new(cx.executor());
2484 let project = Project::test(fs, [], cx).await;
2485 let connection = Rc::new(FakeAgentConnection::new());
2486 let thread = cx
2487 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2488 .await
2489 .unwrap();
2490
2491 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2492
2493 // Send Output BEFORE Created
2494 thread.update(cx, |thread, cx| {
2495 thread.on_terminal_provider_event(
2496 TerminalProviderEvent::Output {
2497 terminal_id: terminal_id.clone(),
2498 data: b"pre-exit data".to_vec(),
2499 },
2500 cx,
2501 );
2502 });
2503
2504 // Send Exit BEFORE Created
2505 thread.update(cx, |thread, cx| {
2506 thread.on_terminal_provider_event(
2507 TerminalProviderEvent::Exit {
2508 terminal_id: terminal_id.clone(),
2509 status: acp::TerminalExitStatus::new().exit_code(0),
2510 },
2511 cx,
2512 );
2513 });
2514
2515 // Now create a display-only lower-level terminal and send Created
2516 let lower = cx.new(|cx| {
2517 let builder = ::terminal::TerminalBuilder::new_display_only(
2518 ::terminal::terminal_settings::CursorShape::default(),
2519 ::terminal::terminal_settings::AlternateScroll::On,
2520 None,
2521 0,
2522 )
2523 .unwrap();
2524 builder.subscribe(cx)
2525 });
2526
2527 thread.update(cx, |thread, cx| {
2528 thread.on_terminal_provider_event(
2529 TerminalProviderEvent::Created {
2530 terminal_id: terminal_id.clone(),
2531 label: "Buffered Exit Test".to_string(),
2532 cwd: None,
2533 output_byte_limit: None,
2534 terminal: lower.clone(),
2535 },
2536 cx,
2537 );
2538 });
2539
2540 // Output should be present after Created (flushed from buffer)
2541 let content = thread.read_with(cx, |thread, cx| {
2542 let term = thread.terminal(terminal_id.clone()).unwrap();
2543 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2544 });
2545
2546 assert!(
2547 content.contains("pre-exit data"),
2548 "expected pre-exit data to render, got: {content}"
2549 );
2550 }
2551
2552 #[gpui::test]
2553 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2554 init_test(cx);
2555
2556 let fs = FakeFs::new(cx.executor());
2557 let project = Project::test(fs, [], cx).await;
2558 let connection = Rc::new(FakeAgentConnection::new());
2559 let thread = cx
2560 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2561 .await
2562 .unwrap();
2563
2564 // Test creating a new user message
2565 thread.update(cx, |thread, cx| {
2566 thread.push_user_content_block(None, "Hello, ".into(), cx);
2567 });
2568
2569 thread.update(cx, |thread, cx| {
2570 assert_eq!(thread.entries.len(), 1);
2571 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2572 assert_eq!(user_msg.id, None);
2573 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2574 } else {
2575 panic!("Expected UserMessage");
2576 }
2577 });
2578
2579 // Test appending to existing user message
2580 let message_1_id = UserMessageId::new();
2581 thread.update(cx, |thread, cx| {
2582 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2583 });
2584
2585 thread.update(cx, |thread, cx| {
2586 assert_eq!(thread.entries.len(), 1);
2587 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2588 assert_eq!(user_msg.id, Some(message_1_id));
2589 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2590 } else {
2591 panic!("Expected UserMessage");
2592 }
2593 });
2594
2595 // Test creating new user message after assistant message
2596 thread.update(cx, |thread, cx| {
2597 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2598 });
2599
2600 let message_2_id = UserMessageId::new();
2601 thread.update(cx, |thread, cx| {
2602 thread.push_user_content_block(
2603 Some(message_2_id.clone()),
2604 "New user message".into(),
2605 cx,
2606 );
2607 });
2608
2609 thread.update(cx, |thread, cx| {
2610 assert_eq!(thread.entries.len(), 3);
2611 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2612 assert_eq!(user_msg.id, Some(message_2_id));
2613 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2614 } else {
2615 panic!("Expected UserMessage at index 2");
2616 }
2617 });
2618 }
2619
2620 #[gpui::test]
2621 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2622 init_test(cx);
2623
2624 let fs = FakeFs::new(cx.executor());
2625 let project = Project::test(fs, [], cx).await;
2626 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2627 |_, thread, mut cx| {
2628 async move {
2629 thread.update(&mut cx, |thread, cx| {
2630 thread
2631 .handle_session_update(
2632 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2633 "Thinking ".into(),
2634 )),
2635 cx,
2636 )
2637 .unwrap();
2638 thread
2639 .handle_session_update(
2640 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2641 "hard!".into(),
2642 )),
2643 cx,
2644 )
2645 .unwrap();
2646 })?;
2647 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2648 }
2649 .boxed_local()
2650 },
2651 ));
2652
2653 let thread = cx
2654 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2655 .await
2656 .unwrap();
2657
2658 thread
2659 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2660 .await
2661 .unwrap();
2662
2663 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2664 assert_eq!(
2665 output,
2666 indoc! {r#"
2667 ## User
2668
2669 Hello from Zed!
2670
2671 ## Assistant
2672
2673 <thinking>
2674 Thinking hard!
2675 </thinking>
2676
2677 "#}
2678 );
2679 }
2680
2681 #[gpui::test]
2682 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2683 init_test(cx);
2684
2685 let fs = FakeFs::new(cx.executor());
2686 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2687 .await;
2688 let project = Project::test(fs.clone(), [], cx).await;
2689 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2690 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2691 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2692 move |_, thread, mut cx| {
2693 let read_file_tx = read_file_tx.clone();
2694 async move {
2695 let content = thread
2696 .update(&mut cx, |thread, cx| {
2697 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2698 })
2699 .unwrap()
2700 .await
2701 .unwrap();
2702 assert_eq!(content, "one\ntwo\nthree\n");
2703 read_file_tx.take().unwrap().send(()).unwrap();
2704 thread
2705 .update(&mut cx, |thread, cx| {
2706 thread.write_text_file(
2707 path!("/tmp/foo").into(),
2708 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2709 cx,
2710 )
2711 })
2712 .unwrap()
2713 .await
2714 .unwrap();
2715 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2716 }
2717 .boxed_local()
2718 },
2719 ));
2720
2721 let (worktree, pathbuf) = project
2722 .update(cx, |project, cx| {
2723 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2724 })
2725 .await
2726 .unwrap();
2727 let buffer = project
2728 .update(cx, |project, cx| {
2729 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2730 })
2731 .await
2732 .unwrap();
2733
2734 let thread = cx
2735 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2736 .await
2737 .unwrap();
2738
2739 let request = thread.update(cx, |thread, cx| {
2740 thread.send_raw("Extend the count in /tmp/foo", cx)
2741 });
2742 read_file_rx.await.ok();
2743 buffer.update(cx, |buffer, cx| {
2744 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2745 });
2746 cx.run_until_parked();
2747 assert_eq!(
2748 buffer.read_with(cx, |buffer, _| buffer.text()),
2749 "zero\none\ntwo\nthree\nfour\nfive\n"
2750 );
2751 assert_eq!(
2752 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2753 "zero\none\ntwo\nthree\nfour\nfive\n"
2754 );
2755 request.await.unwrap();
2756 }
2757
2758 #[gpui::test]
2759 async fn test_reading_from_line(cx: &mut TestAppContext) {
2760 init_test(cx);
2761
2762 let fs = FakeFs::new(cx.executor());
2763 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2764 .await;
2765 let project = Project::test(fs.clone(), [], cx).await;
2766 project
2767 .update(cx, |project, cx| {
2768 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2769 })
2770 .await
2771 .unwrap();
2772
2773 let connection = Rc::new(FakeAgentConnection::new());
2774
2775 let thread = cx
2776 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2777 .await
2778 .unwrap();
2779
2780 // Whole file
2781 let content = thread
2782 .update(cx, |thread, cx| {
2783 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2784 })
2785 .await
2786 .unwrap();
2787
2788 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2789
2790 // Only start line
2791 let content = thread
2792 .update(cx, |thread, cx| {
2793 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2794 })
2795 .await
2796 .unwrap();
2797
2798 assert_eq!(content, "three\nfour\n");
2799
2800 // Only limit
2801 let content = thread
2802 .update(cx, |thread, cx| {
2803 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2804 })
2805 .await
2806 .unwrap();
2807
2808 assert_eq!(content, "one\ntwo\n");
2809
2810 // Range
2811 let content = thread
2812 .update(cx, |thread, cx| {
2813 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2814 })
2815 .await
2816 .unwrap();
2817
2818 assert_eq!(content, "two\nthree\n");
2819
2820 // Invalid
2821 let err = thread
2822 .update(cx, |thread, cx| {
2823 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2824 })
2825 .await
2826 .unwrap_err();
2827
2828 assert_eq!(
2829 err.to_string(),
2830 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2831 );
2832 }
2833
2834 #[gpui::test]
2835 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2836 init_test(cx);
2837
2838 let fs = FakeFs::new(cx.executor());
2839 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2840 let project = Project::test(fs.clone(), [], cx).await;
2841 project
2842 .update(cx, |project, cx| {
2843 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2844 })
2845 .await
2846 .unwrap();
2847
2848 let connection = Rc::new(FakeAgentConnection::new());
2849
2850 let thread = cx
2851 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2852 .await
2853 .unwrap();
2854
2855 // Whole file
2856 let content = thread
2857 .update(cx, |thread, cx| {
2858 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2859 })
2860 .await
2861 .unwrap();
2862
2863 assert_eq!(content, "");
2864
2865 // Only start line
2866 let content = thread
2867 .update(cx, |thread, cx| {
2868 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2869 })
2870 .await
2871 .unwrap();
2872
2873 assert_eq!(content, "");
2874
2875 // Only limit
2876 let content = thread
2877 .update(cx, |thread, cx| {
2878 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2879 })
2880 .await
2881 .unwrap();
2882
2883 assert_eq!(content, "");
2884
2885 // Range
2886 let content = thread
2887 .update(cx, |thread, cx| {
2888 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2889 })
2890 .await
2891 .unwrap();
2892
2893 assert_eq!(content, "");
2894
2895 // Invalid
2896 let err = thread
2897 .update(cx, |thread, cx| {
2898 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2899 })
2900 .await
2901 .unwrap_err();
2902
2903 assert_eq!(
2904 err.to_string(),
2905 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2906 );
2907 }
2908 #[gpui::test]
2909 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2910 init_test(cx);
2911
2912 let fs = FakeFs::new(cx.executor());
2913 fs.insert_tree(path!("/tmp"), json!({})).await;
2914 let project = Project::test(fs.clone(), [], cx).await;
2915 project
2916 .update(cx, |project, cx| {
2917 project.find_or_create_worktree(path!("/tmp"), true, cx)
2918 })
2919 .await
2920 .unwrap();
2921
2922 let connection = Rc::new(FakeAgentConnection::new());
2923
2924 let thread = cx
2925 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2926 .await
2927 .unwrap();
2928
2929 // Out of project file
2930 let err = thread
2931 .update(cx, |thread, cx| {
2932 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2933 })
2934 .await
2935 .unwrap_err();
2936
2937 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
2938 }
2939
2940 #[gpui::test]
2941 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2942 init_test(cx);
2943
2944 let fs = FakeFs::new(cx.executor());
2945 let project = Project::test(fs, [], cx).await;
2946 let id = acp::ToolCallId::new("test");
2947
2948 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2949 let id = id.clone();
2950 move |_, thread, mut cx| {
2951 let id = id.clone();
2952 async move {
2953 thread
2954 .update(&mut cx, |thread, cx| {
2955 thread.handle_session_update(
2956 acp::SessionUpdate::ToolCall(
2957 acp::ToolCall::new(id.clone(), "Label")
2958 .kind(acp::ToolKind::Fetch)
2959 .status(acp::ToolCallStatus::InProgress),
2960 ),
2961 cx,
2962 )
2963 })
2964 .unwrap()
2965 .unwrap();
2966 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2967 }
2968 .boxed_local()
2969 }
2970 }));
2971
2972 let thread = cx
2973 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2974 .await
2975 .unwrap();
2976
2977 let request = thread.update(cx, |thread, cx| {
2978 thread.send_raw("Fetch https://example.com", cx)
2979 });
2980
2981 run_until_first_tool_call(&thread, cx).await;
2982
2983 thread.read_with(cx, |thread, _| {
2984 assert!(matches!(
2985 thread.entries[1],
2986 AgentThreadEntry::ToolCall(ToolCall {
2987 status: ToolCallStatus::InProgress,
2988 ..
2989 })
2990 ));
2991 });
2992
2993 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2994
2995 thread.read_with(cx, |thread, _| {
2996 assert!(matches!(
2997 &thread.entries[1],
2998 AgentThreadEntry::ToolCall(ToolCall {
2999 status: ToolCallStatus::Canceled,
3000 ..
3001 })
3002 ));
3003 });
3004
3005 thread
3006 .update(cx, |thread, cx| {
3007 thread.handle_session_update(
3008 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3009 id,
3010 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3011 )),
3012 cx,
3013 )
3014 })
3015 .unwrap();
3016
3017 request.await.unwrap();
3018
3019 thread.read_with(cx, |thread, _| {
3020 assert!(matches!(
3021 thread.entries[1],
3022 AgentThreadEntry::ToolCall(ToolCall {
3023 status: ToolCallStatus::Completed,
3024 ..
3025 })
3026 ));
3027 });
3028 }
3029
3030 #[gpui::test]
3031 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3032 init_test(cx);
3033 let fs = FakeFs::new(cx.background_executor.clone());
3034 fs.insert_tree(path!("/test"), json!({})).await;
3035 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3036
3037 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3038 move |_, thread, mut cx| {
3039 async move {
3040 thread
3041 .update(&mut cx, |thread, cx| {
3042 thread.handle_session_update(
3043 acp::SessionUpdate::ToolCall(
3044 acp::ToolCall::new("test", "Label")
3045 .kind(acp::ToolKind::Edit)
3046 .status(acp::ToolCallStatus::Completed)
3047 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3048 "/test/test.txt",
3049 "foo",
3050 ))]),
3051 ),
3052 cx,
3053 )
3054 })
3055 .unwrap()
3056 .unwrap();
3057 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3058 }
3059 .boxed_local()
3060 }
3061 }));
3062
3063 let thread = cx
3064 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3065 .await
3066 .unwrap();
3067
3068 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3069 .await
3070 .unwrap();
3071
3072 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3073 }
3074
3075 #[gpui::test(iterations = 10)]
3076 async fn test_checkpoints(cx: &mut TestAppContext) {
3077 init_test(cx);
3078 let fs = FakeFs::new(cx.background_executor.clone());
3079 fs.insert_tree(
3080 path!("/test"),
3081 json!({
3082 ".git": {}
3083 }),
3084 )
3085 .await;
3086 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3087
3088 let simulate_changes = Arc::new(AtomicBool::new(true));
3089 let next_filename = Arc::new(AtomicUsize::new(0));
3090 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3091 let simulate_changes = simulate_changes.clone();
3092 let next_filename = next_filename.clone();
3093 let fs = fs.clone();
3094 move |request, thread, mut cx| {
3095 let fs = fs.clone();
3096 let simulate_changes = simulate_changes.clone();
3097 let next_filename = next_filename.clone();
3098 async move {
3099 if simulate_changes.load(SeqCst) {
3100 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3101 fs.write(Path::new(&filename), b"").await?;
3102 }
3103
3104 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3105 panic!("expected text content block");
3106 };
3107 thread.update(&mut cx, |thread, cx| {
3108 thread
3109 .handle_session_update(
3110 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3111 content.text.to_uppercase().into(),
3112 )),
3113 cx,
3114 )
3115 .unwrap();
3116 })?;
3117 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3118 }
3119 .boxed_local()
3120 }
3121 }));
3122 let thread = cx
3123 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3124 .await
3125 .unwrap();
3126
3127 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3128 .await
3129 .unwrap();
3130 thread.read_with(cx, |thread, cx| {
3131 assert_eq!(
3132 thread.to_markdown(cx),
3133 indoc! {"
3134 ## User (checkpoint)
3135
3136 Lorem
3137
3138 ## Assistant
3139
3140 LOREM
3141
3142 "}
3143 );
3144 });
3145 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3146
3147 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3148 .await
3149 .unwrap();
3150 thread.read_with(cx, |thread, cx| {
3151 assert_eq!(
3152 thread.to_markdown(cx),
3153 indoc! {"
3154 ## User (checkpoint)
3155
3156 Lorem
3157
3158 ## Assistant
3159
3160 LOREM
3161
3162 ## User (checkpoint)
3163
3164 ipsum
3165
3166 ## Assistant
3167
3168 IPSUM
3169
3170 "}
3171 );
3172 });
3173 assert_eq!(
3174 fs.files(),
3175 vec![
3176 Path::new(path!("/test/file-0")),
3177 Path::new(path!("/test/file-1"))
3178 ]
3179 );
3180
3181 // Checkpoint isn't stored when there are no changes.
3182 simulate_changes.store(false, SeqCst);
3183 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3184 .await
3185 .unwrap();
3186 thread.read_with(cx, |thread, cx| {
3187 assert_eq!(
3188 thread.to_markdown(cx),
3189 indoc! {"
3190 ## User (checkpoint)
3191
3192 Lorem
3193
3194 ## Assistant
3195
3196 LOREM
3197
3198 ## User (checkpoint)
3199
3200 ipsum
3201
3202 ## Assistant
3203
3204 IPSUM
3205
3206 ## User
3207
3208 dolor
3209
3210 ## Assistant
3211
3212 DOLOR
3213
3214 "}
3215 );
3216 });
3217 assert_eq!(
3218 fs.files(),
3219 vec![
3220 Path::new(path!("/test/file-0")),
3221 Path::new(path!("/test/file-1"))
3222 ]
3223 );
3224
3225 // Rewinding the conversation truncates the history and restores the checkpoint.
3226 thread
3227 .update(cx, |thread, cx| {
3228 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3229 panic!("unexpected entries {:?}", thread.entries)
3230 };
3231 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3232 })
3233 .await
3234 .unwrap();
3235 thread.read_with(cx, |thread, cx| {
3236 assert_eq!(
3237 thread.to_markdown(cx),
3238 indoc! {"
3239 ## User (checkpoint)
3240
3241 Lorem
3242
3243 ## Assistant
3244
3245 LOREM
3246
3247 "}
3248 );
3249 });
3250 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3251 }
3252
3253 #[gpui::test]
3254 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3255 use std::sync::atomic::AtomicUsize;
3256 init_test(cx);
3257
3258 let fs = FakeFs::new(cx.executor());
3259 let project = Project::test(fs, None, cx).await;
3260
3261 // Create a connection that simulates refusal after tool result
3262 let prompt_count = Arc::new(AtomicUsize::new(0));
3263 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3264 let prompt_count = prompt_count.clone();
3265 move |_request, thread, mut cx| {
3266 let count = prompt_count.fetch_add(1, SeqCst);
3267 async move {
3268 if count == 0 {
3269 // First prompt: Generate a tool call with result
3270 thread.update(&mut cx, |thread, cx| {
3271 thread
3272 .handle_session_update(
3273 acp::SessionUpdate::ToolCall(
3274 acp::ToolCall::new("tool1", "Test Tool")
3275 .kind(acp::ToolKind::Fetch)
3276 .status(acp::ToolCallStatus::Completed)
3277 .raw_input(serde_json::json!({"query": "test"}))
3278 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3279 ),
3280 cx,
3281 )
3282 .unwrap();
3283 })?;
3284
3285 // Now return refusal because of the tool result
3286 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3287 } else {
3288 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3289 }
3290 }
3291 .boxed_local()
3292 }
3293 }));
3294
3295 let thread = cx
3296 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3297 .await
3298 .unwrap();
3299
3300 // Track if we see a Refusal event
3301 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3302 let saw_refusal_event_captured = saw_refusal_event.clone();
3303 thread.update(cx, |_thread, cx| {
3304 cx.subscribe(
3305 &thread,
3306 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3307 if matches!(event, AcpThreadEvent::Refusal) {
3308 *saw_refusal_event_captured.lock().unwrap() = true;
3309 }
3310 },
3311 )
3312 .detach();
3313 });
3314
3315 // Send a user message - this will trigger tool call and then refusal
3316 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3317 cx.background_executor.spawn(send_task).detach();
3318 cx.run_until_parked();
3319
3320 // Verify that:
3321 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3322 // 2. The user message was NOT truncated
3323 assert!(
3324 *saw_refusal_event.lock().unwrap(),
3325 "Refusal event should be emitted for tool result refusals"
3326 );
3327
3328 thread.read_with(cx, |thread, _| {
3329 let entries = thread.entries();
3330 assert!(entries.len() >= 2, "Should have user message and tool call");
3331
3332 // Verify user message is still there
3333 assert!(
3334 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3335 "User message should not be truncated"
3336 );
3337
3338 // Verify tool call is there with result
3339 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3340 assert!(
3341 tool_call.raw_output.is_some(),
3342 "Tool call should have output"
3343 );
3344 } else {
3345 panic!("Expected tool call at index 1");
3346 }
3347 });
3348 }
3349
3350 #[gpui::test]
3351 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3352 init_test(cx);
3353
3354 let fs = FakeFs::new(cx.executor());
3355 let project = Project::test(fs, None, cx).await;
3356
3357 let refuse_next = Arc::new(AtomicBool::new(false));
3358 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3359 let refuse_next = refuse_next.clone();
3360 move |_request, _thread, _cx| {
3361 if refuse_next.load(SeqCst) {
3362 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3363 .boxed_local()
3364 } else {
3365 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3366 .boxed_local()
3367 }
3368 }
3369 }));
3370
3371 let thread = cx
3372 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3373 .await
3374 .unwrap();
3375
3376 // Track if we see a Refusal event
3377 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3378 let saw_refusal_event_captured = saw_refusal_event.clone();
3379 thread.update(cx, |_thread, cx| {
3380 cx.subscribe(
3381 &thread,
3382 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3383 if matches!(event, AcpThreadEvent::Refusal) {
3384 *saw_refusal_event_captured.lock().unwrap() = true;
3385 }
3386 },
3387 )
3388 .detach();
3389 });
3390
3391 // Send a message that will be refused
3392 refuse_next.store(true, SeqCst);
3393 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3394 .await
3395 .unwrap();
3396
3397 // Verify that a Refusal event WAS emitted for user prompt refusal
3398 assert!(
3399 *saw_refusal_event.lock().unwrap(),
3400 "Refusal event should be emitted for user prompt refusals"
3401 );
3402
3403 // Verify the message was truncated (user prompt refusal)
3404 thread.read_with(cx, |thread, cx| {
3405 assert_eq!(thread.to_markdown(cx), "");
3406 });
3407 }
3408
3409 #[gpui::test]
3410 async fn test_refusal(cx: &mut TestAppContext) {
3411 init_test(cx);
3412 let fs = FakeFs::new(cx.background_executor.clone());
3413 fs.insert_tree(path!("/"), json!({})).await;
3414 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3415
3416 let refuse_next = Arc::new(AtomicBool::new(false));
3417 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3418 let refuse_next = refuse_next.clone();
3419 move |request, thread, mut cx| {
3420 let refuse_next = refuse_next.clone();
3421 async move {
3422 if refuse_next.load(SeqCst) {
3423 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3424 }
3425
3426 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3427 panic!("expected text content block");
3428 };
3429 thread.update(&mut cx, |thread, cx| {
3430 thread
3431 .handle_session_update(
3432 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3433 content.text.to_uppercase().into(),
3434 )),
3435 cx,
3436 )
3437 .unwrap();
3438 })?;
3439 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3440 }
3441 .boxed_local()
3442 }
3443 }));
3444 let thread = cx
3445 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3446 .await
3447 .unwrap();
3448
3449 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3450 .await
3451 .unwrap();
3452 thread.read_with(cx, |thread, cx| {
3453 assert_eq!(
3454 thread.to_markdown(cx),
3455 indoc! {"
3456 ## User
3457
3458 hello
3459
3460 ## Assistant
3461
3462 HELLO
3463
3464 "}
3465 );
3466 });
3467
3468 // Simulate refusing the second message. The message should be truncated
3469 // when a user prompt is refused.
3470 refuse_next.store(true, SeqCst);
3471 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3472 .await
3473 .unwrap();
3474 thread.read_with(cx, |thread, cx| {
3475 assert_eq!(
3476 thread.to_markdown(cx),
3477 indoc! {"
3478 ## User
3479
3480 hello
3481
3482 ## Assistant
3483
3484 HELLO
3485
3486 "}
3487 );
3488 });
3489 }
3490
3491 async fn run_until_first_tool_call(
3492 thread: &Entity<AcpThread>,
3493 cx: &mut TestAppContext,
3494 ) -> usize {
3495 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3496
3497 let subscription = cx.update(|cx| {
3498 cx.subscribe(thread, move |thread, _, cx| {
3499 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3500 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3501 return tx.try_send(ix).unwrap();
3502 }
3503 }
3504 })
3505 });
3506
3507 select! {
3508 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3509 panic!("Timeout waiting for tool call")
3510 }
3511 ix = rx.next().fuse() => {
3512 drop(subscription);
3513 ix.unwrap()
3514 }
3515 }
3516 }
3517
3518 #[derive(Clone, Default)]
3519 struct FakeAgentConnection {
3520 auth_methods: Vec<acp::AuthMethod>,
3521 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3522 on_user_message: Option<
3523 Rc<
3524 dyn Fn(
3525 acp::PromptRequest,
3526 WeakEntity<AcpThread>,
3527 AsyncApp,
3528 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3529 + 'static,
3530 >,
3531 >,
3532 }
3533
3534 impl FakeAgentConnection {
3535 fn new() -> Self {
3536 Self {
3537 auth_methods: Vec::new(),
3538 on_user_message: None,
3539 sessions: Arc::default(),
3540 }
3541 }
3542
3543 #[expect(unused)]
3544 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3545 self.auth_methods = auth_methods;
3546 self
3547 }
3548
3549 fn on_user_message(
3550 mut self,
3551 handler: impl Fn(
3552 acp::PromptRequest,
3553 WeakEntity<AcpThread>,
3554 AsyncApp,
3555 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3556 + 'static,
3557 ) -> Self {
3558 self.on_user_message.replace(Rc::new(handler));
3559 self
3560 }
3561 }
3562
3563 impl AgentConnection for FakeAgentConnection {
3564 fn telemetry_id(&self) -> SharedString {
3565 "fake".into()
3566 }
3567
3568 fn auth_methods(&self) -> &[acp::AuthMethod] {
3569 &self.auth_methods
3570 }
3571
3572 fn new_thread(
3573 self: Rc<Self>,
3574 project: Entity<Project>,
3575 _cwd: &Path,
3576 cx: &mut App,
3577 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3578 let session_id = acp::SessionId::new(
3579 rand::rng()
3580 .sample_iter(&distr::Alphanumeric)
3581 .take(7)
3582 .map(char::from)
3583 .collect::<String>(),
3584 );
3585 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3586 let thread = cx.new(|cx| {
3587 AcpThread::new(
3588 "Test",
3589 self.clone(),
3590 project,
3591 action_log,
3592 session_id.clone(),
3593 watch::Receiver::constant(
3594 acp::PromptCapabilities::new()
3595 .image(true)
3596 .audio(true)
3597 .embedded_context(true),
3598 ),
3599 cx,
3600 )
3601 });
3602 self.sessions.lock().insert(session_id, thread.downgrade());
3603 Task::ready(Ok(thread))
3604 }
3605
3606 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3607 if self.auth_methods().iter().any(|m| m.id == method) {
3608 Task::ready(Ok(()))
3609 } else {
3610 Task::ready(Err(anyhow!("Invalid Auth Method")))
3611 }
3612 }
3613
3614 fn prompt(
3615 &self,
3616 _id: Option<UserMessageId>,
3617 params: acp::PromptRequest,
3618 cx: &mut App,
3619 ) -> Task<gpui::Result<acp::PromptResponse>> {
3620 let sessions = self.sessions.lock();
3621 let thread = sessions.get(¶ms.session_id).unwrap();
3622 if let Some(handler) = &self.on_user_message {
3623 let handler = handler.clone();
3624 let thread = thread.clone();
3625 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3626 } else {
3627 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3628 }
3629 }
3630
3631 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3632 let sessions = self.sessions.lock();
3633 let thread = sessions.get(session_id).unwrap().clone();
3634
3635 cx.spawn(async move |cx| {
3636 thread
3637 .update(cx, |thread, cx| thread.cancel(cx))
3638 .unwrap()
3639 .await
3640 })
3641 .detach();
3642 }
3643
3644 fn truncate(
3645 &self,
3646 session_id: &acp::SessionId,
3647 _cx: &App,
3648 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3649 Some(Rc::new(FakeAgentSessionEditor {
3650 _session_id: session_id.clone(),
3651 }))
3652 }
3653
3654 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3655 self
3656 }
3657 }
3658
3659 struct FakeAgentSessionEditor {
3660 _session_id: acp::SessionId,
3661 }
3662
3663 impl AgentSessionTruncate for FakeAgentSessionEditor {
3664 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3665 Task::ready(Ok(()))
3666 }
3667 }
3668
3669 #[gpui::test]
3670 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3671 init_test(cx);
3672
3673 let fs = FakeFs::new(cx.executor());
3674 let project = Project::test(fs, [], cx).await;
3675 let connection = Rc::new(FakeAgentConnection::new());
3676 let thread = cx
3677 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3678 .await
3679 .unwrap();
3680
3681 // Try to update a tool call that doesn't exist
3682 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3683 thread.update(cx, |thread, cx| {
3684 let result = thread.handle_session_update(
3685 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3686 nonexistent_id.clone(),
3687 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3688 )),
3689 cx,
3690 );
3691
3692 // The update should succeed (not return an error)
3693 assert!(result.is_ok());
3694
3695 // There should now be exactly one entry in the thread
3696 assert_eq!(thread.entries.len(), 1);
3697
3698 // The entry should be a failed tool call
3699 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3700 assert_eq!(tool_call.id, nonexistent_id);
3701 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3702 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3703
3704 // Check that the content contains the error message
3705 assert_eq!(tool_call.content.len(), 1);
3706 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3707 match content_block {
3708 ContentBlock::Markdown { markdown } => {
3709 let markdown_text = markdown.read(cx).source();
3710 assert!(markdown_text.contains("Tool call not found"));
3711 }
3712 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3713 ContentBlock::ResourceLink { .. } => {
3714 panic!("Expected markdown content, got resource link")
3715 }
3716 }
3717 } else {
3718 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3719 }
3720 } else {
3721 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3722 }
3723 });
3724 }
3725
3726 /// Tests that restoring a checkpoint properly cleans up terminals that were
3727 /// created after that checkpoint, and cancels any in-progress generation.
3728 ///
3729 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3730 /// that were started after that checkpoint should be terminated, and any in-progress
3731 /// AI generation should be canceled.
3732 #[gpui::test]
3733 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3734 init_test(cx);
3735
3736 let fs = FakeFs::new(cx.executor());
3737 let project = Project::test(fs, [], cx).await;
3738 let connection = Rc::new(FakeAgentConnection::new());
3739 let thread = cx
3740 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3741 .await
3742 .unwrap();
3743
3744 // Send first user message to create a checkpoint
3745 cx.update(|cx| {
3746 thread.update(cx, |thread, cx| {
3747 thread.send(vec!["first message".into()], cx)
3748 })
3749 })
3750 .await
3751 .unwrap();
3752
3753 // Send second message (creates another checkpoint) - we'll restore to this one
3754 cx.update(|cx| {
3755 thread.update(cx, |thread, cx| {
3756 thread.send(vec!["second message".into()], cx)
3757 })
3758 })
3759 .await
3760 .unwrap();
3761
3762 // Create 2 terminals BEFORE the checkpoint that have completed running
3763 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3764 let mock_terminal_1 = cx.new(|cx| {
3765 let builder = ::terminal::TerminalBuilder::new_display_only(
3766 ::terminal::terminal_settings::CursorShape::default(),
3767 ::terminal::terminal_settings::AlternateScroll::On,
3768 None,
3769 0,
3770 )
3771 .unwrap();
3772 builder.subscribe(cx)
3773 });
3774
3775 thread.update(cx, |thread, cx| {
3776 thread.on_terminal_provider_event(
3777 TerminalProviderEvent::Created {
3778 terminal_id: terminal_id_1.clone(),
3779 label: "echo 'first'".to_string(),
3780 cwd: Some(PathBuf::from("/test")),
3781 output_byte_limit: None,
3782 terminal: mock_terminal_1.clone(),
3783 },
3784 cx,
3785 );
3786 });
3787
3788 thread.update(cx, |thread, cx| {
3789 thread.on_terminal_provider_event(
3790 TerminalProviderEvent::Output {
3791 terminal_id: terminal_id_1.clone(),
3792 data: b"first\n".to_vec(),
3793 },
3794 cx,
3795 );
3796 });
3797
3798 thread.update(cx, |thread, cx| {
3799 thread.on_terminal_provider_event(
3800 TerminalProviderEvent::Exit {
3801 terminal_id: terminal_id_1.clone(),
3802 status: acp::TerminalExitStatus::new().exit_code(0),
3803 },
3804 cx,
3805 );
3806 });
3807
3808 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3809 let mock_terminal_2 = cx.new(|cx| {
3810 let builder = ::terminal::TerminalBuilder::new_display_only(
3811 ::terminal::terminal_settings::CursorShape::default(),
3812 ::terminal::terminal_settings::AlternateScroll::On,
3813 None,
3814 0,
3815 )
3816 .unwrap();
3817 builder.subscribe(cx)
3818 });
3819
3820 thread.update(cx, |thread, cx| {
3821 thread.on_terminal_provider_event(
3822 TerminalProviderEvent::Created {
3823 terminal_id: terminal_id_2.clone(),
3824 label: "echo 'second'".to_string(),
3825 cwd: Some(PathBuf::from("/test")),
3826 output_byte_limit: None,
3827 terminal: mock_terminal_2.clone(),
3828 },
3829 cx,
3830 );
3831 });
3832
3833 thread.update(cx, |thread, cx| {
3834 thread.on_terminal_provider_event(
3835 TerminalProviderEvent::Output {
3836 terminal_id: terminal_id_2.clone(),
3837 data: b"second\n".to_vec(),
3838 },
3839 cx,
3840 );
3841 });
3842
3843 thread.update(cx, |thread, cx| {
3844 thread.on_terminal_provider_event(
3845 TerminalProviderEvent::Exit {
3846 terminal_id: terminal_id_2.clone(),
3847 status: acp::TerminalExitStatus::new().exit_code(0),
3848 },
3849 cx,
3850 );
3851 });
3852
3853 // Get the second message ID to restore to
3854 let second_message_id = thread.read_with(cx, |thread, _| {
3855 // At this point we have:
3856 // - Index 0: First user message (with checkpoint)
3857 // - Index 1: Second user message (with checkpoint)
3858 // No assistant responses because FakeAgentConnection just returns EndTurn
3859 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3860 panic!("expected user message at index 1");
3861 };
3862 message.id.clone().unwrap()
3863 });
3864
3865 // Create a terminal AFTER the checkpoint we'll restore to.
3866 // This simulates the AI agent starting a long-running terminal command.
3867 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3868 let mock_terminal = cx.new(|cx| {
3869 let builder = ::terminal::TerminalBuilder::new_display_only(
3870 ::terminal::terminal_settings::CursorShape::default(),
3871 ::terminal::terminal_settings::AlternateScroll::On,
3872 None,
3873 0,
3874 )
3875 .unwrap();
3876 builder.subscribe(cx)
3877 });
3878
3879 // Register the terminal as created
3880 thread.update(cx, |thread, cx| {
3881 thread.on_terminal_provider_event(
3882 TerminalProviderEvent::Created {
3883 terminal_id: terminal_id.clone(),
3884 label: "sleep 1000".to_string(),
3885 cwd: Some(PathBuf::from("/test")),
3886 output_byte_limit: None,
3887 terminal: mock_terminal.clone(),
3888 },
3889 cx,
3890 );
3891 });
3892
3893 // Simulate the terminal producing output (still running)
3894 thread.update(cx, |thread, cx| {
3895 thread.on_terminal_provider_event(
3896 TerminalProviderEvent::Output {
3897 terminal_id: terminal_id.clone(),
3898 data: b"terminal is running...\n".to_vec(),
3899 },
3900 cx,
3901 );
3902 });
3903
3904 // Create a tool call entry that references this terminal
3905 // This represents the agent requesting a terminal command
3906 thread.update(cx, |thread, cx| {
3907 thread
3908 .handle_session_update(
3909 acp::SessionUpdate::ToolCall(
3910 acp::ToolCall::new("terminal-tool-1", "Running command")
3911 .kind(acp::ToolKind::Execute)
3912 .status(acp::ToolCallStatus::InProgress)
3913 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
3914 terminal_id.clone(),
3915 ))])
3916 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
3917 ),
3918 cx,
3919 )
3920 .unwrap();
3921 });
3922
3923 // Verify terminal exists and is in the thread
3924 let terminal_exists_before =
3925 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
3926 assert!(
3927 terminal_exists_before,
3928 "Terminal should exist before checkpoint restore"
3929 );
3930
3931 // Verify the terminal's underlying task is still running (not completed)
3932 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
3933 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
3934 terminal_entity.read_with(cx, |term, _cx| {
3935 term.output().is_none() // output is None means it's still running
3936 })
3937 });
3938 assert!(
3939 terminal_running_before,
3940 "Terminal should be running before checkpoint restore"
3941 );
3942
3943 // Verify we have the expected entries before restore
3944 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
3945 assert!(
3946 entry_count_before > 1,
3947 "Should have multiple entries before restore"
3948 );
3949
3950 // Restore the checkpoint to the second message.
3951 // This should:
3952 // 1. Cancel any in-progress generation (via the cancel() call)
3953 // 2. Remove the terminal that was created after that point
3954 thread
3955 .update(cx, |thread, cx| {
3956 thread.restore_checkpoint(second_message_id, cx)
3957 })
3958 .await
3959 .unwrap();
3960
3961 // Verify that no send_task is in progress after restore
3962 // (cancel() clears the send_task)
3963 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
3964 assert!(
3965 !has_send_task_after,
3966 "Should not have a send_task after restore (cancel should have cleared it)"
3967 );
3968
3969 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
3970 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
3971 assert_eq!(
3972 entry_count, 1,
3973 "Should have 1 entry after restore (only the first user message)"
3974 );
3975
3976 // Verify the 2 completed terminals from before the checkpoint still exist
3977 let terminal_1_exists = thread.read_with(cx, |thread, _| {
3978 thread.terminals.contains_key(&terminal_id_1)
3979 });
3980 assert!(
3981 terminal_1_exists,
3982 "Terminal 1 (from before checkpoint) should still exist"
3983 );
3984
3985 let terminal_2_exists = thread.read_with(cx, |thread, _| {
3986 thread.terminals.contains_key(&terminal_id_2)
3987 });
3988 assert!(
3989 terminal_2_exists,
3990 "Terminal 2 (from before checkpoint) should still exist"
3991 );
3992
3993 // Verify they're still in completed state
3994 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
3995 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
3996 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
3997 });
3998 assert!(terminal_1_completed, "Terminal 1 should still be completed");
3999
4000 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4001 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4002 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4003 });
4004 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4005
4006 // Verify the running terminal (created after checkpoint) was removed
4007 let terminal_3_exists =
4008 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4009 assert!(
4010 !terminal_3_exists,
4011 "Terminal 3 (created after checkpoint) should have been removed"
4012 );
4013
4014 // Verify total count is 2 (the two from before the checkpoint)
4015 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4016 assert_eq!(
4017 terminal_count, 2,
4018 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4019 );
4020 }
4021}