1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15use task::{Shell, ShellBuilder};
16pub use terminal::*;
17
18use action_log::{ActionLog, ActionLogTelemetry};
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(
98 chunk: &str,
99 language_registry: &Arc<LanguageRegistry>,
100 path_style: PathStyle,
101 cx: &mut App,
102 ) -> Self {
103 Self::Message {
104 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
105 }
106 }
107
108 fn to_markdown(&self, cx: &App) -> String {
109 match self {
110 Self::Message { block } => block.to_markdown(cx).to_string(),
111 Self::Thought { block } => {
112 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
113 }
114 }
115 }
116}
117
118#[derive(Debug)]
119pub enum AgentThreadEntry {
120 UserMessage(UserMessage),
121 AssistantMessage(AssistantMessage),
122 ToolCall(ToolCall),
123}
124
125impl AgentThreadEntry {
126 pub fn to_markdown(&self, cx: &App) -> String {
127 match self {
128 Self::UserMessage(message) => message.to_markdown(cx),
129 Self::AssistantMessage(message) => message.to_markdown(cx),
130 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
131 }
132 }
133
134 pub fn user_message(&self) -> Option<&UserMessage> {
135 if let AgentThreadEntry::UserMessage(message) = self {
136 Some(message)
137 } else {
138 None
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
151 if let AgentThreadEntry::ToolCall(call) = self {
152 itertools::Either::Left(call.terminals())
153 } else {
154 itertools::Either::Right(std::iter::empty())
155 }
156 }
157
158 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
159 if let AgentThreadEntry::ToolCall(ToolCall {
160 locations,
161 resolved_locations,
162 ..
163 }) = self
164 {
165 Some((
166 locations.get(ix)?.clone(),
167 resolved_locations.get(ix)?.clone()?,
168 ))
169 } else {
170 None
171 }
172 }
173}
174
175#[derive(Debug)]
176pub struct ToolCall {
177 pub id: acp::ToolCallId,
178 pub label: Entity<Markdown>,
179 pub kind: acp::ToolKind,
180 pub content: Vec<ToolCallContent>,
181 pub status: ToolCallStatus,
182 pub locations: Vec<acp::ToolCallLocation>,
183 pub resolved_locations: Vec<Option<AgentLocation>>,
184 pub raw_input: Option<serde_json::Value>,
185 pub raw_output: Option<serde_json::Value>,
186}
187
188impl ToolCall {
189 fn from_acp(
190 tool_call: acp::ToolCall,
191 status: ToolCallStatus,
192 language_registry: Arc<LanguageRegistry>,
193 path_style: PathStyle,
194 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
195 cx: &mut App,
196 ) -> Result<Self> {
197 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
198 first_line.to_owned() + "…"
199 } else {
200 tool_call.title
201 };
202 let mut content = Vec::with_capacity(tool_call.content.len());
203 for item in tool_call.content {
204 content.push(ToolCallContent::from_acp(
205 item,
206 language_registry.clone(),
207 path_style,
208 terminals,
209 cx,
210 )?);
211 }
212
213 let result = Self {
214 id: tool_call.id,
215 label: cx
216 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
217 kind: tool_call.kind,
218 content,
219 locations: tool_call.locations,
220 resolved_locations: Vec::default(),
221 status,
222 raw_input: tool_call.raw_input,
223 raw_output: tool_call.raw_output,
224 };
225 Ok(result)
226 }
227
228 fn update_fields(
229 &mut self,
230 fields: acp::ToolCallUpdateFields,
231 language_registry: Arc<LanguageRegistry>,
232 path_style: PathStyle,
233 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
234 cx: &mut App,
235 ) -> Result<()> {
236 let acp::ToolCallUpdateFields {
237 kind,
238 status,
239 title,
240 content,
241 locations,
242 raw_input,
243 raw_output,
244 } = fields;
245
246 if let Some(kind) = kind {
247 self.kind = kind;
248 }
249
250 if let Some(status) = status {
251 self.status = status.into();
252 }
253
254 if let Some(title) = title {
255 self.label.update(cx, |label, cx| {
256 if let Some((first_line, _)) = title.split_once("\n") {
257 label.replace(first_line.to_owned() + "…", cx)
258 } else {
259 label.replace(title, cx);
260 }
261 });
262 }
263
264 if let Some(content) = content {
265 let new_content_len = content.len();
266 let mut content = content.into_iter();
267
268 // Reuse existing content if we can
269 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
270 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
271 }
272 for new in content {
273 self.content.push(ToolCallContent::from_acp(
274 new,
275 language_registry.clone(),
276 path_style,
277 terminals,
278 cx,
279 )?)
280 }
281 self.content.truncate(new_content_len);
282 }
283
284 if let Some(locations) = locations {
285 self.locations = locations;
286 }
287
288 if let Some(raw_input) = raw_input {
289 self.raw_input = Some(raw_input);
290 }
291
292 if let Some(raw_output) = raw_output {
293 if self.content.is_empty()
294 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
295 {
296 self.content
297 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
298 markdown,
299 }));
300 }
301 self.raw_output = Some(raw_output);
302 }
303 Ok(())
304 }
305
306 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
307 self.content.iter().filter_map(|content| match content {
308 ToolCallContent::Diff(diff) => Some(diff),
309 ToolCallContent::ContentBlock(_) => None,
310 ToolCallContent::Terminal(_) => None,
311 })
312 }
313
314 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
315 self.content.iter().filter_map(|content| match content {
316 ToolCallContent::Terminal(terminal) => Some(terminal),
317 ToolCallContent::ContentBlock(_) => None,
318 ToolCallContent::Diff(_) => None,
319 })
320 }
321
322 fn to_markdown(&self, cx: &App) -> String {
323 let mut markdown = format!(
324 "**Tool Call: {}**\nStatus: {}\n\n",
325 self.label.read(cx).source(),
326 self.status
327 );
328 for content in &self.content {
329 markdown.push_str(content.to_markdown(cx).as_str());
330 markdown.push_str("\n\n");
331 }
332 markdown
333 }
334
335 async fn resolve_location(
336 location: acp::ToolCallLocation,
337 project: WeakEntity<Project>,
338 cx: &mut AsyncApp,
339 ) -> Option<ResolvedLocation> {
340 let buffer = project
341 .update(cx, |project, cx| {
342 project
343 .project_path_for_absolute_path(&location.path, cx)
344 .map(|path| project.open_buffer(path, cx))
345 })
346 .ok()??;
347 let buffer = buffer.await.log_err()?;
348 let position = buffer
349 .update(cx, |buffer, _| {
350 if let Some(row) = location.line {
351 let snapshot = buffer.snapshot();
352 let column = snapshot.indent_size_for_line(row).len;
353 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
354 snapshot.anchor_before(point)
355 } else {
356 Anchor::MIN
357 }
358 })
359 .ok()?;
360
361 Some(ResolvedLocation { buffer, position })
362 }
363
364 fn resolve_locations(
365 &self,
366 project: Entity<Project>,
367 cx: &mut App,
368 ) -> Task<Vec<Option<ResolvedLocation>>> {
369 let locations = self.locations.clone();
370 project.update(cx, |_, cx| {
371 cx.spawn(async move |project, cx| {
372 let mut new_locations = Vec::new();
373 for location in locations {
374 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
375 }
376 new_locations
377 })
378 })
379 }
380}
381
382// Separate so we can hold a strong reference to the buffer
383// for saving on the thread
384#[derive(Clone, Debug, PartialEq, Eq)]
385struct ResolvedLocation {
386 buffer: Entity<Buffer>,
387 position: Anchor,
388}
389
390impl From<&ResolvedLocation> for AgentLocation {
391 fn from(value: &ResolvedLocation) -> Self {
392 Self {
393 buffer: value.buffer.downgrade(),
394 position: value.position,
395 }
396 }
397}
398
399#[derive(Debug)]
400pub enum ToolCallStatus {
401 /// The tool call hasn't started running yet, but we start showing it to
402 /// the user.
403 Pending,
404 /// The tool call is waiting for confirmation from the user.
405 WaitingForConfirmation {
406 options: Vec<acp::PermissionOption>,
407 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
408 },
409 /// The tool call is currently running.
410 InProgress,
411 /// The tool call completed successfully.
412 Completed,
413 /// The tool call failed.
414 Failed,
415 /// The user rejected the tool call.
416 Rejected,
417 /// The user canceled generation so the tool call was canceled.
418 Canceled,
419}
420
421impl From<acp::ToolCallStatus> for ToolCallStatus {
422 fn from(status: acp::ToolCallStatus) -> Self {
423 match status {
424 acp::ToolCallStatus::Pending => Self::Pending,
425 acp::ToolCallStatus::InProgress => Self::InProgress,
426 acp::ToolCallStatus::Completed => Self::Completed,
427 acp::ToolCallStatus::Failed => Self::Failed,
428 }
429 }
430}
431
432impl Display for ToolCallStatus {
433 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
434 write!(
435 f,
436 "{}",
437 match self {
438 ToolCallStatus::Pending => "Pending",
439 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
440 ToolCallStatus::InProgress => "In Progress",
441 ToolCallStatus::Completed => "Completed",
442 ToolCallStatus::Failed => "Failed",
443 ToolCallStatus::Rejected => "Rejected",
444 ToolCallStatus::Canceled => "Canceled",
445 }
446 )
447 }
448}
449
450#[derive(Debug, PartialEq, Clone)]
451pub enum ContentBlock {
452 Empty,
453 Markdown { markdown: Entity<Markdown> },
454 ResourceLink { resource_link: acp::ResourceLink },
455}
456
457impl ContentBlock {
458 pub fn new(
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 path_style: PathStyle,
462 cx: &mut App,
463 ) -> Self {
464 let mut this = Self::Empty;
465 this.append(block, language_registry, path_style, cx);
466 this
467 }
468
469 pub fn new_combined(
470 blocks: impl IntoIterator<Item = acp::ContentBlock>,
471 language_registry: Arc<LanguageRegistry>,
472 path_style: PathStyle,
473 cx: &mut App,
474 ) -> Self {
475 let mut this = Self::Empty;
476 for block in blocks {
477 this.append(block, &language_registry, path_style, cx);
478 }
479 this
480 }
481
482 pub fn append(
483 &mut self,
484 block: acp::ContentBlock,
485 language_registry: &Arc<LanguageRegistry>,
486 path_style: PathStyle,
487 cx: &mut App,
488 ) {
489 if matches!(self, ContentBlock::Empty)
490 && let acp::ContentBlock::ResourceLink(resource_link) = block
491 {
492 *self = ContentBlock::ResourceLink { resource_link };
493 return;
494 }
495
496 let new_content = self.block_string_contents(block, path_style);
497
498 match self {
499 ContentBlock::Empty => {
500 *self = Self::create_markdown_block(new_content, language_registry, cx);
501 }
502 ContentBlock::Markdown { markdown } => {
503 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
504 }
505 ContentBlock::ResourceLink { resource_link } => {
506 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
507 let combined = format!("{}\n{}", existing_content, new_content);
508
509 *self = Self::create_markdown_block(combined, language_registry, cx);
510 }
511 }
512 }
513
514 fn create_markdown_block(
515 content: String,
516 language_registry: &Arc<LanguageRegistry>,
517 cx: &mut App,
518 ) -> ContentBlock {
519 ContentBlock::Markdown {
520 markdown: cx
521 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
522 }
523 }
524
525 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
526 match block {
527 acp::ContentBlock::Text(text_content) => text_content.text,
528 acp::ContentBlock::ResourceLink(resource_link) => {
529 Self::resource_link_md(&resource_link.uri, path_style)
530 }
531 acp::ContentBlock::Resource(acp::EmbeddedResource {
532 resource:
533 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
534 uri,
535 ..
536 }),
537 ..
538 }) => Self::resource_link_md(&uri, path_style),
539 acp::ContentBlock::Image(image) => Self::image_md(&image),
540 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
541 }
542 }
543
544 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
545 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
546 uri.as_link().to_string()
547 } else {
548 uri.to_string()
549 }
550 }
551
552 fn image_md(_image: &acp::ImageContent) -> String {
553 "`Image`".into()
554 }
555
556 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
557 match self {
558 ContentBlock::Empty => "",
559 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
560 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
561 }
562 }
563
564 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
565 match self {
566 ContentBlock::Empty => None,
567 ContentBlock::Markdown { markdown } => Some(markdown),
568 ContentBlock::ResourceLink { .. } => None,
569 }
570 }
571
572 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
573 match self {
574 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
575 _ => None,
576 }
577 }
578}
579
580#[derive(Debug)]
581pub enum ToolCallContent {
582 ContentBlock(ContentBlock),
583 Diff(Entity<Diff>),
584 Terminal(Entity<Terminal>),
585}
586
587impl ToolCallContent {
588 pub fn from_acp(
589 content: acp::ToolCallContent,
590 language_registry: Arc<LanguageRegistry>,
591 path_style: PathStyle,
592 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
593 cx: &mut App,
594 ) -> Result<Self> {
595 match content {
596 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
597 content,
598 &language_registry,
599 path_style,
600 cx,
601 ))),
602 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
603 Diff::finalized(
604 diff.path.to_string_lossy().into_owned(),
605 diff.old_text,
606 diff.new_text,
607 language_registry,
608 cx,
609 )
610 }))),
611 acp::ToolCallContent::Terminal { terminal_id } => terminals
612 .get(&terminal_id)
613 .cloned()
614 .map(Self::Terminal)
615 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
616 }
617 }
618
619 pub fn update_from_acp(
620 &mut self,
621 new: acp::ToolCallContent,
622 language_registry: Arc<LanguageRegistry>,
623 path_style: PathStyle,
624 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
625 cx: &mut App,
626 ) -> Result<()> {
627 let needs_update = match (&self, &new) {
628 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
629 old_diff.read(cx).needs_update(
630 new_diff.old_text.as_deref().unwrap_or(""),
631 &new_diff.new_text,
632 cx,
633 )
634 }
635 _ => true,
636 };
637
638 if needs_update {
639 *self = Self::from_acp(new, language_registry, path_style, terminals, cx)?;
640 }
641 Ok(())
642 }
643
644 pub fn to_markdown(&self, cx: &App) -> String {
645 match self {
646 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
647 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
648 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
649 }
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub enum ToolCallUpdate {
655 UpdateFields(acp::ToolCallUpdate),
656 UpdateDiff(ToolCallUpdateDiff),
657 UpdateTerminal(ToolCallUpdateTerminal),
658}
659
660impl ToolCallUpdate {
661 fn id(&self) -> &acp::ToolCallId {
662 match self {
663 Self::UpdateFields(update) => &update.id,
664 Self::UpdateDiff(diff) => &diff.id,
665 Self::UpdateTerminal(terminal) => &terminal.id,
666 }
667 }
668}
669
670impl From<acp::ToolCallUpdate> for ToolCallUpdate {
671 fn from(update: acp::ToolCallUpdate) -> Self {
672 Self::UpdateFields(update)
673 }
674}
675
676impl From<ToolCallUpdateDiff> for ToolCallUpdate {
677 fn from(diff: ToolCallUpdateDiff) -> Self {
678 Self::UpdateDiff(diff)
679 }
680}
681
682#[derive(Debug, PartialEq)]
683pub struct ToolCallUpdateDiff {
684 pub id: acp::ToolCallId,
685 pub diff: Entity<Diff>,
686}
687
688impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
689 fn from(terminal: ToolCallUpdateTerminal) -> Self {
690 Self::UpdateTerminal(terminal)
691 }
692}
693
694#[derive(Debug, PartialEq)]
695pub struct ToolCallUpdateTerminal {
696 pub id: acp::ToolCallId,
697 pub terminal: Entity<Terminal>,
698}
699
700#[derive(Debug, Default)]
701pub struct Plan {
702 pub entries: Vec<PlanEntry>,
703}
704
705#[derive(Debug)]
706pub struct PlanStats<'a> {
707 pub in_progress_entry: Option<&'a PlanEntry>,
708 pub pending: u32,
709 pub completed: u32,
710}
711
712impl Plan {
713 pub fn is_empty(&self) -> bool {
714 self.entries.is_empty()
715 }
716
717 pub fn stats(&self) -> PlanStats<'_> {
718 let mut stats = PlanStats {
719 in_progress_entry: None,
720 pending: 0,
721 completed: 0,
722 };
723
724 for entry in &self.entries {
725 match &entry.status {
726 acp::PlanEntryStatus::Pending => {
727 stats.pending += 1;
728 }
729 acp::PlanEntryStatus::InProgress => {
730 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
731 }
732 acp::PlanEntryStatus::Completed => {
733 stats.completed += 1;
734 }
735 }
736 }
737
738 stats
739 }
740}
741
742#[derive(Debug)]
743pub struct PlanEntry {
744 pub content: Entity<Markdown>,
745 pub priority: acp::PlanEntryPriority,
746 pub status: acp::PlanEntryStatus,
747}
748
749impl PlanEntry {
750 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
751 Self {
752 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
753 priority: entry.priority,
754 status: entry.status,
755 }
756 }
757}
758
759#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
760pub struct TokenUsage {
761 pub max_tokens: u64,
762 pub used_tokens: u64,
763}
764
765impl TokenUsage {
766 pub fn ratio(&self) -> TokenUsageRatio {
767 #[cfg(debug_assertions)]
768 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
769 .unwrap_or("0.8".to_string())
770 .parse()
771 .unwrap();
772 #[cfg(not(debug_assertions))]
773 let warning_threshold: f32 = 0.8;
774
775 // When the maximum is unknown because there is no selected model,
776 // avoid showing the token limit warning.
777 if self.max_tokens == 0 {
778 TokenUsageRatio::Normal
779 } else if self.used_tokens >= self.max_tokens {
780 TokenUsageRatio::Exceeded
781 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
782 TokenUsageRatio::Warning
783 } else {
784 TokenUsageRatio::Normal
785 }
786 }
787}
788
789#[derive(Debug, Clone, PartialEq, Eq)]
790pub enum TokenUsageRatio {
791 Normal,
792 Warning,
793 Exceeded,
794}
795
796#[derive(Debug, Clone)]
797pub struct RetryStatus {
798 pub last_error: SharedString,
799 pub attempt: usize,
800 pub max_attempts: usize,
801 pub started_at: Instant,
802 pub duration: Duration,
803}
804
805pub struct AcpThread {
806 title: SharedString,
807 entries: Vec<AgentThreadEntry>,
808 plan: Plan,
809 project: Entity<Project>,
810 action_log: Entity<ActionLog>,
811 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
812 send_task: Option<Task<()>>,
813 connection: Rc<dyn AgentConnection>,
814 session_id: acp::SessionId,
815 token_usage: Option<TokenUsage>,
816 prompt_capabilities: acp::PromptCapabilities,
817 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
818 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
819 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
820 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
821}
822
823impl From<&AcpThread> for ActionLogTelemetry {
824 fn from(value: &AcpThread) -> Self {
825 Self {
826 agent_telemetry_id: value.connection().telemetry_id(),
827 session_id: value.session_id.0.clone(),
828 }
829 }
830}
831
832#[derive(Debug)]
833pub enum AcpThreadEvent {
834 NewEntry,
835 TitleUpdated,
836 TokenUsageUpdated,
837 EntryUpdated(usize),
838 EntriesRemoved(Range<usize>),
839 ToolAuthorizationRequired,
840 Retry(RetryStatus),
841 Stopped,
842 Error,
843 LoadError(LoadError),
844 PromptCapabilitiesUpdated,
845 Refusal,
846 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
847 ModeUpdated(acp::SessionModeId),
848}
849
850impl EventEmitter<AcpThreadEvent> for AcpThread {}
851
852#[derive(Debug, Clone)]
853pub enum TerminalProviderEvent {
854 Created {
855 terminal_id: acp::TerminalId,
856 label: String,
857 cwd: Option<PathBuf>,
858 output_byte_limit: Option<u64>,
859 terminal: Entity<::terminal::Terminal>,
860 },
861 Output {
862 terminal_id: acp::TerminalId,
863 data: Vec<u8>,
864 },
865 TitleChanged {
866 terminal_id: acp::TerminalId,
867 title: String,
868 },
869 Exit {
870 terminal_id: acp::TerminalId,
871 status: acp::TerminalExitStatus,
872 },
873}
874
875#[derive(Debug, Clone)]
876pub enum TerminalProviderCommand {
877 WriteInput {
878 terminal_id: acp::TerminalId,
879 bytes: Vec<u8>,
880 },
881 Resize {
882 terminal_id: acp::TerminalId,
883 cols: u16,
884 rows: u16,
885 },
886 Close {
887 terminal_id: acp::TerminalId,
888 },
889}
890
891impl AcpThread {
892 pub fn on_terminal_provider_event(
893 &mut self,
894 event: TerminalProviderEvent,
895 cx: &mut Context<Self>,
896 ) {
897 match event {
898 TerminalProviderEvent::Created {
899 terminal_id,
900 label,
901 cwd,
902 output_byte_limit,
903 terminal,
904 } => {
905 let entity = self.register_terminal_created(
906 terminal_id.clone(),
907 label,
908 cwd,
909 output_byte_limit,
910 terminal,
911 cx,
912 );
913
914 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
915 for data in chunks.drain(..) {
916 entity.update(cx, |term, cx| {
917 term.inner().update(cx, |inner, cx| {
918 inner.write_output(&data, cx);
919 })
920 });
921 }
922 }
923
924 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
925 entity.update(cx, |_term, cx| {
926 cx.notify();
927 });
928 }
929
930 cx.notify();
931 }
932 TerminalProviderEvent::Output { terminal_id, data } => {
933 if let Some(entity) = self.terminals.get(&terminal_id) {
934 entity.update(cx, |term, cx| {
935 term.inner().update(cx, |inner, cx| {
936 inner.write_output(&data, cx);
937 })
938 });
939 } else {
940 self.pending_terminal_output
941 .entry(terminal_id)
942 .or_default()
943 .push(data);
944 }
945 }
946 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
947 if let Some(entity) = self.terminals.get(&terminal_id) {
948 entity.update(cx, |term, cx| {
949 term.inner().update(cx, |inner, cx| {
950 inner.breadcrumb_text = title;
951 cx.emit(::terminal::Event::BreadcrumbsChanged);
952 })
953 });
954 }
955 }
956 TerminalProviderEvent::Exit {
957 terminal_id,
958 status,
959 } => {
960 if let Some(entity) = self.terminals.get(&terminal_id) {
961 entity.update(cx, |_term, cx| {
962 cx.notify();
963 });
964 } else {
965 self.pending_terminal_exit.insert(terminal_id, status);
966 }
967 }
968 }
969 }
970}
971
972#[derive(PartialEq, Eq, Debug)]
973pub enum ThreadStatus {
974 Idle,
975 Generating,
976}
977
978#[derive(Debug, Clone)]
979pub enum LoadError {
980 Unsupported {
981 command: SharedString,
982 current_version: SharedString,
983 minimum_version: SharedString,
984 },
985 FailedToInstall(SharedString),
986 Exited {
987 status: ExitStatus,
988 },
989 Other(SharedString),
990}
991
992impl Display for LoadError {
993 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
994 match self {
995 LoadError::Unsupported {
996 command: path,
997 current_version,
998 minimum_version,
999 } => {
1000 write!(
1001 f,
1002 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1003 )
1004 }
1005 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1006 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1007 LoadError::Other(msg) => write!(f, "{msg}"),
1008 }
1009 }
1010}
1011
1012impl Error for LoadError {}
1013
1014impl AcpThread {
1015 pub fn new(
1016 title: impl Into<SharedString>,
1017 connection: Rc<dyn AgentConnection>,
1018 project: Entity<Project>,
1019 action_log: Entity<ActionLog>,
1020 session_id: acp::SessionId,
1021 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1022 cx: &mut Context<Self>,
1023 ) -> Self {
1024 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1025 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1026 loop {
1027 let caps = prompt_capabilities_rx.recv().await?;
1028 this.update(cx, |this, cx| {
1029 this.prompt_capabilities = caps;
1030 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1031 })?;
1032 }
1033 });
1034
1035 Self {
1036 action_log,
1037 shared_buffers: Default::default(),
1038 entries: Default::default(),
1039 plan: Default::default(),
1040 title: title.into(),
1041 project,
1042 send_task: None,
1043 connection,
1044 session_id,
1045 token_usage: None,
1046 prompt_capabilities,
1047 _observe_prompt_capabilities: task,
1048 terminals: HashMap::default(),
1049 pending_terminal_output: HashMap::default(),
1050 pending_terminal_exit: HashMap::default(),
1051 }
1052 }
1053
1054 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1055 self.prompt_capabilities.clone()
1056 }
1057
1058 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1059 &self.connection
1060 }
1061
1062 pub fn action_log(&self) -> &Entity<ActionLog> {
1063 &self.action_log
1064 }
1065
1066 pub fn project(&self) -> &Entity<Project> {
1067 &self.project
1068 }
1069
1070 pub fn title(&self) -> SharedString {
1071 self.title.clone()
1072 }
1073
1074 pub fn entries(&self) -> &[AgentThreadEntry] {
1075 &self.entries
1076 }
1077
1078 pub fn session_id(&self) -> &acp::SessionId {
1079 &self.session_id
1080 }
1081
1082 pub fn status(&self) -> ThreadStatus {
1083 if self.send_task.is_some() {
1084 ThreadStatus::Generating
1085 } else {
1086 ThreadStatus::Idle
1087 }
1088 }
1089
1090 pub fn token_usage(&self) -> Option<&TokenUsage> {
1091 self.token_usage.as_ref()
1092 }
1093
1094 pub fn has_pending_edit_tool_calls(&self) -> bool {
1095 for entry in self.entries.iter().rev() {
1096 match entry {
1097 AgentThreadEntry::UserMessage(_) => return false,
1098 AgentThreadEntry::ToolCall(
1099 call @ ToolCall {
1100 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1101 ..
1102 },
1103 ) if call.diffs().next().is_some() => {
1104 return true;
1105 }
1106 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1107 }
1108 }
1109
1110 false
1111 }
1112
1113 pub fn used_tools_since_last_user_message(&self) -> bool {
1114 for entry in self.entries.iter().rev() {
1115 match entry {
1116 AgentThreadEntry::UserMessage(..) => return false,
1117 AgentThreadEntry::AssistantMessage(..) => continue,
1118 AgentThreadEntry::ToolCall(..) => return true,
1119 }
1120 }
1121
1122 false
1123 }
1124
1125 pub fn handle_session_update(
1126 &mut self,
1127 update: acp::SessionUpdate,
1128 cx: &mut Context<Self>,
1129 ) -> Result<(), acp::Error> {
1130 match update {
1131 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1132 self.push_user_content_block(None, content, cx);
1133 }
1134 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1135 self.push_assistant_content_block(content, false, cx);
1136 }
1137 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1138 self.push_assistant_content_block(content, true, cx);
1139 }
1140 acp::SessionUpdate::ToolCall(tool_call) => {
1141 self.upsert_tool_call(tool_call, cx)?;
1142 }
1143 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1144 self.update_tool_call(tool_call_update, cx)?;
1145 }
1146 acp::SessionUpdate::Plan(plan) => {
1147 self.update_plan(plan, cx);
1148 }
1149 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1150 available_commands,
1151 ..
1152 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1153 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1154 current_mode_id,
1155 ..
1156 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1157 }
1158 Ok(())
1159 }
1160
1161 pub fn push_user_content_block(
1162 &mut self,
1163 message_id: Option<UserMessageId>,
1164 chunk: acp::ContentBlock,
1165 cx: &mut Context<Self>,
1166 ) {
1167 let language_registry = self.project.read(cx).languages().clone();
1168 let path_style = self.project.read(cx).path_style(cx);
1169 let entries_len = self.entries.len();
1170
1171 if let Some(last_entry) = self.entries.last_mut()
1172 && let AgentThreadEntry::UserMessage(UserMessage {
1173 id,
1174 content,
1175 chunks,
1176 ..
1177 }) = last_entry
1178 {
1179 *id = message_id.or(id.take());
1180 content.append(chunk.clone(), &language_registry, path_style, cx);
1181 chunks.push(chunk);
1182 let idx = entries_len - 1;
1183 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1184 } else {
1185 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1186 self.push_entry(
1187 AgentThreadEntry::UserMessage(UserMessage {
1188 id: message_id,
1189 content,
1190 chunks: vec![chunk],
1191 checkpoint: None,
1192 }),
1193 cx,
1194 );
1195 }
1196 }
1197
1198 pub fn push_assistant_content_block(
1199 &mut self,
1200 chunk: acp::ContentBlock,
1201 is_thought: bool,
1202 cx: &mut Context<Self>,
1203 ) {
1204 let language_registry = self.project.read(cx).languages().clone();
1205 let path_style = self.project.read(cx).path_style(cx);
1206 let entries_len = self.entries.len();
1207 if let Some(last_entry) = self.entries.last_mut()
1208 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1209 {
1210 let idx = entries_len - 1;
1211 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1212 match (chunks.last_mut(), is_thought) {
1213 (Some(AssistantMessageChunk::Message { block }), false)
1214 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1215 block.append(chunk, &language_registry, path_style, cx)
1216 }
1217 _ => {
1218 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1219 if is_thought {
1220 chunks.push(AssistantMessageChunk::Thought { block })
1221 } else {
1222 chunks.push(AssistantMessageChunk::Message { block })
1223 }
1224 }
1225 }
1226 } else {
1227 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1228 let chunk = if is_thought {
1229 AssistantMessageChunk::Thought { block }
1230 } else {
1231 AssistantMessageChunk::Message { block }
1232 };
1233
1234 self.push_entry(
1235 AgentThreadEntry::AssistantMessage(AssistantMessage {
1236 chunks: vec![chunk],
1237 }),
1238 cx,
1239 );
1240 }
1241 }
1242
1243 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1244 self.entries.push(entry);
1245 cx.emit(AcpThreadEvent::NewEntry);
1246 }
1247
1248 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1249 self.connection.set_title(&self.session_id, cx).is_some()
1250 }
1251
1252 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1253 if title != self.title {
1254 self.title = title.clone();
1255 cx.emit(AcpThreadEvent::TitleUpdated);
1256 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1257 return set_title.run(title, cx);
1258 }
1259 }
1260 Task::ready(Ok(()))
1261 }
1262
1263 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1264 self.token_usage = usage;
1265 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1266 }
1267
1268 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1269 cx.emit(AcpThreadEvent::Retry(status));
1270 }
1271
1272 pub fn update_tool_call(
1273 &mut self,
1274 update: impl Into<ToolCallUpdate>,
1275 cx: &mut Context<Self>,
1276 ) -> Result<()> {
1277 let update = update.into();
1278 let languages = self.project.read(cx).languages().clone();
1279 let path_style = self.project.read(cx).path_style(cx);
1280
1281 let ix = match self.index_for_tool_call(update.id()) {
1282 Some(ix) => ix,
1283 None => {
1284 // Tool call not found - create a failed tool call entry
1285 let failed_tool_call = ToolCall {
1286 id: update.id().clone(),
1287 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1288 kind: acp::ToolKind::Fetch,
1289 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1290 acp::ContentBlock::Text(acp::TextContent {
1291 text: "Tool call not found".to_string(),
1292 annotations: None,
1293 meta: None,
1294 }),
1295 &languages,
1296 path_style,
1297 cx,
1298 ))],
1299 status: ToolCallStatus::Failed,
1300 locations: Vec::new(),
1301 resolved_locations: Vec::new(),
1302 raw_input: None,
1303 raw_output: None,
1304 };
1305 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1306 return Ok(());
1307 }
1308 };
1309 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1310 unreachable!()
1311 };
1312
1313 match update {
1314 ToolCallUpdate::UpdateFields(update) => {
1315 let location_updated = update.fields.locations.is_some();
1316 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1317 if location_updated {
1318 self.resolve_locations(update.id, cx);
1319 }
1320 }
1321 ToolCallUpdate::UpdateDiff(update) => {
1322 call.content.clear();
1323 call.content.push(ToolCallContent::Diff(update.diff));
1324 }
1325 ToolCallUpdate::UpdateTerminal(update) => {
1326 call.content.clear();
1327 call.content
1328 .push(ToolCallContent::Terminal(update.terminal));
1329 }
1330 }
1331
1332 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1333
1334 Ok(())
1335 }
1336
1337 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1338 pub fn upsert_tool_call(
1339 &mut self,
1340 tool_call: acp::ToolCall,
1341 cx: &mut Context<Self>,
1342 ) -> Result<(), acp::Error> {
1343 let status = tool_call.status.into();
1344 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1345 }
1346
1347 /// Fails if id does not match an existing entry.
1348 pub fn upsert_tool_call_inner(
1349 &mut self,
1350 update: acp::ToolCallUpdate,
1351 status: ToolCallStatus,
1352 cx: &mut Context<Self>,
1353 ) -> Result<(), acp::Error> {
1354 let language_registry = self.project.read(cx).languages().clone();
1355 let path_style = self.project.read(cx).path_style(cx);
1356 let id = update.id.clone();
1357
1358 let agent = self.connection().telemetry_id();
1359 let session = self.session_id();
1360 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1361 let status = if matches!(status, ToolCallStatus::Completed) {
1362 "completed"
1363 } else {
1364 "failed"
1365 };
1366 telemetry::event!("Agent Tool Call Completed", agent, session, status);
1367 }
1368
1369 if let Some(ix) = self.index_for_tool_call(&id) {
1370 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1371 unreachable!()
1372 };
1373
1374 call.update_fields(
1375 update.fields,
1376 language_registry,
1377 path_style,
1378 &self.terminals,
1379 cx,
1380 )?;
1381 call.status = status;
1382
1383 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1384 } else {
1385 let call = ToolCall::from_acp(
1386 update.try_into()?,
1387 status,
1388 language_registry,
1389 self.project.read(cx).path_style(cx),
1390 &self.terminals,
1391 cx,
1392 )?;
1393 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1394 };
1395
1396 self.resolve_locations(id, cx);
1397 Ok(())
1398 }
1399
1400 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1401 self.entries
1402 .iter()
1403 .enumerate()
1404 .rev()
1405 .find_map(|(index, entry)| {
1406 if let AgentThreadEntry::ToolCall(tool_call) = entry
1407 && &tool_call.id == id
1408 {
1409 Some(index)
1410 } else {
1411 None
1412 }
1413 })
1414 }
1415
1416 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1417 // The tool call we are looking for is typically the last one, or very close to the end.
1418 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1419 self.entries
1420 .iter_mut()
1421 .enumerate()
1422 .rev()
1423 .find_map(|(index, tool_call)| {
1424 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1425 && &tool_call.id == id
1426 {
1427 Some((index, tool_call))
1428 } else {
1429 None
1430 }
1431 })
1432 }
1433
1434 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1435 self.entries
1436 .iter()
1437 .enumerate()
1438 .rev()
1439 .find_map(|(index, tool_call)| {
1440 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1441 && &tool_call.id == id
1442 {
1443 Some((index, tool_call))
1444 } else {
1445 None
1446 }
1447 })
1448 }
1449
1450 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1451 let project = self.project.clone();
1452 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1453 return;
1454 };
1455 let task = tool_call.resolve_locations(project, cx);
1456 cx.spawn(async move |this, cx| {
1457 let resolved_locations = task.await;
1458
1459 this.update(cx, |this, cx| {
1460 let project = this.project.clone();
1461
1462 for location in resolved_locations.iter().flatten() {
1463 this.shared_buffers
1464 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1465 }
1466 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1467 return;
1468 };
1469
1470 if let Some(Some(location)) = resolved_locations.last() {
1471 project.update(cx, |project, cx| {
1472 let should_ignore = if let Some(agent_location) = project
1473 .agent_location()
1474 .filter(|agent_location| agent_location.buffer == location.buffer)
1475 {
1476 let snapshot = location.buffer.read(cx).snapshot();
1477 let old_position = agent_location.position.to_point(&snapshot);
1478 let new_position = location.position.to_point(&snapshot);
1479
1480 // ignore this so that when we get updates from the edit tool
1481 // the position doesn't reset to the startof line
1482 old_position.row == new_position.row
1483 && old_position.column > new_position.column
1484 } else {
1485 false
1486 };
1487 if !should_ignore {
1488 project.set_agent_location(Some(location.into()), cx);
1489 }
1490 });
1491 }
1492
1493 let resolved_locations = resolved_locations
1494 .iter()
1495 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1496 .collect::<Vec<_>>();
1497
1498 if tool_call.resolved_locations != resolved_locations {
1499 tool_call.resolved_locations = resolved_locations;
1500 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1501 }
1502 })
1503 })
1504 .detach();
1505 }
1506
1507 pub fn request_tool_call_authorization(
1508 &mut self,
1509 tool_call: acp::ToolCallUpdate,
1510 options: Vec<acp::PermissionOption>,
1511 respect_always_allow_setting: bool,
1512 cx: &mut Context<Self>,
1513 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1514 let (tx, rx) = oneshot::channel();
1515
1516 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1517 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1518 // some tools would (incorrectly) continue to auto-accept.
1519 if let Some(allow_once_option) = options.iter().find_map(|option| {
1520 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1521 Some(option.id.clone())
1522 } else {
1523 None
1524 }
1525 }) {
1526 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1527 return Ok(async {
1528 acp::RequestPermissionOutcome::Selected {
1529 option_id: allow_once_option,
1530 }
1531 }
1532 .boxed());
1533 }
1534 }
1535
1536 let status = ToolCallStatus::WaitingForConfirmation {
1537 options,
1538 respond_tx: tx,
1539 };
1540
1541 self.upsert_tool_call_inner(tool_call, status, cx)?;
1542 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1543
1544 let fut = async {
1545 match rx.await {
1546 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1547 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1548 }
1549 }
1550 .boxed();
1551
1552 Ok(fut)
1553 }
1554
1555 pub fn authorize_tool_call(
1556 &mut self,
1557 id: acp::ToolCallId,
1558 option_id: acp::PermissionOptionId,
1559 option_kind: acp::PermissionOptionKind,
1560 cx: &mut Context<Self>,
1561 ) {
1562 let Some((ix, call)) = self.tool_call_mut(&id) else {
1563 return;
1564 };
1565
1566 let new_status = match option_kind {
1567 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1568 ToolCallStatus::Rejected
1569 }
1570 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1571 ToolCallStatus::InProgress
1572 }
1573 };
1574
1575 let curr_status = mem::replace(&mut call.status, new_status);
1576
1577 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1578 respond_tx.send(option_id).log_err();
1579 } else if cfg!(debug_assertions) {
1580 panic!("tried to authorize an already authorized tool call");
1581 }
1582
1583 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1584 }
1585
1586 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1587 let mut first_tool_call = None;
1588
1589 for entry in self.entries.iter().rev() {
1590 match &entry {
1591 AgentThreadEntry::ToolCall(call) => {
1592 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1593 first_tool_call = Some(call);
1594 } else {
1595 continue;
1596 }
1597 }
1598 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1599 // Reached the beginning of the turn.
1600 // If we had pending permission requests in the previous turn, they have been cancelled.
1601 break;
1602 }
1603 }
1604 }
1605
1606 first_tool_call
1607 }
1608
1609 pub fn plan(&self) -> &Plan {
1610 &self.plan
1611 }
1612
1613 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1614 let new_entries_len = request.entries.len();
1615 let mut new_entries = request.entries.into_iter();
1616
1617 // Reuse existing markdown to prevent flickering
1618 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1619 let PlanEntry {
1620 content,
1621 priority,
1622 status,
1623 } = old;
1624 content.update(cx, |old, cx| {
1625 old.replace(new.content, cx);
1626 });
1627 *priority = new.priority;
1628 *status = new.status;
1629 }
1630 for new in new_entries {
1631 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1632 }
1633 self.plan.entries.truncate(new_entries_len);
1634
1635 cx.notify();
1636 }
1637
1638 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1639 self.plan
1640 .entries
1641 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1642 cx.notify();
1643 }
1644
1645 #[cfg(any(test, feature = "test-support"))]
1646 pub fn send_raw(
1647 &mut self,
1648 message: &str,
1649 cx: &mut Context<Self>,
1650 ) -> BoxFuture<'static, Result<()>> {
1651 self.send(
1652 vec![acp::ContentBlock::Text(acp::TextContent {
1653 text: message.to_string(),
1654 annotations: None,
1655 meta: None,
1656 })],
1657 cx,
1658 )
1659 }
1660
1661 pub fn send(
1662 &mut self,
1663 message: Vec<acp::ContentBlock>,
1664 cx: &mut Context<Self>,
1665 ) -> BoxFuture<'static, Result<()>> {
1666 let block = ContentBlock::new_combined(
1667 message.clone(),
1668 self.project.read(cx).languages().clone(),
1669 self.project.read(cx).path_style(cx),
1670 cx,
1671 );
1672 let request = acp::PromptRequest {
1673 prompt: message.clone(),
1674 session_id: self.session_id.clone(),
1675 meta: None,
1676 };
1677 let git_store = self.project.read(cx).git_store().clone();
1678
1679 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1680 Some(UserMessageId::new())
1681 } else {
1682 None
1683 };
1684
1685 self.run_turn(cx, async move |this, cx| {
1686 this.update(cx, |this, cx| {
1687 this.push_entry(
1688 AgentThreadEntry::UserMessage(UserMessage {
1689 id: message_id.clone(),
1690 content: block,
1691 chunks: message,
1692 checkpoint: None,
1693 }),
1694 cx,
1695 );
1696 })
1697 .ok();
1698
1699 let old_checkpoint = git_store
1700 .update(cx, |git, cx| git.checkpoint(cx))?
1701 .await
1702 .context("failed to get old checkpoint")
1703 .log_err();
1704 this.update(cx, |this, cx| {
1705 if let Some((_ix, message)) = this.last_user_message() {
1706 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1707 git_checkpoint,
1708 show: false,
1709 });
1710 }
1711 this.connection.prompt(message_id, request, cx)
1712 })?
1713 .await
1714 })
1715 }
1716
1717 pub fn can_resume(&self, cx: &App) -> bool {
1718 self.connection.resume(&self.session_id, cx).is_some()
1719 }
1720
1721 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1722 self.run_turn(cx, async move |this, cx| {
1723 this.update(cx, |this, cx| {
1724 this.connection
1725 .resume(&this.session_id, cx)
1726 .map(|resume| resume.run(cx))
1727 })?
1728 .context("resuming a session is not supported")?
1729 .await
1730 })
1731 }
1732
1733 fn run_turn(
1734 &mut self,
1735 cx: &mut Context<Self>,
1736 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1737 ) -> BoxFuture<'static, Result<()>> {
1738 self.clear_completed_plan_entries(cx);
1739
1740 let (tx, rx) = oneshot::channel();
1741 let cancel_task = self.cancel(cx);
1742
1743 self.send_task = Some(cx.spawn(async move |this, cx| {
1744 cancel_task.await;
1745 tx.send(f(this, cx).await).ok();
1746 }));
1747
1748 cx.spawn(async move |this, cx| {
1749 let response = rx.await;
1750
1751 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1752 .await?;
1753
1754 this.update(cx, |this, cx| {
1755 this.project
1756 .update(cx, |project, cx| project.set_agent_location(None, cx));
1757 match response {
1758 Ok(Err(e)) => {
1759 this.send_task.take();
1760 cx.emit(AcpThreadEvent::Error);
1761 Err(e)
1762 }
1763 result => {
1764 let canceled = matches!(
1765 result,
1766 Ok(Ok(acp::PromptResponse {
1767 stop_reason: acp::StopReason::Cancelled,
1768 meta: None,
1769 }))
1770 );
1771
1772 // We only take the task if the current prompt wasn't canceled.
1773 //
1774 // This prompt may have been canceled because another one was sent
1775 // while it was still generating. In these cases, dropping `send_task`
1776 // would cause the next generation to be canceled.
1777 if !canceled {
1778 this.send_task.take();
1779 }
1780
1781 // Handle refusal - distinguish between user prompt and tool call refusals
1782 if let Ok(Ok(acp::PromptResponse {
1783 stop_reason: acp::StopReason::Refusal,
1784 meta: _,
1785 })) = result
1786 {
1787 if let Some((user_msg_ix, _)) = this.last_user_message() {
1788 // Check if there's a completed tool call with results after the last user message
1789 // This indicates the refusal is in response to tool output, not the user's prompt
1790 let has_completed_tool_call_after_user_msg =
1791 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1792 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1793 // Check if the tool call has completed and has output
1794 matches!(tool_call.status, ToolCallStatus::Completed)
1795 && tool_call.raw_output.is_some()
1796 } else {
1797 false
1798 }
1799 });
1800
1801 if has_completed_tool_call_after_user_msg {
1802 // Refusal is due to tool output - don't truncate, just notify
1803 // The model refused based on what the tool returned
1804 cx.emit(AcpThreadEvent::Refusal);
1805 } else {
1806 // User prompt was refused - truncate back to before the user message
1807 let range = user_msg_ix..this.entries.len();
1808 if range.start < range.end {
1809 this.entries.truncate(user_msg_ix);
1810 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1811 }
1812 cx.emit(AcpThreadEvent::Refusal);
1813 }
1814 } else {
1815 // No user message found, treat as general refusal
1816 cx.emit(AcpThreadEvent::Refusal);
1817 }
1818 }
1819
1820 cx.emit(AcpThreadEvent::Stopped);
1821 Ok(())
1822 }
1823 }
1824 })?
1825 })
1826 .boxed()
1827 }
1828
1829 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1830 let Some(send_task) = self.send_task.take() else {
1831 return Task::ready(());
1832 };
1833
1834 for entry in self.entries.iter_mut() {
1835 if let AgentThreadEntry::ToolCall(call) = entry {
1836 let cancel = matches!(
1837 call.status,
1838 ToolCallStatus::Pending
1839 | ToolCallStatus::WaitingForConfirmation { .. }
1840 | ToolCallStatus::InProgress
1841 );
1842
1843 if cancel {
1844 call.status = ToolCallStatus::Canceled;
1845 }
1846 }
1847 }
1848
1849 self.connection.cancel(&self.session_id, cx);
1850
1851 // Wait for the send task to complete
1852 cx.foreground_executor().spawn(send_task)
1853 }
1854
1855 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1856 pub fn restore_checkpoint(
1857 &mut self,
1858 id: UserMessageId,
1859 cx: &mut Context<Self>,
1860 ) -> Task<Result<()>> {
1861 let Some((_, message)) = self.user_message_mut(&id) else {
1862 return Task::ready(Err(anyhow!("message not found")));
1863 };
1864
1865 let checkpoint = message
1866 .checkpoint
1867 .as_ref()
1868 .map(|c| c.git_checkpoint.clone());
1869
1870 // Cancel any in-progress generation before restoring
1871 let cancel_task = self.cancel(cx);
1872 let rewind = self.rewind(id.clone(), cx);
1873 let git_store = self.project.read(cx).git_store().clone();
1874
1875 cx.spawn(async move |_, cx| {
1876 cancel_task.await;
1877 rewind.await?;
1878 if let Some(checkpoint) = checkpoint {
1879 git_store
1880 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1881 .await?;
1882 }
1883
1884 Ok(())
1885 })
1886 }
1887
1888 /// Rewinds this thread to before the entry at `index`, removing it and all
1889 /// subsequent entries while rejecting any action_log changes made from that point.
1890 /// Unlike `restore_checkpoint`, this method does not restore from git.
1891 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1892 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1893 return Task::ready(Err(anyhow!("not supported")));
1894 };
1895
1896 let telemetry = ActionLogTelemetry::from(&*self);
1897 cx.spawn(async move |this, cx| {
1898 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1899 this.update(cx, |this, cx| {
1900 if let Some((ix, _)) = this.user_message_mut(&id) {
1901 // Collect all terminals from entries that will be removed
1902 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1903 .iter()
1904 .flat_map(|entry| entry.terminals())
1905 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1906 .collect();
1907
1908 let range = ix..this.entries.len();
1909 this.entries.truncate(ix);
1910 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1911
1912 // Kill and remove the terminals
1913 for terminal_id in terminals_to_remove {
1914 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1915 terminal.update(cx, |terminal, cx| {
1916 terminal.kill(cx);
1917 });
1918 }
1919 }
1920 }
1921 this.action_log().update(cx, |action_log, cx| {
1922 action_log.reject_all_edits(Some(telemetry), cx)
1923 })
1924 })?
1925 .await;
1926 Ok(())
1927 })
1928 }
1929
1930 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1931 let git_store = self.project.read(cx).git_store().clone();
1932
1933 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1934 if let Some(checkpoint) = message.checkpoint.as_ref() {
1935 checkpoint.git_checkpoint.clone()
1936 } else {
1937 return Task::ready(Ok(()));
1938 }
1939 } else {
1940 return Task::ready(Ok(()));
1941 };
1942
1943 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1944 cx.spawn(async move |this, cx| {
1945 let new_checkpoint = new_checkpoint
1946 .await
1947 .context("failed to get new checkpoint")
1948 .log_err();
1949 if let Some(new_checkpoint) = new_checkpoint {
1950 let equal = git_store
1951 .update(cx, |git, cx| {
1952 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1953 })?
1954 .await
1955 .unwrap_or(true);
1956 this.update(cx, |this, cx| {
1957 let (ix, message) = this.last_user_message().context("no user message")?;
1958 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1959 checkpoint.show = !equal;
1960 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1961 anyhow::Ok(())
1962 })??;
1963 }
1964
1965 Ok(())
1966 })
1967 }
1968
1969 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1970 self.entries
1971 .iter_mut()
1972 .enumerate()
1973 .rev()
1974 .find_map(|(ix, entry)| {
1975 if let AgentThreadEntry::UserMessage(message) = entry {
1976 Some((ix, message))
1977 } else {
1978 None
1979 }
1980 })
1981 }
1982
1983 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1984 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1985 if let AgentThreadEntry::UserMessage(message) = entry {
1986 if message.id.as_ref() == Some(id) {
1987 Some((ix, message))
1988 } else {
1989 None
1990 }
1991 } else {
1992 None
1993 }
1994 })
1995 }
1996
1997 pub fn read_text_file(
1998 &self,
1999 path: PathBuf,
2000 line: Option<u32>,
2001 limit: Option<u32>,
2002 reuse_shared_snapshot: bool,
2003 cx: &mut Context<Self>,
2004 ) -> Task<Result<String, acp::Error>> {
2005 // Args are 1-based, move to 0-based
2006 let line = line.unwrap_or_default().saturating_sub(1);
2007 let limit = limit.unwrap_or(u32::MAX);
2008 let project = self.project.clone();
2009 let action_log = self.action_log.clone();
2010 cx.spawn(async move |this, cx| {
2011 let load = project
2012 .update(cx, |project, cx| {
2013 let path = project
2014 .project_path_for_absolute_path(&path, cx)
2015 .ok_or_else(|| {
2016 acp::Error::resource_not_found(Some(path.display().to_string()))
2017 })?;
2018 Ok(project.open_buffer(path, cx))
2019 })
2020 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
2021 .flatten()?;
2022
2023 let buffer = load.await?;
2024
2025 let snapshot = if reuse_shared_snapshot {
2026 this.read_with(cx, |this, _| {
2027 this.shared_buffers.get(&buffer.clone()).cloned()
2028 })
2029 .log_err()
2030 .flatten()
2031 } else {
2032 None
2033 };
2034
2035 let snapshot = if let Some(snapshot) = snapshot {
2036 snapshot
2037 } else {
2038 action_log.update(cx, |action_log, cx| {
2039 action_log.buffer_read(buffer.clone(), cx);
2040 })?;
2041
2042 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2043 this.update(cx, |this, _| {
2044 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2045 })?;
2046 snapshot
2047 };
2048
2049 let max_point = snapshot.max_point();
2050 let start_position = Point::new(line, 0);
2051
2052 if start_position > max_point {
2053 return Err(acp::Error::invalid_params().with_data(format!(
2054 "Attempting to read beyond the end of the file, line {}:{}",
2055 max_point.row + 1,
2056 max_point.column
2057 )));
2058 }
2059
2060 let start = snapshot.anchor_before(start_position);
2061 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2062
2063 project.update(cx, |project, cx| {
2064 project.set_agent_location(
2065 Some(AgentLocation {
2066 buffer: buffer.downgrade(),
2067 position: start,
2068 }),
2069 cx,
2070 );
2071 })?;
2072
2073 Ok(snapshot.text_for_range(start..end).collect::<String>())
2074 })
2075 }
2076
2077 pub fn write_text_file(
2078 &self,
2079 path: PathBuf,
2080 content: String,
2081 cx: &mut Context<Self>,
2082 ) -> Task<Result<()>> {
2083 let project = self.project.clone();
2084 let action_log = self.action_log.clone();
2085 cx.spawn(async move |this, cx| {
2086 let load = project.update(cx, |project, cx| {
2087 let path = project
2088 .project_path_for_absolute_path(&path, cx)
2089 .context("invalid path")?;
2090 anyhow::Ok(project.open_buffer(path, cx))
2091 });
2092 let buffer = load??.await?;
2093 let snapshot = this.update(cx, |this, cx| {
2094 this.shared_buffers
2095 .get(&buffer)
2096 .cloned()
2097 .unwrap_or_else(|| buffer.read(cx).snapshot())
2098 })?;
2099 let edits = cx
2100 .background_executor()
2101 .spawn(async move {
2102 let old_text = snapshot.text();
2103 text_diff(old_text.as_str(), &content)
2104 .into_iter()
2105 .map(|(range, replacement)| {
2106 (
2107 snapshot.anchor_after(range.start)
2108 ..snapshot.anchor_before(range.end),
2109 replacement,
2110 )
2111 })
2112 .collect::<Vec<_>>()
2113 })
2114 .await;
2115
2116 project.update(cx, |project, cx| {
2117 project.set_agent_location(
2118 Some(AgentLocation {
2119 buffer: buffer.downgrade(),
2120 position: edits
2121 .last()
2122 .map(|(range, _)| range.end)
2123 .unwrap_or(Anchor::MIN),
2124 }),
2125 cx,
2126 );
2127 })?;
2128
2129 let format_on_save = cx.update(|cx| {
2130 action_log.update(cx, |action_log, cx| {
2131 action_log.buffer_read(buffer.clone(), cx);
2132 });
2133
2134 let format_on_save = buffer.update(cx, |buffer, cx| {
2135 buffer.edit(edits, None, cx);
2136
2137 let settings = language::language_settings::language_settings(
2138 buffer.language().map(|l| l.name()),
2139 buffer.file(),
2140 cx,
2141 );
2142
2143 settings.format_on_save != FormatOnSave::Off
2144 });
2145 action_log.update(cx, |action_log, cx| {
2146 action_log.buffer_edited(buffer.clone(), cx);
2147 });
2148 format_on_save
2149 })?;
2150
2151 if format_on_save {
2152 let format_task = project.update(cx, |project, cx| {
2153 project.format(
2154 HashSet::from_iter([buffer.clone()]),
2155 LspFormatTarget::Buffers,
2156 false,
2157 FormatTrigger::Save,
2158 cx,
2159 )
2160 })?;
2161 format_task.await.log_err();
2162
2163 action_log.update(cx, |action_log, cx| {
2164 action_log.buffer_edited(buffer.clone(), cx);
2165 })?;
2166 }
2167
2168 project
2169 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2170 .await
2171 })
2172 }
2173
2174 pub fn create_terminal(
2175 &self,
2176 command: String,
2177 args: Vec<String>,
2178 extra_env: Vec<acp::EnvVariable>,
2179 cwd: Option<PathBuf>,
2180 output_byte_limit: Option<u64>,
2181 cx: &mut Context<Self>,
2182 ) -> Task<Result<Entity<Terminal>>> {
2183 let env = match &cwd {
2184 Some(dir) => self.project.update(cx, |project, cx| {
2185 project.environment().update(cx, |env, cx| {
2186 env.directory_environment(dir.as_path().into(), cx)
2187 })
2188 }),
2189 None => Task::ready(None).shared(),
2190 };
2191 let env = cx.spawn(async move |_, _| {
2192 let mut env = env.await.unwrap_or_default();
2193 // Disables paging for `git` and hopefully other commands
2194 env.insert("PAGER".into(), "".into());
2195 for var in extra_env {
2196 env.insert(var.name, var.value);
2197 }
2198 env
2199 });
2200
2201 let project = self.project.clone();
2202 let language_registry = project.read(cx).languages().clone();
2203 let is_windows = project.read(cx).path_style(cx).is_windows();
2204
2205 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2206 let terminal_task = cx.spawn({
2207 let terminal_id = terminal_id.clone();
2208 async move |_this, cx| {
2209 let env = env.await;
2210 let shell = project
2211 .update(cx, |project, cx| {
2212 project
2213 .remote_client()
2214 .and_then(|r| r.read(cx).default_system_shell())
2215 })?
2216 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2217 let (task_command, task_args) =
2218 ShellBuilder::new(&Shell::Program(shell), is_windows)
2219 .redirect_stdin_to_dev_null()
2220 .build(Some(command.clone()), &args);
2221 let terminal = project
2222 .update(cx, |project, cx| {
2223 project.create_terminal_task(
2224 task::SpawnInTerminal {
2225 command: Some(task_command),
2226 args: task_args,
2227 cwd: cwd.clone(),
2228 env,
2229 ..Default::default()
2230 },
2231 cx,
2232 )
2233 })?
2234 .await?;
2235
2236 cx.new(|cx| {
2237 Terminal::new(
2238 terminal_id,
2239 &format!("{} {}", command, args.join(" ")),
2240 cwd,
2241 output_byte_limit.map(|l| l as usize),
2242 terminal,
2243 language_registry,
2244 cx,
2245 )
2246 })
2247 }
2248 });
2249
2250 cx.spawn(async move |this, cx| {
2251 let terminal = terminal_task.await?;
2252 this.update(cx, |this, _cx| {
2253 this.terminals.insert(terminal_id, terminal.clone());
2254 terminal
2255 })
2256 })
2257 }
2258
2259 pub fn kill_terminal(
2260 &mut self,
2261 terminal_id: acp::TerminalId,
2262 cx: &mut Context<Self>,
2263 ) -> Result<()> {
2264 self.terminals
2265 .get(&terminal_id)
2266 .context("Terminal not found")?
2267 .update(cx, |terminal, cx| {
2268 terminal.kill(cx);
2269 });
2270
2271 Ok(())
2272 }
2273
2274 pub fn release_terminal(
2275 &mut self,
2276 terminal_id: acp::TerminalId,
2277 cx: &mut Context<Self>,
2278 ) -> Result<()> {
2279 self.terminals
2280 .remove(&terminal_id)
2281 .context("Terminal not found")?
2282 .update(cx, |terminal, cx| {
2283 terminal.kill(cx);
2284 });
2285
2286 Ok(())
2287 }
2288
2289 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2290 self.terminals
2291 .get(&terminal_id)
2292 .context("Terminal not found")
2293 .cloned()
2294 }
2295
2296 pub fn to_markdown(&self, cx: &App) -> String {
2297 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2298 }
2299
2300 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2301 cx.emit(AcpThreadEvent::LoadError(error));
2302 }
2303
2304 pub fn register_terminal_created(
2305 &mut self,
2306 terminal_id: acp::TerminalId,
2307 command_label: String,
2308 working_dir: Option<PathBuf>,
2309 output_byte_limit: Option<u64>,
2310 terminal: Entity<::terminal::Terminal>,
2311 cx: &mut Context<Self>,
2312 ) -> Entity<Terminal> {
2313 let language_registry = self.project.read(cx).languages().clone();
2314
2315 let entity = cx.new(|cx| {
2316 Terminal::new(
2317 terminal_id.clone(),
2318 &command_label,
2319 working_dir.clone(),
2320 output_byte_limit.map(|l| l as usize),
2321 terminal,
2322 language_registry,
2323 cx,
2324 )
2325 });
2326 self.terminals.insert(terminal_id.clone(), entity.clone());
2327 entity
2328 }
2329}
2330
2331fn markdown_for_raw_output(
2332 raw_output: &serde_json::Value,
2333 language_registry: &Arc<LanguageRegistry>,
2334 cx: &mut App,
2335) -> Option<Entity<Markdown>> {
2336 match raw_output {
2337 serde_json::Value::Null => None,
2338 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2339 Markdown::new(
2340 value.to_string().into(),
2341 Some(language_registry.clone()),
2342 None,
2343 cx,
2344 )
2345 })),
2346 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2347 Markdown::new(
2348 value.to_string().into(),
2349 Some(language_registry.clone()),
2350 None,
2351 cx,
2352 )
2353 })),
2354 serde_json::Value::String(value) => Some(cx.new(|cx| {
2355 Markdown::new(
2356 value.clone().into(),
2357 Some(language_registry.clone()),
2358 None,
2359 cx,
2360 )
2361 })),
2362 value => Some(cx.new(|cx| {
2363 Markdown::new(
2364 format!("```json\n{}\n```", value).into(),
2365 Some(language_registry.clone()),
2366 None,
2367 cx,
2368 )
2369 })),
2370 }
2371}
2372
2373#[cfg(test)]
2374mod tests {
2375 use super::*;
2376 use anyhow::anyhow;
2377 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2378 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2379 use indoc::indoc;
2380 use project::{FakeFs, Fs};
2381 use rand::{distr, prelude::*};
2382 use serde_json::json;
2383 use settings::SettingsStore;
2384 use smol::stream::StreamExt as _;
2385 use std::{
2386 any::Any,
2387 cell::RefCell,
2388 path::Path,
2389 rc::Rc,
2390 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2391 time::Duration,
2392 };
2393 use util::path;
2394
2395 fn init_test(cx: &mut TestAppContext) {
2396 env_logger::try_init().ok();
2397 cx.update(|cx| {
2398 let settings_store = SettingsStore::test(cx);
2399 cx.set_global(settings_store);
2400 });
2401 }
2402
2403 #[gpui::test]
2404 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2405 init_test(cx);
2406
2407 let fs = FakeFs::new(cx.executor());
2408 let project = Project::test(fs, [], cx).await;
2409 let connection = Rc::new(FakeAgentConnection::new());
2410 let thread = cx
2411 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2412 .await
2413 .unwrap();
2414
2415 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2416
2417 // Send Output BEFORE Created - should be buffered by acp_thread
2418 thread.update(cx, |thread, cx| {
2419 thread.on_terminal_provider_event(
2420 TerminalProviderEvent::Output {
2421 terminal_id: terminal_id.clone(),
2422 data: b"hello buffered".to_vec(),
2423 },
2424 cx,
2425 );
2426 });
2427
2428 // Create a display-only terminal and then send Created
2429 let lower = cx.new(|cx| {
2430 let builder = ::terminal::TerminalBuilder::new_display_only(
2431 ::terminal::terminal_settings::CursorShape::default(),
2432 ::terminal::terminal_settings::AlternateScroll::On,
2433 None,
2434 0,
2435 )
2436 .unwrap();
2437 builder.subscribe(cx)
2438 });
2439
2440 thread.update(cx, |thread, cx| {
2441 thread.on_terminal_provider_event(
2442 TerminalProviderEvent::Created {
2443 terminal_id: terminal_id.clone(),
2444 label: "Buffered Test".to_string(),
2445 cwd: None,
2446 output_byte_limit: None,
2447 terminal: lower.clone(),
2448 },
2449 cx,
2450 );
2451 });
2452
2453 // After Created, buffered Output should have been flushed into the renderer
2454 let content = thread.read_with(cx, |thread, cx| {
2455 let term = thread.terminal(terminal_id.clone()).unwrap();
2456 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2457 });
2458
2459 assert!(
2460 content.contains("hello buffered"),
2461 "expected buffered output to render, got: {content}"
2462 );
2463 }
2464
2465 #[gpui::test]
2466 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2467 init_test(cx);
2468
2469 let fs = FakeFs::new(cx.executor());
2470 let project = Project::test(fs, [], cx).await;
2471 let connection = Rc::new(FakeAgentConnection::new());
2472 let thread = cx
2473 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2474 .await
2475 .unwrap();
2476
2477 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2478
2479 // Send Output BEFORE Created
2480 thread.update(cx, |thread, cx| {
2481 thread.on_terminal_provider_event(
2482 TerminalProviderEvent::Output {
2483 terminal_id: terminal_id.clone(),
2484 data: b"pre-exit data".to_vec(),
2485 },
2486 cx,
2487 );
2488 });
2489
2490 // Send Exit BEFORE Created
2491 thread.update(cx, |thread, cx| {
2492 thread.on_terminal_provider_event(
2493 TerminalProviderEvent::Exit {
2494 terminal_id: terminal_id.clone(),
2495 status: acp::TerminalExitStatus {
2496 exit_code: Some(0),
2497 signal: None,
2498 meta: None,
2499 },
2500 },
2501 cx,
2502 );
2503 });
2504
2505 // Now create a display-only lower-level terminal and send Created
2506 let lower = cx.new(|cx| {
2507 let builder = ::terminal::TerminalBuilder::new_display_only(
2508 ::terminal::terminal_settings::CursorShape::default(),
2509 ::terminal::terminal_settings::AlternateScroll::On,
2510 None,
2511 0,
2512 )
2513 .unwrap();
2514 builder.subscribe(cx)
2515 });
2516
2517 thread.update(cx, |thread, cx| {
2518 thread.on_terminal_provider_event(
2519 TerminalProviderEvent::Created {
2520 terminal_id: terminal_id.clone(),
2521 label: "Buffered Exit Test".to_string(),
2522 cwd: None,
2523 output_byte_limit: None,
2524 terminal: lower.clone(),
2525 },
2526 cx,
2527 );
2528 });
2529
2530 // Output should be present after Created (flushed from buffer)
2531 let content = thread.read_with(cx, |thread, cx| {
2532 let term = thread.terminal(terminal_id.clone()).unwrap();
2533 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2534 });
2535
2536 assert!(
2537 content.contains("pre-exit data"),
2538 "expected pre-exit data to render, got: {content}"
2539 );
2540 }
2541
2542 #[gpui::test]
2543 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2544 init_test(cx);
2545
2546 let fs = FakeFs::new(cx.executor());
2547 let project = Project::test(fs, [], cx).await;
2548 let connection = Rc::new(FakeAgentConnection::new());
2549 let thread = cx
2550 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2551 .await
2552 .unwrap();
2553
2554 // Test creating a new user message
2555 thread.update(cx, |thread, cx| {
2556 thread.push_user_content_block(
2557 None,
2558 acp::ContentBlock::Text(acp::TextContent {
2559 annotations: None,
2560 text: "Hello, ".to_string(),
2561 meta: None,
2562 }),
2563 cx,
2564 );
2565 });
2566
2567 thread.update(cx, |thread, cx| {
2568 assert_eq!(thread.entries.len(), 1);
2569 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2570 assert_eq!(user_msg.id, None);
2571 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2572 } else {
2573 panic!("Expected UserMessage");
2574 }
2575 });
2576
2577 // Test appending to existing user message
2578 let message_1_id = UserMessageId::new();
2579 thread.update(cx, |thread, cx| {
2580 thread.push_user_content_block(
2581 Some(message_1_id.clone()),
2582 acp::ContentBlock::Text(acp::TextContent {
2583 annotations: None,
2584 text: "world!".to_string(),
2585 meta: None,
2586 }),
2587 cx,
2588 );
2589 });
2590
2591 thread.update(cx, |thread, cx| {
2592 assert_eq!(thread.entries.len(), 1);
2593 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2594 assert_eq!(user_msg.id, Some(message_1_id));
2595 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2596 } else {
2597 panic!("Expected UserMessage");
2598 }
2599 });
2600
2601 // Test creating new user message after assistant message
2602 thread.update(cx, |thread, cx| {
2603 thread.push_assistant_content_block(
2604 acp::ContentBlock::Text(acp::TextContent {
2605 annotations: None,
2606 text: "Assistant response".to_string(),
2607 meta: None,
2608 }),
2609 false,
2610 cx,
2611 );
2612 });
2613
2614 let message_2_id = UserMessageId::new();
2615 thread.update(cx, |thread, cx| {
2616 thread.push_user_content_block(
2617 Some(message_2_id.clone()),
2618 acp::ContentBlock::Text(acp::TextContent {
2619 annotations: None,
2620 text: "New user message".to_string(),
2621 meta: None,
2622 }),
2623 cx,
2624 );
2625 });
2626
2627 thread.update(cx, |thread, cx| {
2628 assert_eq!(thread.entries.len(), 3);
2629 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2630 assert_eq!(user_msg.id, Some(message_2_id));
2631 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2632 } else {
2633 panic!("Expected UserMessage at index 2");
2634 }
2635 });
2636 }
2637
2638 #[gpui::test]
2639 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2640 init_test(cx);
2641
2642 let fs = FakeFs::new(cx.executor());
2643 let project = Project::test(fs, [], cx).await;
2644 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2645 |_, thread, mut cx| {
2646 async move {
2647 thread.update(&mut cx, |thread, cx| {
2648 thread
2649 .handle_session_update(
2650 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk {
2651 content: "Thinking ".into(),
2652 meta: None,
2653 }),
2654 cx,
2655 )
2656 .unwrap();
2657 thread
2658 .handle_session_update(
2659 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk {
2660 content: "hard!".into(),
2661 meta: None,
2662 }),
2663 cx,
2664 )
2665 .unwrap();
2666 })?;
2667 Ok(acp::PromptResponse {
2668 stop_reason: acp::StopReason::EndTurn,
2669 meta: None,
2670 })
2671 }
2672 .boxed_local()
2673 },
2674 ));
2675
2676 let thread = cx
2677 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2678 .await
2679 .unwrap();
2680
2681 thread
2682 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2683 .await
2684 .unwrap();
2685
2686 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2687 assert_eq!(
2688 output,
2689 indoc! {r#"
2690 ## User
2691
2692 Hello from Zed!
2693
2694 ## Assistant
2695
2696 <thinking>
2697 Thinking hard!
2698 </thinking>
2699
2700 "#}
2701 );
2702 }
2703
2704 #[gpui::test]
2705 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2706 init_test(cx);
2707
2708 let fs = FakeFs::new(cx.executor());
2709 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2710 .await;
2711 let project = Project::test(fs.clone(), [], cx).await;
2712 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2713 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2714 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2715 move |_, thread, mut cx| {
2716 let read_file_tx = read_file_tx.clone();
2717 async move {
2718 let content = thread
2719 .update(&mut cx, |thread, cx| {
2720 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2721 })
2722 .unwrap()
2723 .await
2724 .unwrap();
2725 assert_eq!(content, "one\ntwo\nthree\n");
2726 read_file_tx.take().unwrap().send(()).unwrap();
2727 thread
2728 .update(&mut cx, |thread, cx| {
2729 thread.write_text_file(
2730 path!("/tmp/foo").into(),
2731 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2732 cx,
2733 )
2734 })
2735 .unwrap()
2736 .await
2737 .unwrap();
2738 Ok(acp::PromptResponse {
2739 stop_reason: acp::StopReason::EndTurn,
2740 meta: None,
2741 })
2742 }
2743 .boxed_local()
2744 },
2745 ));
2746
2747 let (worktree, pathbuf) = project
2748 .update(cx, |project, cx| {
2749 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2750 })
2751 .await
2752 .unwrap();
2753 let buffer = project
2754 .update(cx, |project, cx| {
2755 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2756 })
2757 .await
2758 .unwrap();
2759
2760 let thread = cx
2761 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2762 .await
2763 .unwrap();
2764
2765 let request = thread.update(cx, |thread, cx| {
2766 thread.send_raw("Extend the count in /tmp/foo", cx)
2767 });
2768 read_file_rx.await.ok();
2769 buffer.update(cx, |buffer, cx| {
2770 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2771 });
2772 cx.run_until_parked();
2773 assert_eq!(
2774 buffer.read_with(cx, |buffer, _| buffer.text()),
2775 "zero\none\ntwo\nthree\nfour\nfive\n"
2776 );
2777 assert_eq!(
2778 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2779 "zero\none\ntwo\nthree\nfour\nfive\n"
2780 );
2781 request.await.unwrap();
2782 }
2783
2784 #[gpui::test]
2785 async fn test_reading_from_line(cx: &mut TestAppContext) {
2786 init_test(cx);
2787
2788 let fs = FakeFs::new(cx.executor());
2789 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2790 .await;
2791 let project = Project::test(fs.clone(), [], cx).await;
2792 project
2793 .update(cx, |project, cx| {
2794 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2795 })
2796 .await
2797 .unwrap();
2798
2799 let connection = Rc::new(FakeAgentConnection::new());
2800
2801 let thread = cx
2802 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2803 .await
2804 .unwrap();
2805
2806 // Whole file
2807 let content = thread
2808 .update(cx, |thread, cx| {
2809 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2810 })
2811 .await
2812 .unwrap();
2813
2814 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2815
2816 // Only start line
2817 let content = thread
2818 .update(cx, |thread, cx| {
2819 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2820 })
2821 .await
2822 .unwrap();
2823
2824 assert_eq!(content, "three\nfour\n");
2825
2826 // Only limit
2827 let content = thread
2828 .update(cx, |thread, cx| {
2829 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2830 })
2831 .await
2832 .unwrap();
2833
2834 assert_eq!(content, "one\ntwo\n");
2835
2836 // Range
2837 let content = thread
2838 .update(cx, |thread, cx| {
2839 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2840 })
2841 .await
2842 .unwrap();
2843
2844 assert_eq!(content, "two\nthree\n");
2845
2846 // Invalid
2847 let err = thread
2848 .update(cx, |thread, cx| {
2849 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2850 })
2851 .await
2852 .unwrap_err();
2853
2854 assert_eq!(
2855 err.to_string(),
2856 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2857 );
2858 }
2859
2860 #[gpui::test]
2861 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2862 init_test(cx);
2863
2864 let fs = FakeFs::new(cx.executor());
2865 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2866 let project = Project::test(fs.clone(), [], cx).await;
2867 project
2868 .update(cx, |project, cx| {
2869 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2870 })
2871 .await
2872 .unwrap();
2873
2874 let connection = Rc::new(FakeAgentConnection::new());
2875
2876 let thread = cx
2877 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2878 .await
2879 .unwrap();
2880
2881 // Whole file
2882 let content = thread
2883 .update(cx, |thread, cx| {
2884 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2885 })
2886 .await
2887 .unwrap();
2888
2889 assert_eq!(content, "");
2890
2891 // Only start line
2892 let content = thread
2893 .update(cx, |thread, cx| {
2894 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2895 })
2896 .await
2897 .unwrap();
2898
2899 assert_eq!(content, "");
2900
2901 // Only limit
2902 let content = thread
2903 .update(cx, |thread, cx| {
2904 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2905 })
2906 .await
2907 .unwrap();
2908
2909 assert_eq!(content, "");
2910
2911 // Range
2912 let content = thread
2913 .update(cx, |thread, cx| {
2914 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2915 })
2916 .await
2917 .unwrap();
2918
2919 assert_eq!(content, "");
2920
2921 // Invalid
2922 let err = thread
2923 .update(cx, |thread, cx| {
2924 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2925 })
2926 .await
2927 .unwrap_err();
2928
2929 assert_eq!(
2930 err.to_string(),
2931 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2932 );
2933 }
2934 #[gpui::test]
2935 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2936 init_test(cx);
2937
2938 let fs = FakeFs::new(cx.executor());
2939 fs.insert_tree(path!("/tmp"), json!({})).await;
2940 let project = Project::test(fs.clone(), [], cx).await;
2941 project
2942 .update(cx, |project, cx| {
2943 project.find_or_create_worktree(path!("/tmp"), true, cx)
2944 })
2945 .await
2946 .unwrap();
2947
2948 let connection = Rc::new(FakeAgentConnection::new());
2949
2950 let thread = cx
2951 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2952 .await
2953 .unwrap();
2954
2955 // Out of project file
2956 let err = thread
2957 .update(cx, |thread, cx| {
2958 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2959 })
2960 .await
2961 .unwrap_err();
2962
2963 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2964 }
2965
2966 #[gpui::test]
2967 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2968 init_test(cx);
2969
2970 let fs = FakeFs::new(cx.executor());
2971 let project = Project::test(fs, [], cx).await;
2972 let id = acp::ToolCallId("test".into());
2973
2974 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2975 let id = id.clone();
2976 move |_, thread, mut cx| {
2977 let id = id.clone();
2978 async move {
2979 thread
2980 .update(&mut cx, |thread, cx| {
2981 thread.handle_session_update(
2982 acp::SessionUpdate::ToolCall(acp::ToolCall {
2983 id: id.clone(),
2984 title: "Label".into(),
2985 kind: acp::ToolKind::Fetch,
2986 status: acp::ToolCallStatus::InProgress,
2987 content: vec![],
2988 locations: vec![],
2989 raw_input: None,
2990 raw_output: None,
2991 meta: None,
2992 }),
2993 cx,
2994 )
2995 })
2996 .unwrap()
2997 .unwrap();
2998 Ok(acp::PromptResponse {
2999 stop_reason: acp::StopReason::EndTurn,
3000 meta: None,
3001 })
3002 }
3003 .boxed_local()
3004 }
3005 }));
3006
3007 let thread = cx
3008 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3009 .await
3010 .unwrap();
3011
3012 let request = thread.update(cx, |thread, cx| {
3013 thread.send_raw("Fetch https://example.com", cx)
3014 });
3015
3016 run_until_first_tool_call(&thread, cx).await;
3017
3018 thread.read_with(cx, |thread, _| {
3019 assert!(matches!(
3020 thread.entries[1],
3021 AgentThreadEntry::ToolCall(ToolCall {
3022 status: ToolCallStatus::InProgress,
3023 ..
3024 })
3025 ));
3026 });
3027
3028 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3029
3030 thread.read_with(cx, |thread, _| {
3031 assert!(matches!(
3032 &thread.entries[1],
3033 AgentThreadEntry::ToolCall(ToolCall {
3034 status: ToolCallStatus::Canceled,
3035 ..
3036 })
3037 ));
3038 });
3039
3040 thread
3041 .update(cx, |thread, cx| {
3042 thread.handle_session_update(
3043 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3044 id,
3045 fields: acp::ToolCallUpdateFields {
3046 status: Some(acp::ToolCallStatus::Completed),
3047 ..Default::default()
3048 },
3049 meta: None,
3050 }),
3051 cx,
3052 )
3053 })
3054 .unwrap();
3055
3056 request.await.unwrap();
3057
3058 thread.read_with(cx, |thread, _| {
3059 assert!(matches!(
3060 thread.entries[1],
3061 AgentThreadEntry::ToolCall(ToolCall {
3062 status: ToolCallStatus::Completed,
3063 ..
3064 })
3065 ));
3066 });
3067 }
3068
3069 #[gpui::test]
3070 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3071 init_test(cx);
3072 let fs = FakeFs::new(cx.background_executor.clone());
3073 fs.insert_tree(path!("/test"), json!({})).await;
3074 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3075
3076 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3077 move |_, thread, mut cx| {
3078 async move {
3079 thread
3080 .update(&mut cx, |thread, cx| {
3081 thread.handle_session_update(
3082 acp::SessionUpdate::ToolCall(acp::ToolCall {
3083 id: acp::ToolCallId("test".into()),
3084 title: "Label".into(),
3085 kind: acp::ToolKind::Edit,
3086 status: acp::ToolCallStatus::Completed,
3087 content: vec![acp::ToolCallContent::Diff {
3088 diff: acp::Diff {
3089 path: "/test/test.txt".into(),
3090 old_text: None,
3091 new_text: "foo".into(),
3092 meta: None,
3093 },
3094 }],
3095 locations: vec![],
3096 raw_input: None,
3097 raw_output: None,
3098 meta: None,
3099 }),
3100 cx,
3101 )
3102 })
3103 .unwrap()
3104 .unwrap();
3105 Ok(acp::PromptResponse {
3106 stop_reason: acp::StopReason::EndTurn,
3107 meta: None,
3108 })
3109 }
3110 .boxed_local()
3111 }
3112 }));
3113
3114 let thread = cx
3115 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3116 .await
3117 .unwrap();
3118
3119 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3120 .await
3121 .unwrap();
3122
3123 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3124 }
3125
3126 #[gpui::test(iterations = 10)]
3127 async fn test_checkpoints(cx: &mut TestAppContext) {
3128 init_test(cx);
3129 let fs = FakeFs::new(cx.background_executor.clone());
3130 fs.insert_tree(
3131 path!("/test"),
3132 json!({
3133 ".git": {}
3134 }),
3135 )
3136 .await;
3137 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3138
3139 let simulate_changes = Arc::new(AtomicBool::new(true));
3140 let next_filename = Arc::new(AtomicUsize::new(0));
3141 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3142 let simulate_changes = simulate_changes.clone();
3143 let next_filename = next_filename.clone();
3144 let fs = fs.clone();
3145 move |request, thread, mut cx| {
3146 let fs = fs.clone();
3147 let simulate_changes = simulate_changes.clone();
3148 let next_filename = next_filename.clone();
3149 async move {
3150 if simulate_changes.load(SeqCst) {
3151 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3152 fs.write(Path::new(&filename), b"").await?;
3153 }
3154
3155 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3156 panic!("expected text content block");
3157 };
3158 thread.update(&mut cx, |thread, cx| {
3159 thread
3160 .handle_session_update(
3161 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk {
3162 content: content.text.to_uppercase().into(),
3163 meta: None,
3164 }),
3165 cx,
3166 )
3167 .unwrap();
3168 })?;
3169 Ok(acp::PromptResponse {
3170 stop_reason: acp::StopReason::EndTurn,
3171 meta: None,
3172 })
3173 }
3174 .boxed_local()
3175 }
3176 }));
3177 let thread = cx
3178 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3179 .await
3180 .unwrap();
3181
3182 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3183 .await
3184 .unwrap();
3185 thread.read_with(cx, |thread, cx| {
3186 assert_eq!(
3187 thread.to_markdown(cx),
3188 indoc! {"
3189 ## User (checkpoint)
3190
3191 Lorem
3192
3193 ## Assistant
3194
3195 LOREM
3196
3197 "}
3198 );
3199 });
3200 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3201
3202 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3203 .await
3204 .unwrap();
3205 thread.read_with(cx, |thread, cx| {
3206 assert_eq!(
3207 thread.to_markdown(cx),
3208 indoc! {"
3209 ## User (checkpoint)
3210
3211 Lorem
3212
3213 ## Assistant
3214
3215 LOREM
3216
3217 ## User (checkpoint)
3218
3219 ipsum
3220
3221 ## Assistant
3222
3223 IPSUM
3224
3225 "}
3226 );
3227 });
3228 assert_eq!(
3229 fs.files(),
3230 vec![
3231 Path::new(path!("/test/file-0")),
3232 Path::new(path!("/test/file-1"))
3233 ]
3234 );
3235
3236 // Checkpoint isn't stored when there are no changes.
3237 simulate_changes.store(false, SeqCst);
3238 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3239 .await
3240 .unwrap();
3241 thread.read_with(cx, |thread, cx| {
3242 assert_eq!(
3243 thread.to_markdown(cx),
3244 indoc! {"
3245 ## User (checkpoint)
3246
3247 Lorem
3248
3249 ## Assistant
3250
3251 LOREM
3252
3253 ## User (checkpoint)
3254
3255 ipsum
3256
3257 ## Assistant
3258
3259 IPSUM
3260
3261 ## User
3262
3263 dolor
3264
3265 ## Assistant
3266
3267 DOLOR
3268
3269 "}
3270 );
3271 });
3272 assert_eq!(
3273 fs.files(),
3274 vec![
3275 Path::new(path!("/test/file-0")),
3276 Path::new(path!("/test/file-1"))
3277 ]
3278 );
3279
3280 // Rewinding the conversation truncates the history and restores the checkpoint.
3281 thread
3282 .update(cx, |thread, cx| {
3283 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3284 panic!("unexpected entries {:?}", thread.entries)
3285 };
3286 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3287 })
3288 .await
3289 .unwrap();
3290 thread.read_with(cx, |thread, cx| {
3291 assert_eq!(
3292 thread.to_markdown(cx),
3293 indoc! {"
3294 ## User (checkpoint)
3295
3296 Lorem
3297
3298 ## Assistant
3299
3300 LOREM
3301
3302 "}
3303 );
3304 });
3305 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3306 }
3307
3308 #[gpui::test]
3309 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3310 use std::sync::atomic::AtomicUsize;
3311 init_test(cx);
3312
3313 let fs = FakeFs::new(cx.executor());
3314 let project = Project::test(fs, None, cx).await;
3315
3316 // Create a connection that simulates refusal after tool result
3317 let prompt_count = Arc::new(AtomicUsize::new(0));
3318 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3319 let prompt_count = prompt_count.clone();
3320 move |_request, thread, mut cx| {
3321 let count = prompt_count.fetch_add(1, SeqCst);
3322 async move {
3323 if count == 0 {
3324 // First prompt: Generate a tool call with result
3325 thread.update(&mut cx, |thread, cx| {
3326 thread
3327 .handle_session_update(
3328 acp::SessionUpdate::ToolCall(acp::ToolCall {
3329 id: acp::ToolCallId("tool1".into()),
3330 title: "Test Tool".into(),
3331 kind: acp::ToolKind::Fetch,
3332 status: acp::ToolCallStatus::Completed,
3333 content: vec![],
3334 locations: vec![],
3335 raw_input: Some(serde_json::json!({"query": "test"})),
3336 raw_output: Some(
3337 serde_json::json!({"result": "inappropriate content"}),
3338 ),
3339 meta: None,
3340 }),
3341 cx,
3342 )
3343 .unwrap();
3344 })?;
3345
3346 // Now return refusal because of the tool result
3347 Ok(acp::PromptResponse {
3348 stop_reason: acp::StopReason::Refusal,
3349 meta: None,
3350 })
3351 } else {
3352 Ok(acp::PromptResponse {
3353 stop_reason: acp::StopReason::EndTurn,
3354 meta: None,
3355 })
3356 }
3357 }
3358 .boxed_local()
3359 }
3360 }));
3361
3362 let thread = cx
3363 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3364 .await
3365 .unwrap();
3366
3367 // Track if we see a Refusal event
3368 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3369 let saw_refusal_event_captured = saw_refusal_event.clone();
3370 thread.update(cx, |_thread, cx| {
3371 cx.subscribe(
3372 &thread,
3373 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3374 if matches!(event, AcpThreadEvent::Refusal) {
3375 *saw_refusal_event_captured.lock().unwrap() = true;
3376 }
3377 },
3378 )
3379 .detach();
3380 });
3381
3382 // Send a user message - this will trigger tool call and then refusal
3383 let send_task = thread.update(cx, |thread, cx| {
3384 thread.send(
3385 vec![acp::ContentBlock::Text(acp::TextContent {
3386 text: "Hello".into(),
3387 annotations: None,
3388 meta: None,
3389 })],
3390 cx,
3391 )
3392 });
3393 cx.background_executor.spawn(send_task).detach();
3394 cx.run_until_parked();
3395
3396 // Verify that:
3397 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3398 // 2. The user message was NOT truncated
3399 assert!(
3400 *saw_refusal_event.lock().unwrap(),
3401 "Refusal event should be emitted for tool result refusals"
3402 );
3403
3404 thread.read_with(cx, |thread, _| {
3405 let entries = thread.entries();
3406 assert!(entries.len() >= 2, "Should have user message and tool call");
3407
3408 // Verify user message is still there
3409 assert!(
3410 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3411 "User message should not be truncated"
3412 );
3413
3414 // Verify tool call is there with result
3415 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3416 assert!(
3417 tool_call.raw_output.is_some(),
3418 "Tool call should have output"
3419 );
3420 } else {
3421 panic!("Expected tool call at index 1");
3422 }
3423 });
3424 }
3425
3426 #[gpui::test]
3427 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3428 init_test(cx);
3429
3430 let fs = FakeFs::new(cx.executor());
3431 let project = Project::test(fs, None, cx).await;
3432
3433 let refuse_next = Arc::new(AtomicBool::new(false));
3434 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3435 let refuse_next = refuse_next.clone();
3436 move |_request, _thread, _cx| {
3437 if refuse_next.load(SeqCst) {
3438 async move {
3439 Ok(acp::PromptResponse {
3440 stop_reason: acp::StopReason::Refusal,
3441 meta: None,
3442 })
3443 }
3444 .boxed_local()
3445 } else {
3446 async move {
3447 Ok(acp::PromptResponse {
3448 stop_reason: acp::StopReason::EndTurn,
3449 meta: None,
3450 })
3451 }
3452 .boxed_local()
3453 }
3454 }
3455 }));
3456
3457 let thread = cx
3458 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3459 .await
3460 .unwrap();
3461
3462 // Track if we see a Refusal event
3463 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3464 let saw_refusal_event_captured = saw_refusal_event.clone();
3465 thread.update(cx, |_thread, cx| {
3466 cx.subscribe(
3467 &thread,
3468 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3469 if matches!(event, AcpThreadEvent::Refusal) {
3470 *saw_refusal_event_captured.lock().unwrap() = true;
3471 }
3472 },
3473 )
3474 .detach();
3475 });
3476
3477 // Send a message that will be refused
3478 refuse_next.store(true, SeqCst);
3479 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3480 .await
3481 .unwrap();
3482
3483 // Verify that a Refusal event WAS emitted for user prompt refusal
3484 assert!(
3485 *saw_refusal_event.lock().unwrap(),
3486 "Refusal event should be emitted for user prompt refusals"
3487 );
3488
3489 // Verify the message was truncated (user prompt refusal)
3490 thread.read_with(cx, |thread, cx| {
3491 assert_eq!(thread.to_markdown(cx), "");
3492 });
3493 }
3494
3495 #[gpui::test]
3496 async fn test_refusal(cx: &mut TestAppContext) {
3497 init_test(cx);
3498 let fs = FakeFs::new(cx.background_executor.clone());
3499 fs.insert_tree(path!("/"), json!({})).await;
3500 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3501
3502 let refuse_next = Arc::new(AtomicBool::new(false));
3503 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3504 let refuse_next = refuse_next.clone();
3505 move |request, thread, mut cx| {
3506 let refuse_next = refuse_next.clone();
3507 async move {
3508 if refuse_next.load(SeqCst) {
3509 return Ok(acp::PromptResponse {
3510 stop_reason: acp::StopReason::Refusal,
3511 meta: None,
3512 });
3513 }
3514
3515 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3516 panic!("expected text content block");
3517 };
3518 thread.update(&mut cx, |thread, cx| {
3519 thread
3520 .handle_session_update(
3521 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk {
3522 content: content.text.to_uppercase().into(),
3523 meta: None,
3524 }),
3525 cx,
3526 )
3527 .unwrap();
3528 })?;
3529 Ok(acp::PromptResponse {
3530 stop_reason: acp::StopReason::EndTurn,
3531 meta: None,
3532 })
3533 }
3534 .boxed_local()
3535 }
3536 }));
3537 let thread = cx
3538 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3539 .await
3540 .unwrap();
3541
3542 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3543 .await
3544 .unwrap();
3545 thread.read_with(cx, |thread, cx| {
3546 assert_eq!(
3547 thread.to_markdown(cx),
3548 indoc! {"
3549 ## User
3550
3551 hello
3552
3553 ## Assistant
3554
3555 HELLO
3556
3557 "}
3558 );
3559 });
3560
3561 // Simulate refusing the second message. The message should be truncated
3562 // when a user prompt is refused.
3563 refuse_next.store(true, SeqCst);
3564 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3565 .await
3566 .unwrap();
3567 thread.read_with(cx, |thread, cx| {
3568 assert_eq!(
3569 thread.to_markdown(cx),
3570 indoc! {"
3571 ## User
3572
3573 hello
3574
3575 ## Assistant
3576
3577 HELLO
3578
3579 "}
3580 );
3581 });
3582 }
3583
3584 async fn run_until_first_tool_call(
3585 thread: &Entity<AcpThread>,
3586 cx: &mut TestAppContext,
3587 ) -> usize {
3588 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3589
3590 let subscription = cx.update(|cx| {
3591 cx.subscribe(thread, move |thread, _, cx| {
3592 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3593 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3594 return tx.try_send(ix).unwrap();
3595 }
3596 }
3597 })
3598 });
3599
3600 select! {
3601 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3602 panic!("Timeout waiting for tool call")
3603 }
3604 ix = rx.next().fuse() => {
3605 drop(subscription);
3606 ix.unwrap()
3607 }
3608 }
3609 }
3610
3611 #[derive(Clone, Default)]
3612 struct FakeAgentConnection {
3613 auth_methods: Vec<acp::AuthMethod>,
3614 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3615 on_user_message: Option<
3616 Rc<
3617 dyn Fn(
3618 acp::PromptRequest,
3619 WeakEntity<AcpThread>,
3620 AsyncApp,
3621 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3622 + 'static,
3623 >,
3624 >,
3625 }
3626
3627 impl FakeAgentConnection {
3628 fn new() -> Self {
3629 Self {
3630 auth_methods: Vec::new(),
3631 on_user_message: None,
3632 sessions: Arc::default(),
3633 }
3634 }
3635
3636 #[expect(unused)]
3637 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3638 self.auth_methods = auth_methods;
3639 self
3640 }
3641
3642 fn on_user_message(
3643 mut self,
3644 handler: impl Fn(
3645 acp::PromptRequest,
3646 WeakEntity<AcpThread>,
3647 AsyncApp,
3648 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3649 + 'static,
3650 ) -> Self {
3651 self.on_user_message.replace(Rc::new(handler));
3652 self
3653 }
3654 }
3655
3656 impl AgentConnection for FakeAgentConnection {
3657 fn telemetry_id(&self) -> &'static str {
3658 "fake"
3659 }
3660
3661 fn auth_methods(&self) -> &[acp::AuthMethod] {
3662 &self.auth_methods
3663 }
3664
3665 fn new_thread(
3666 self: Rc<Self>,
3667 project: Entity<Project>,
3668 _cwd: &Path,
3669 cx: &mut App,
3670 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3671 let session_id = acp::SessionId(
3672 rand::rng()
3673 .sample_iter(&distr::Alphanumeric)
3674 .take(7)
3675 .map(char::from)
3676 .collect::<String>()
3677 .into(),
3678 );
3679 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3680 let thread = cx.new(|cx| {
3681 AcpThread::new(
3682 "Test",
3683 self.clone(),
3684 project,
3685 action_log,
3686 session_id.clone(),
3687 watch::Receiver::constant(acp::PromptCapabilities {
3688 image: true,
3689 audio: true,
3690 embedded_context: true,
3691 meta: None,
3692 }),
3693 cx,
3694 )
3695 });
3696 self.sessions.lock().insert(session_id, thread.downgrade());
3697 Task::ready(Ok(thread))
3698 }
3699
3700 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3701 if self.auth_methods().iter().any(|m| m.id == method) {
3702 Task::ready(Ok(()))
3703 } else {
3704 Task::ready(Err(anyhow!("Invalid Auth Method")))
3705 }
3706 }
3707
3708 fn prompt(
3709 &self,
3710 _id: Option<UserMessageId>,
3711 params: acp::PromptRequest,
3712 cx: &mut App,
3713 ) -> Task<gpui::Result<acp::PromptResponse>> {
3714 let sessions = self.sessions.lock();
3715 let thread = sessions.get(¶ms.session_id).unwrap();
3716 if let Some(handler) = &self.on_user_message {
3717 let handler = handler.clone();
3718 let thread = thread.clone();
3719 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3720 } else {
3721 Task::ready(Ok(acp::PromptResponse {
3722 stop_reason: acp::StopReason::EndTurn,
3723 meta: None,
3724 }))
3725 }
3726 }
3727
3728 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3729 let sessions = self.sessions.lock();
3730 let thread = sessions.get(session_id).unwrap().clone();
3731
3732 cx.spawn(async move |cx| {
3733 thread
3734 .update(cx, |thread, cx| thread.cancel(cx))
3735 .unwrap()
3736 .await
3737 })
3738 .detach();
3739 }
3740
3741 fn truncate(
3742 &self,
3743 session_id: &acp::SessionId,
3744 _cx: &App,
3745 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3746 Some(Rc::new(FakeAgentSessionEditor {
3747 _session_id: session_id.clone(),
3748 }))
3749 }
3750
3751 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3752 self
3753 }
3754 }
3755
3756 struct FakeAgentSessionEditor {
3757 _session_id: acp::SessionId,
3758 }
3759
3760 impl AgentSessionTruncate for FakeAgentSessionEditor {
3761 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3762 Task::ready(Ok(()))
3763 }
3764 }
3765
3766 #[gpui::test]
3767 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3768 init_test(cx);
3769
3770 let fs = FakeFs::new(cx.executor());
3771 let project = Project::test(fs, [], cx).await;
3772 let connection = Rc::new(FakeAgentConnection::new());
3773 let thread = cx
3774 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3775 .await
3776 .unwrap();
3777
3778 // Try to update a tool call that doesn't exist
3779 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3780 thread.update(cx, |thread, cx| {
3781 let result = thread.handle_session_update(
3782 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3783 id: nonexistent_id.clone(),
3784 fields: acp::ToolCallUpdateFields {
3785 status: Some(acp::ToolCallStatus::Completed),
3786 ..Default::default()
3787 },
3788 meta: None,
3789 }),
3790 cx,
3791 );
3792
3793 // The update should succeed (not return an error)
3794 assert!(result.is_ok());
3795
3796 // There should now be exactly one entry in the thread
3797 assert_eq!(thread.entries.len(), 1);
3798
3799 // The entry should be a failed tool call
3800 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3801 assert_eq!(tool_call.id, nonexistent_id);
3802 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3803 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3804
3805 // Check that the content contains the error message
3806 assert_eq!(tool_call.content.len(), 1);
3807 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3808 match content_block {
3809 ContentBlock::Markdown { markdown } => {
3810 let markdown_text = markdown.read(cx).source();
3811 assert!(markdown_text.contains("Tool call not found"));
3812 }
3813 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3814 ContentBlock::ResourceLink { .. } => {
3815 panic!("Expected markdown content, got resource link")
3816 }
3817 }
3818 } else {
3819 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3820 }
3821 } else {
3822 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3823 }
3824 });
3825 }
3826
3827 /// Tests that restoring a checkpoint properly cleans up terminals that were
3828 /// created after that checkpoint, and cancels any in-progress generation.
3829 ///
3830 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3831 /// that were started after that checkpoint should be terminated, and any in-progress
3832 /// AI generation should be canceled.
3833 #[gpui::test]
3834 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3835 init_test(cx);
3836
3837 let fs = FakeFs::new(cx.executor());
3838 let project = Project::test(fs, [], cx).await;
3839 let connection = Rc::new(FakeAgentConnection::new());
3840 let thread = cx
3841 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3842 .await
3843 .unwrap();
3844
3845 // Send first user message to create a checkpoint
3846 cx.update(|cx| {
3847 thread.update(cx, |thread, cx| {
3848 thread.send(vec!["first message".into()], cx)
3849 })
3850 })
3851 .await
3852 .unwrap();
3853
3854 // Send second message (creates another checkpoint) - we'll restore to this one
3855 cx.update(|cx| {
3856 thread.update(cx, |thread, cx| {
3857 thread.send(vec!["second message".into()], cx)
3858 })
3859 })
3860 .await
3861 .unwrap();
3862
3863 // Create 2 terminals BEFORE the checkpoint that have completed running
3864 let terminal_id_1 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
3865 let mock_terminal_1 = cx.new(|cx| {
3866 let builder = ::terminal::TerminalBuilder::new_display_only(
3867 ::terminal::terminal_settings::CursorShape::default(),
3868 ::terminal::terminal_settings::AlternateScroll::On,
3869 None,
3870 0,
3871 )
3872 .unwrap();
3873 builder.subscribe(cx)
3874 });
3875
3876 thread.update(cx, |thread, cx| {
3877 thread.on_terminal_provider_event(
3878 TerminalProviderEvent::Created {
3879 terminal_id: terminal_id_1.clone(),
3880 label: "echo 'first'".to_string(),
3881 cwd: Some(PathBuf::from("/test")),
3882 output_byte_limit: None,
3883 terminal: mock_terminal_1.clone(),
3884 },
3885 cx,
3886 );
3887 });
3888
3889 thread.update(cx, |thread, cx| {
3890 thread.on_terminal_provider_event(
3891 TerminalProviderEvent::Output {
3892 terminal_id: terminal_id_1.clone(),
3893 data: b"first\n".to_vec(),
3894 },
3895 cx,
3896 );
3897 });
3898
3899 thread.update(cx, |thread, cx| {
3900 thread.on_terminal_provider_event(
3901 TerminalProviderEvent::Exit {
3902 terminal_id: terminal_id_1.clone(),
3903 status: acp::TerminalExitStatus {
3904 exit_code: Some(0),
3905 signal: None,
3906 meta: None,
3907 },
3908 },
3909 cx,
3910 );
3911 });
3912
3913 let terminal_id_2 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
3914 let mock_terminal_2 = cx.new(|cx| {
3915 let builder = ::terminal::TerminalBuilder::new_display_only(
3916 ::terminal::terminal_settings::CursorShape::default(),
3917 ::terminal::terminal_settings::AlternateScroll::On,
3918 None,
3919 0,
3920 )
3921 .unwrap();
3922 builder.subscribe(cx)
3923 });
3924
3925 thread.update(cx, |thread, cx| {
3926 thread.on_terminal_provider_event(
3927 TerminalProviderEvent::Created {
3928 terminal_id: terminal_id_2.clone(),
3929 label: "echo 'second'".to_string(),
3930 cwd: Some(PathBuf::from("/test")),
3931 output_byte_limit: None,
3932 terminal: mock_terminal_2.clone(),
3933 },
3934 cx,
3935 );
3936 });
3937
3938 thread.update(cx, |thread, cx| {
3939 thread.on_terminal_provider_event(
3940 TerminalProviderEvent::Output {
3941 terminal_id: terminal_id_2.clone(),
3942 data: b"second\n".to_vec(),
3943 },
3944 cx,
3945 );
3946 });
3947
3948 thread.update(cx, |thread, cx| {
3949 thread.on_terminal_provider_event(
3950 TerminalProviderEvent::Exit {
3951 terminal_id: terminal_id_2.clone(),
3952 status: acp::TerminalExitStatus {
3953 exit_code: Some(0),
3954 signal: None,
3955 meta: None,
3956 },
3957 },
3958 cx,
3959 );
3960 });
3961
3962 // Get the second message ID to restore to
3963 let second_message_id = thread.read_with(cx, |thread, _| {
3964 // At this point we have:
3965 // - Index 0: First user message (with checkpoint)
3966 // - Index 1: Second user message (with checkpoint)
3967 // No assistant responses because FakeAgentConnection just returns EndTurn
3968 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3969 panic!("expected user message at index 1");
3970 };
3971 message.id.clone().unwrap()
3972 });
3973
3974 // Create a terminal AFTER the checkpoint we'll restore to.
3975 // This simulates the AI agent starting a long-running terminal command.
3976 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
3977 let mock_terminal = cx.new(|cx| {
3978 let builder = ::terminal::TerminalBuilder::new_display_only(
3979 ::terminal::terminal_settings::CursorShape::default(),
3980 ::terminal::terminal_settings::AlternateScroll::On,
3981 None,
3982 0,
3983 )
3984 .unwrap();
3985 builder.subscribe(cx)
3986 });
3987
3988 // Register the terminal as created
3989 thread.update(cx, |thread, cx| {
3990 thread.on_terminal_provider_event(
3991 TerminalProviderEvent::Created {
3992 terminal_id: terminal_id.clone(),
3993 label: "sleep 1000".to_string(),
3994 cwd: Some(PathBuf::from("/test")),
3995 output_byte_limit: None,
3996 terminal: mock_terminal.clone(),
3997 },
3998 cx,
3999 );
4000 });
4001
4002 // Simulate the terminal producing output (still running)
4003 thread.update(cx, |thread, cx| {
4004 thread.on_terminal_provider_event(
4005 TerminalProviderEvent::Output {
4006 terminal_id: terminal_id.clone(),
4007 data: b"terminal is running...\n".to_vec(),
4008 },
4009 cx,
4010 );
4011 });
4012
4013 // Create a tool call entry that references this terminal
4014 // This represents the agent requesting a terminal command
4015 thread.update(cx, |thread, cx| {
4016 thread
4017 .handle_session_update(
4018 acp::SessionUpdate::ToolCall(acp::ToolCall {
4019 id: acp::ToolCallId("terminal-tool-1".into()),
4020 title: "Running command".into(),
4021 kind: acp::ToolKind::Execute,
4022 status: acp::ToolCallStatus::InProgress,
4023 content: vec![acp::ToolCallContent::Terminal {
4024 terminal_id: terminal_id.clone(),
4025 }],
4026 locations: vec![],
4027 raw_input: Some(
4028 serde_json::json!({"command": "sleep 1000", "cd": "/test"}),
4029 ),
4030 raw_output: None,
4031 meta: None,
4032 }),
4033 cx,
4034 )
4035 .unwrap();
4036 });
4037
4038 // Verify terminal exists and is in the thread
4039 let terminal_exists_before =
4040 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4041 assert!(
4042 terminal_exists_before,
4043 "Terminal should exist before checkpoint restore"
4044 );
4045
4046 // Verify the terminal's underlying task is still running (not completed)
4047 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4048 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4049 terminal_entity.read_with(cx, |term, _cx| {
4050 term.output().is_none() // output is None means it's still running
4051 })
4052 });
4053 assert!(
4054 terminal_running_before,
4055 "Terminal should be running before checkpoint restore"
4056 );
4057
4058 // Verify we have the expected entries before restore
4059 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4060 assert!(
4061 entry_count_before > 1,
4062 "Should have multiple entries before restore"
4063 );
4064
4065 // Restore the checkpoint to the second message.
4066 // This should:
4067 // 1. Cancel any in-progress generation (via the cancel() call)
4068 // 2. Remove the terminal that was created after that point
4069 thread
4070 .update(cx, |thread, cx| {
4071 thread.restore_checkpoint(second_message_id, cx)
4072 })
4073 .await
4074 .unwrap();
4075
4076 // Verify that no send_task is in progress after restore
4077 // (cancel() clears the send_task)
4078 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4079 assert!(
4080 !has_send_task_after,
4081 "Should not have a send_task after restore (cancel should have cleared it)"
4082 );
4083
4084 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4085 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4086 assert_eq!(
4087 entry_count, 1,
4088 "Should have 1 entry after restore (only the first user message)"
4089 );
4090
4091 // Verify the 2 completed terminals from before the checkpoint still exist
4092 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4093 thread.terminals.contains_key(&terminal_id_1)
4094 });
4095 assert!(
4096 terminal_1_exists,
4097 "Terminal 1 (from before checkpoint) should still exist"
4098 );
4099
4100 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4101 thread.terminals.contains_key(&terminal_id_2)
4102 });
4103 assert!(
4104 terminal_2_exists,
4105 "Terminal 2 (from before checkpoint) should still exist"
4106 );
4107
4108 // Verify they're still in completed state
4109 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4110 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4111 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4112 });
4113 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4114
4115 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4116 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4117 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4118 });
4119 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4120
4121 // Verify the running terminal (created after checkpoint) was removed
4122 let terminal_3_exists =
4123 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4124 assert!(
4125 !terminal_3_exists,
4126 "Terminal 3 (created after checkpoint) should have been removed"
4127 );
4128
4129 // Verify total count is 2 (the two from before the checkpoint)
4130 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4131 assert_eq!(
4132 terminal_count, 2,
4133 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4134 );
4135 }
4136}