1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer
379 .update(cx, |buffer, _| {
380 let snapshot = buffer.snapshot();
381 if let Some(row) = location.line {
382 let column = snapshot.indent_size_for_line(row).len;
383 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
384 snapshot.anchor_before(point)
385 } else {
386 Anchor::min_for_buffer(snapshot.remote_id())
387 }
388 })
389 .ok()?;
390
391 Some(ResolvedLocation { buffer, position })
392 }
393
394 fn resolve_locations(
395 &self,
396 project: Entity<Project>,
397 cx: &mut App,
398 ) -> Task<Vec<Option<ResolvedLocation>>> {
399 let locations = self.locations.clone();
400 project.update(cx, |_, cx| {
401 cx.spawn(async move |project, cx| {
402 let mut new_locations = Vec::new();
403 for location in locations {
404 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
405 }
406 new_locations
407 })
408 })
409 }
410}
411
412// Separate so we can hold a strong reference to the buffer
413// for saving on the thread
414#[derive(Clone, Debug, PartialEq, Eq)]
415struct ResolvedLocation {
416 buffer: Entity<Buffer>,
417 position: Anchor,
418}
419
420impl From<&ResolvedLocation> for AgentLocation {
421 fn from(value: &ResolvedLocation) -> Self {
422 Self {
423 buffer: value.buffer.downgrade(),
424 position: value.position,
425 }
426 }
427}
428
429#[derive(Debug)]
430pub enum ToolCallStatus {
431 /// The tool call hasn't started running yet, but we start showing it to
432 /// the user.
433 Pending,
434 /// The tool call is waiting for confirmation from the user.
435 WaitingForConfirmation {
436 options: Vec<acp::PermissionOption>,
437 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
438 },
439 /// The tool call is currently running.
440 InProgress,
441 /// The tool call completed successfully.
442 Completed,
443 /// The tool call failed.
444 Failed,
445 /// The user rejected the tool call.
446 Rejected,
447 /// The user canceled generation so the tool call was canceled.
448 Canceled,
449}
450
451impl From<acp::ToolCallStatus> for ToolCallStatus {
452 fn from(status: acp::ToolCallStatus) -> Self {
453 match status {
454 acp::ToolCallStatus::Pending => Self::Pending,
455 acp::ToolCallStatus::InProgress => Self::InProgress,
456 acp::ToolCallStatus::Completed => Self::Completed,
457 acp::ToolCallStatus::Failed => Self::Failed,
458 _ => Self::Pending,
459 }
460 }
461}
462
463impl Display for ToolCallStatus {
464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465 write!(
466 f,
467 "{}",
468 match self {
469 ToolCallStatus::Pending => "Pending",
470 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
471 ToolCallStatus::InProgress => "In Progress",
472 ToolCallStatus::Completed => "Completed",
473 ToolCallStatus::Failed => "Failed",
474 ToolCallStatus::Rejected => "Rejected",
475 ToolCallStatus::Canceled => "Canceled",
476 }
477 )
478 }
479}
480
481#[derive(Debug, PartialEq, Clone)]
482pub enum ContentBlock {
483 Empty,
484 Markdown { markdown: Entity<Markdown> },
485 ResourceLink { resource_link: acp::ResourceLink },
486}
487
488impl ContentBlock {
489 pub fn new(
490 block: acp::ContentBlock,
491 language_registry: &Arc<LanguageRegistry>,
492 path_style: PathStyle,
493 cx: &mut App,
494 ) -> Self {
495 let mut this = Self::Empty;
496 this.append(block, language_registry, path_style, cx);
497 this
498 }
499
500 pub fn new_combined(
501 blocks: impl IntoIterator<Item = acp::ContentBlock>,
502 language_registry: Arc<LanguageRegistry>,
503 path_style: PathStyle,
504 cx: &mut App,
505 ) -> Self {
506 let mut this = Self::Empty;
507 for block in blocks {
508 this.append(block, &language_registry, path_style, cx);
509 }
510 this
511 }
512
513 pub fn append(
514 &mut self,
515 block: acp::ContentBlock,
516 language_registry: &Arc<LanguageRegistry>,
517 path_style: PathStyle,
518 cx: &mut App,
519 ) {
520 if matches!(self, ContentBlock::Empty)
521 && let acp::ContentBlock::ResourceLink(resource_link) = block
522 {
523 *self = ContentBlock::ResourceLink { resource_link };
524 return;
525 }
526
527 let new_content = self.block_string_contents(block, path_style);
528
529 match self {
530 ContentBlock::Empty => {
531 *self = Self::create_markdown_block(new_content, language_registry, cx);
532 }
533 ContentBlock::Markdown { markdown } => {
534 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
535 }
536 ContentBlock::ResourceLink { resource_link } => {
537 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
538 let combined = format!("{}\n{}", existing_content, new_content);
539
540 *self = Self::create_markdown_block(combined, language_registry, cx);
541 }
542 }
543 }
544
545 fn create_markdown_block(
546 content: String,
547 language_registry: &Arc<LanguageRegistry>,
548 cx: &mut App,
549 ) -> ContentBlock {
550 ContentBlock::Markdown {
551 markdown: cx
552 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
553 }
554 }
555
556 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
557 match block {
558 acp::ContentBlock::Text(text_content) => text_content.text,
559 acp::ContentBlock::ResourceLink(resource_link) => {
560 Self::resource_link_md(&resource_link.uri, path_style)
561 }
562 acp::ContentBlock::Resource(acp::EmbeddedResource {
563 resource:
564 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
565 uri,
566 ..
567 }),
568 ..
569 }) => Self::resource_link_md(&uri, path_style),
570 acp::ContentBlock::Image(image) => Self::image_md(&image),
571 _ => String::new(),
572 }
573 }
574
575 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
576 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
577 uri.as_link().to_string()
578 } else {
579 uri.to_string()
580 }
581 }
582
583 fn image_md(_image: &acp::ImageContent) -> String {
584 "`Image`".into()
585 }
586
587 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
588 match self {
589 ContentBlock::Empty => "",
590 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
591 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
592 }
593 }
594
595 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
596 match self {
597 ContentBlock::Empty => None,
598 ContentBlock::Markdown { markdown } => Some(markdown),
599 ContentBlock::ResourceLink { .. } => None,
600 }
601 }
602
603 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
604 match self {
605 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
606 _ => None,
607 }
608 }
609}
610
611#[derive(Debug)]
612pub enum ToolCallContent {
613 ContentBlock(ContentBlock),
614 Diff(Entity<Diff>),
615 Terminal(Entity<Terminal>),
616}
617
618impl ToolCallContent {
619 pub fn from_acp(
620 content: acp::ToolCallContent,
621 language_registry: Arc<LanguageRegistry>,
622 path_style: PathStyle,
623 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
624 cx: &mut App,
625 ) -> Result<Option<Self>> {
626 match content {
627 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
628 Ok(Some(Self::ContentBlock(ContentBlock::new(
629 content,
630 &language_registry,
631 path_style,
632 cx,
633 ))))
634 }
635 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
636 Diff::finalized(
637 diff.path.to_string_lossy().into_owned(),
638 diff.old_text,
639 diff.new_text,
640 language_registry,
641 cx,
642 )
643 })))),
644 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
645 .get(&terminal_id)
646 .cloned()
647 .map(|terminal| Some(Self::Terminal(terminal)))
648 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
649 _ => Ok(None),
650 }
651 }
652
653 pub fn update_from_acp(
654 &mut self,
655 new: acp::ToolCallContent,
656 language_registry: Arc<LanguageRegistry>,
657 path_style: PathStyle,
658 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
659 cx: &mut App,
660 ) -> Result<bool> {
661 let needs_update = match (&self, &new) {
662 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
663 old_diff.read(cx).needs_update(
664 new_diff.old_text.as_deref().unwrap_or(""),
665 &new_diff.new_text,
666 cx,
667 )
668 }
669 _ => true,
670 };
671
672 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
673 if needs_update {
674 *self = update;
675 }
676 Ok(true)
677 } else {
678 Ok(false)
679 }
680 }
681
682 pub fn to_markdown(&self, cx: &App) -> String {
683 match self {
684 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
685 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
686 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
687 }
688 }
689}
690
691#[derive(Debug, PartialEq)]
692pub enum ToolCallUpdate {
693 UpdateFields(acp::ToolCallUpdate),
694 UpdateDiff(ToolCallUpdateDiff),
695 UpdateTerminal(ToolCallUpdateTerminal),
696}
697
698impl ToolCallUpdate {
699 fn id(&self) -> &acp::ToolCallId {
700 match self {
701 Self::UpdateFields(update) => &update.tool_call_id,
702 Self::UpdateDiff(diff) => &diff.id,
703 Self::UpdateTerminal(terminal) => &terminal.id,
704 }
705 }
706}
707
708impl From<acp::ToolCallUpdate> for ToolCallUpdate {
709 fn from(update: acp::ToolCallUpdate) -> Self {
710 Self::UpdateFields(update)
711 }
712}
713
714impl From<ToolCallUpdateDiff> for ToolCallUpdate {
715 fn from(diff: ToolCallUpdateDiff) -> Self {
716 Self::UpdateDiff(diff)
717 }
718}
719
720#[derive(Debug, PartialEq)]
721pub struct ToolCallUpdateDiff {
722 pub id: acp::ToolCallId,
723 pub diff: Entity<Diff>,
724}
725
726impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
727 fn from(terminal: ToolCallUpdateTerminal) -> Self {
728 Self::UpdateTerminal(terminal)
729 }
730}
731
732#[derive(Debug, PartialEq)]
733pub struct ToolCallUpdateTerminal {
734 pub id: acp::ToolCallId,
735 pub terminal: Entity<Terminal>,
736}
737
738#[derive(Debug, Default)]
739pub struct Plan {
740 pub entries: Vec<PlanEntry>,
741}
742
743#[derive(Debug)]
744pub struct PlanStats<'a> {
745 pub in_progress_entry: Option<&'a PlanEntry>,
746 pub pending: u32,
747 pub completed: u32,
748}
749
750impl Plan {
751 pub fn is_empty(&self) -> bool {
752 self.entries.is_empty()
753 }
754
755 pub fn stats(&self) -> PlanStats<'_> {
756 let mut stats = PlanStats {
757 in_progress_entry: None,
758 pending: 0,
759 completed: 0,
760 };
761
762 for entry in &self.entries {
763 match &entry.status {
764 acp::PlanEntryStatus::Pending => {
765 stats.pending += 1;
766 }
767 acp::PlanEntryStatus::InProgress => {
768 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
769 }
770 acp::PlanEntryStatus::Completed => {
771 stats.completed += 1;
772 }
773 _ => {}
774 }
775 }
776
777 stats
778 }
779}
780
781#[derive(Debug)]
782pub struct PlanEntry {
783 pub content: Entity<Markdown>,
784 pub priority: acp::PlanEntryPriority,
785 pub status: acp::PlanEntryStatus,
786}
787
788impl PlanEntry {
789 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
790 Self {
791 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
792 priority: entry.priority,
793 status: entry.status,
794 }
795 }
796}
797
798#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
799pub struct TokenUsage {
800 pub max_tokens: u64,
801 pub used_tokens: u64,
802}
803
804impl TokenUsage {
805 pub fn ratio(&self) -> TokenUsageRatio {
806 #[cfg(debug_assertions)]
807 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
808 .unwrap_or("0.8".to_string())
809 .parse()
810 .unwrap();
811 #[cfg(not(debug_assertions))]
812 let warning_threshold: f32 = 0.8;
813
814 // When the maximum is unknown because there is no selected model,
815 // avoid showing the token limit warning.
816 if self.max_tokens == 0 {
817 TokenUsageRatio::Normal
818 } else if self.used_tokens >= self.max_tokens {
819 TokenUsageRatio::Exceeded
820 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
821 TokenUsageRatio::Warning
822 } else {
823 TokenUsageRatio::Normal
824 }
825 }
826}
827
828#[derive(Debug, Clone, PartialEq, Eq)]
829pub enum TokenUsageRatio {
830 Normal,
831 Warning,
832 Exceeded,
833}
834
835#[derive(Debug, Clone)]
836pub struct RetryStatus {
837 pub last_error: SharedString,
838 pub attempt: usize,
839 pub max_attempts: usize,
840 pub started_at: Instant,
841 pub duration: Duration,
842}
843
844pub struct AcpThread {
845 title: SharedString,
846 entries: Vec<AgentThreadEntry>,
847 plan: Plan,
848 project: Entity<Project>,
849 action_log: Entity<ActionLog>,
850 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
851 send_task: Option<Task<()>>,
852 connection: Rc<dyn AgentConnection>,
853 session_id: acp::SessionId,
854 token_usage: Option<TokenUsage>,
855 prompt_capabilities: acp::PromptCapabilities,
856 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
857 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
858 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
859 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
860}
861
862impl From<&AcpThread> for ActionLogTelemetry {
863 fn from(value: &AcpThread) -> Self {
864 Self {
865 agent_telemetry_id: value.connection().telemetry_id(),
866 session_id: value.session_id.0.clone(),
867 }
868 }
869}
870
871#[derive(Debug)]
872pub enum AcpThreadEvent {
873 NewEntry,
874 TitleUpdated,
875 TokenUsageUpdated,
876 EntryUpdated(usize),
877 EntriesRemoved(Range<usize>),
878 ToolAuthorizationRequired,
879 Retry(RetryStatus),
880 Stopped,
881 Error,
882 LoadError(LoadError),
883 PromptCapabilitiesUpdated,
884 Refusal,
885 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
886 ModeUpdated(acp::SessionModeId),
887}
888
889impl EventEmitter<AcpThreadEvent> for AcpThread {}
890
891#[derive(Debug, Clone)]
892pub enum TerminalProviderEvent {
893 Created {
894 terminal_id: acp::TerminalId,
895 label: String,
896 cwd: Option<PathBuf>,
897 output_byte_limit: Option<u64>,
898 terminal: Entity<::terminal::Terminal>,
899 },
900 Output {
901 terminal_id: acp::TerminalId,
902 data: Vec<u8>,
903 },
904 TitleChanged {
905 terminal_id: acp::TerminalId,
906 title: String,
907 },
908 Exit {
909 terminal_id: acp::TerminalId,
910 status: acp::TerminalExitStatus,
911 },
912}
913
914#[derive(Debug, Clone)]
915pub enum TerminalProviderCommand {
916 WriteInput {
917 terminal_id: acp::TerminalId,
918 bytes: Vec<u8>,
919 },
920 Resize {
921 terminal_id: acp::TerminalId,
922 cols: u16,
923 rows: u16,
924 },
925 Close {
926 terminal_id: acp::TerminalId,
927 },
928}
929
930impl AcpThread {
931 pub fn on_terminal_provider_event(
932 &mut self,
933 event: TerminalProviderEvent,
934 cx: &mut Context<Self>,
935 ) {
936 match event {
937 TerminalProviderEvent::Created {
938 terminal_id,
939 label,
940 cwd,
941 output_byte_limit,
942 terminal,
943 } => {
944 let entity = self.register_terminal_created(
945 terminal_id.clone(),
946 label,
947 cwd,
948 output_byte_limit,
949 terminal,
950 cx,
951 );
952
953 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
954 for data in chunks.drain(..) {
955 entity.update(cx, |term, cx| {
956 term.inner().update(cx, |inner, cx| {
957 inner.write_output(&data, cx);
958 })
959 });
960 }
961 }
962
963 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
964 entity.update(cx, |_term, cx| {
965 cx.notify();
966 });
967 }
968
969 cx.notify();
970 }
971 TerminalProviderEvent::Output { terminal_id, data } => {
972 if let Some(entity) = self.terminals.get(&terminal_id) {
973 entity.update(cx, |term, cx| {
974 term.inner().update(cx, |inner, cx| {
975 inner.write_output(&data, cx);
976 })
977 });
978 } else {
979 self.pending_terminal_output
980 .entry(terminal_id)
981 .or_default()
982 .push(data);
983 }
984 }
985 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
986 if let Some(entity) = self.terminals.get(&terminal_id) {
987 entity.update(cx, |term, cx| {
988 term.inner().update(cx, |inner, cx| {
989 inner.breadcrumb_text = title;
990 cx.emit(::terminal::Event::BreadcrumbsChanged);
991 })
992 });
993 }
994 }
995 TerminalProviderEvent::Exit {
996 terminal_id,
997 status,
998 } => {
999 if let Some(entity) = self.terminals.get(&terminal_id) {
1000 entity.update(cx, |_term, cx| {
1001 cx.notify();
1002 });
1003 } else {
1004 self.pending_terminal_exit.insert(terminal_id, status);
1005 }
1006 }
1007 }
1008 }
1009}
1010
1011#[derive(PartialEq, Eq, Debug)]
1012pub enum ThreadStatus {
1013 Idle,
1014 Generating,
1015}
1016
1017#[derive(Debug, Clone)]
1018pub enum LoadError {
1019 Unsupported {
1020 command: SharedString,
1021 current_version: SharedString,
1022 minimum_version: SharedString,
1023 },
1024 FailedToInstall(SharedString),
1025 Exited {
1026 status: ExitStatus,
1027 },
1028 Other(SharedString),
1029}
1030
1031impl Display for LoadError {
1032 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1033 match self {
1034 LoadError::Unsupported {
1035 command: path,
1036 current_version,
1037 minimum_version,
1038 } => {
1039 write!(
1040 f,
1041 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1042 )
1043 }
1044 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1045 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1046 LoadError::Other(msg) => write!(f, "{msg}"),
1047 }
1048 }
1049}
1050
1051impl Error for LoadError {}
1052
1053impl AcpThread {
1054 pub fn new(
1055 title: impl Into<SharedString>,
1056 connection: Rc<dyn AgentConnection>,
1057 project: Entity<Project>,
1058 action_log: Entity<ActionLog>,
1059 session_id: acp::SessionId,
1060 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1061 cx: &mut Context<Self>,
1062 ) -> Self {
1063 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1064 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1065 loop {
1066 let caps = prompt_capabilities_rx.recv().await?;
1067 this.update(cx, |this, cx| {
1068 this.prompt_capabilities = caps;
1069 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1070 })?;
1071 }
1072 });
1073
1074 Self {
1075 action_log,
1076 shared_buffers: Default::default(),
1077 entries: Default::default(),
1078 plan: Default::default(),
1079 title: title.into(),
1080 project,
1081 send_task: None,
1082 connection,
1083 session_id,
1084 token_usage: None,
1085 prompt_capabilities,
1086 _observe_prompt_capabilities: task,
1087 terminals: HashMap::default(),
1088 pending_terminal_output: HashMap::default(),
1089 pending_terminal_exit: HashMap::default(),
1090 }
1091 }
1092
1093 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1094 self.prompt_capabilities.clone()
1095 }
1096
1097 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1098 &self.connection
1099 }
1100
1101 pub fn action_log(&self) -> &Entity<ActionLog> {
1102 &self.action_log
1103 }
1104
1105 pub fn project(&self) -> &Entity<Project> {
1106 &self.project
1107 }
1108
1109 pub fn title(&self) -> SharedString {
1110 self.title.clone()
1111 }
1112
1113 pub fn entries(&self) -> &[AgentThreadEntry] {
1114 &self.entries
1115 }
1116
1117 pub fn session_id(&self) -> &acp::SessionId {
1118 &self.session_id
1119 }
1120
1121 pub fn status(&self) -> ThreadStatus {
1122 if self.send_task.is_some() {
1123 ThreadStatus::Generating
1124 } else {
1125 ThreadStatus::Idle
1126 }
1127 }
1128
1129 pub fn token_usage(&self) -> Option<&TokenUsage> {
1130 self.token_usage.as_ref()
1131 }
1132
1133 pub fn has_pending_edit_tool_calls(&self) -> bool {
1134 for entry in self.entries.iter().rev() {
1135 match entry {
1136 AgentThreadEntry::UserMessage(_) => return false,
1137 AgentThreadEntry::ToolCall(
1138 call @ ToolCall {
1139 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1140 ..
1141 },
1142 ) if call.diffs().next().is_some() => {
1143 return true;
1144 }
1145 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1146 }
1147 }
1148
1149 false
1150 }
1151
1152 pub fn used_tools_since_last_user_message(&self) -> bool {
1153 for entry in self.entries.iter().rev() {
1154 match entry {
1155 AgentThreadEntry::UserMessage(..) => return false,
1156 AgentThreadEntry::AssistantMessage(..) => continue,
1157 AgentThreadEntry::ToolCall(..) => return true,
1158 }
1159 }
1160
1161 false
1162 }
1163
1164 pub fn handle_session_update(
1165 &mut self,
1166 update: acp::SessionUpdate,
1167 cx: &mut Context<Self>,
1168 ) -> Result<(), acp::Error> {
1169 match update {
1170 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1171 self.push_user_content_block(None, content, cx);
1172 }
1173 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1174 self.push_assistant_content_block(content, false, cx);
1175 }
1176 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1177 self.push_assistant_content_block(content, true, cx);
1178 }
1179 acp::SessionUpdate::ToolCall(tool_call) => {
1180 self.upsert_tool_call(tool_call, cx)?;
1181 }
1182 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1183 self.update_tool_call(tool_call_update, cx)?;
1184 }
1185 acp::SessionUpdate::Plan(plan) => {
1186 self.update_plan(plan, cx);
1187 }
1188 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1189 available_commands,
1190 ..
1191 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1192 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1193 current_mode_id,
1194 ..
1195 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1196 _ => {}
1197 }
1198 Ok(())
1199 }
1200
1201 pub fn push_user_content_block(
1202 &mut self,
1203 message_id: Option<UserMessageId>,
1204 chunk: acp::ContentBlock,
1205 cx: &mut Context<Self>,
1206 ) {
1207 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1208 }
1209
1210 pub fn push_user_content_block_with_indent(
1211 &mut self,
1212 message_id: Option<UserMessageId>,
1213 chunk: acp::ContentBlock,
1214 indented: bool,
1215 cx: &mut Context<Self>,
1216 ) {
1217 let language_registry = self.project.read(cx).languages().clone();
1218 let path_style = self.project.read(cx).path_style(cx);
1219 let entries_len = self.entries.len();
1220
1221 if let Some(last_entry) = self.entries.last_mut()
1222 && let AgentThreadEntry::UserMessage(UserMessage {
1223 id,
1224 content,
1225 chunks,
1226 indented: existing_indented,
1227 ..
1228 }) = last_entry
1229 && *existing_indented == indented
1230 {
1231 *id = message_id.or(id.take());
1232 content.append(chunk.clone(), &language_registry, path_style, cx);
1233 chunks.push(chunk);
1234 let idx = entries_len - 1;
1235 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1236 } else {
1237 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1238 self.push_entry(
1239 AgentThreadEntry::UserMessage(UserMessage {
1240 id: message_id,
1241 content,
1242 chunks: vec![chunk],
1243 checkpoint: None,
1244 indented,
1245 }),
1246 cx,
1247 );
1248 }
1249 }
1250
1251 pub fn push_assistant_content_block(
1252 &mut self,
1253 chunk: acp::ContentBlock,
1254 is_thought: bool,
1255 cx: &mut Context<Self>,
1256 ) {
1257 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1258 }
1259
1260 pub fn push_assistant_content_block_with_indent(
1261 &mut self,
1262 chunk: acp::ContentBlock,
1263 is_thought: bool,
1264 indented: bool,
1265 cx: &mut Context<Self>,
1266 ) {
1267 let language_registry = self.project.read(cx).languages().clone();
1268 let path_style = self.project.read(cx).path_style(cx);
1269 let entries_len = self.entries.len();
1270 if let Some(last_entry) = self.entries.last_mut()
1271 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1272 chunks,
1273 indented: existing_indented,
1274 }) = last_entry
1275 && *existing_indented == indented
1276 {
1277 let idx = entries_len - 1;
1278 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1279 match (chunks.last_mut(), is_thought) {
1280 (Some(AssistantMessageChunk::Message { block }), false)
1281 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1282 block.append(chunk, &language_registry, path_style, cx)
1283 }
1284 _ => {
1285 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1286 if is_thought {
1287 chunks.push(AssistantMessageChunk::Thought { block })
1288 } else {
1289 chunks.push(AssistantMessageChunk::Message { block })
1290 }
1291 }
1292 }
1293 } else {
1294 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1295 let chunk = if is_thought {
1296 AssistantMessageChunk::Thought { block }
1297 } else {
1298 AssistantMessageChunk::Message { block }
1299 };
1300
1301 self.push_entry(
1302 AgentThreadEntry::AssistantMessage(AssistantMessage {
1303 chunks: vec![chunk],
1304 indented,
1305 }),
1306 cx,
1307 );
1308 }
1309 }
1310
1311 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1312 self.entries.push(entry);
1313 cx.emit(AcpThreadEvent::NewEntry);
1314 }
1315
1316 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1317 self.connection.set_title(&self.session_id, cx).is_some()
1318 }
1319
1320 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1321 if title != self.title {
1322 self.title = title.clone();
1323 cx.emit(AcpThreadEvent::TitleUpdated);
1324 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1325 return set_title.run(title, cx);
1326 }
1327 }
1328 Task::ready(Ok(()))
1329 }
1330
1331 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1332 self.token_usage = usage;
1333 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1334 }
1335
1336 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1337 cx.emit(AcpThreadEvent::Retry(status));
1338 }
1339
1340 pub fn update_tool_call(
1341 &mut self,
1342 update: impl Into<ToolCallUpdate>,
1343 cx: &mut Context<Self>,
1344 ) -> Result<()> {
1345 let update = update.into();
1346 let languages = self.project.read(cx).languages().clone();
1347 let path_style = self.project.read(cx).path_style(cx);
1348
1349 let ix = match self.index_for_tool_call(update.id()) {
1350 Some(ix) => ix,
1351 None => {
1352 // Tool call not found - create a failed tool call entry
1353 let failed_tool_call = ToolCall {
1354 id: update.id().clone(),
1355 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1356 kind: acp::ToolKind::Fetch,
1357 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1358 "Tool call not found".into(),
1359 &languages,
1360 path_style,
1361 cx,
1362 ))],
1363 status: ToolCallStatus::Failed,
1364 locations: Vec::new(),
1365 resolved_locations: Vec::new(),
1366 raw_input: None,
1367 raw_input_markdown: None,
1368 raw_output: None,
1369 };
1370 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1371 return Ok(());
1372 }
1373 };
1374 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1375 unreachable!()
1376 };
1377
1378 match update {
1379 ToolCallUpdate::UpdateFields(update) => {
1380 let location_updated = update.fields.locations.is_some();
1381 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1382 if location_updated {
1383 self.resolve_locations(update.tool_call_id, cx);
1384 }
1385 }
1386 ToolCallUpdate::UpdateDiff(update) => {
1387 call.content.clear();
1388 call.content.push(ToolCallContent::Diff(update.diff));
1389 }
1390 ToolCallUpdate::UpdateTerminal(update) => {
1391 call.content.clear();
1392 call.content
1393 .push(ToolCallContent::Terminal(update.terminal));
1394 }
1395 }
1396
1397 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1398
1399 Ok(())
1400 }
1401
1402 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1403 pub fn upsert_tool_call(
1404 &mut self,
1405 tool_call: acp::ToolCall,
1406 cx: &mut Context<Self>,
1407 ) -> Result<(), acp::Error> {
1408 let status = tool_call.status.into();
1409 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1410 }
1411
1412 /// Fails if id does not match an existing entry.
1413 pub fn upsert_tool_call_inner(
1414 &mut self,
1415 update: acp::ToolCallUpdate,
1416 status: ToolCallStatus,
1417 cx: &mut Context<Self>,
1418 ) -> Result<(), acp::Error> {
1419 let language_registry = self.project.read(cx).languages().clone();
1420 let path_style = self.project.read(cx).path_style(cx);
1421 let id = update.tool_call_id.clone();
1422
1423 let agent_telemetry_id = self.connection().telemetry_id();
1424 let session = self.session_id();
1425 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1426 let status = if matches!(status, ToolCallStatus::Completed) {
1427 "completed"
1428 } else {
1429 "failed"
1430 };
1431 telemetry::event!(
1432 "Agent Tool Call Completed",
1433 agent_telemetry_id,
1434 session,
1435 status
1436 );
1437 }
1438
1439 if let Some(ix) = self.index_for_tool_call(&id) {
1440 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1441 unreachable!()
1442 };
1443
1444 call.update_fields(
1445 update.fields,
1446 language_registry,
1447 path_style,
1448 &self.terminals,
1449 cx,
1450 )?;
1451 call.status = status;
1452
1453 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1454 } else {
1455 let call = ToolCall::from_acp(
1456 update.try_into()?,
1457 status,
1458 language_registry,
1459 self.project.read(cx).path_style(cx),
1460 &self.terminals,
1461 cx,
1462 )?;
1463 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1464 };
1465
1466 self.resolve_locations(id, cx);
1467 Ok(())
1468 }
1469
1470 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1471 self.entries
1472 .iter()
1473 .enumerate()
1474 .rev()
1475 .find_map(|(index, entry)| {
1476 if let AgentThreadEntry::ToolCall(tool_call) = entry
1477 && &tool_call.id == id
1478 {
1479 Some(index)
1480 } else {
1481 None
1482 }
1483 })
1484 }
1485
1486 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1487 // The tool call we are looking for is typically the last one, or very close to the end.
1488 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1489 self.entries
1490 .iter_mut()
1491 .enumerate()
1492 .rev()
1493 .find_map(|(index, tool_call)| {
1494 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1495 && &tool_call.id == id
1496 {
1497 Some((index, tool_call))
1498 } else {
1499 None
1500 }
1501 })
1502 }
1503
1504 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1505 self.entries
1506 .iter()
1507 .enumerate()
1508 .rev()
1509 .find_map(|(index, tool_call)| {
1510 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1511 && &tool_call.id == id
1512 {
1513 Some((index, tool_call))
1514 } else {
1515 None
1516 }
1517 })
1518 }
1519
1520 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1521 let project = self.project.clone();
1522 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1523 return;
1524 };
1525 let task = tool_call.resolve_locations(project, cx);
1526 cx.spawn(async move |this, cx| {
1527 let resolved_locations = task.await;
1528
1529 this.update(cx, |this, cx| {
1530 let project = this.project.clone();
1531
1532 for location in resolved_locations.iter().flatten() {
1533 this.shared_buffers
1534 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1535 }
1536 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1537 return;
1538 };
1539
1540 if let Some(Some(location)) = resolved_locations.last() {
1541 project.update(cx, |project, cx| {
1542 let should_ignore = if let Some(agent_location) = project
1543 .agent_location()
1544 .filter(|agent_location| agent_location.buffer == location.buffer)
1545 {
1546 let snapshot = location.buffer.read(cx).snapshot();
1547 let old_position = agent_location.position.to_point(&snapshot);
1548 let new_position = location.position.to_point(&snapshot);
1549
1550 // ignore this so that when we get updates from the edit tool
1551 // the position doesn't reset to the startof line
1552 old_position.row == new_position.row
1553 && old_position.column > new_position.column
1554 } else {
1555 false
1556 };
1557 if !should_ignore {
1558 project.set_agent_location(Some(location.into()), cx);
1559 }
1560 });
1561 }
1562
1563 let resolved_locations = resolved_locations
1564 .iter()
1565 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1566 .collect::<Vec<_>>();
1567
1568 if tool_call.resolved_locations != resolved_locations {
1569 tool_call.resolved_locations = resolved_locations;
1570 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1571 }
1572 })
1573 })
1574 .detach();
1575 }
1576
1577 pub fn request_tool_call_authorization(
1578 &mut self,
1579 tool_call: acp::ToolCallUpdate,
1580 options: Vec<acp::PermissionOption>,
1581 respect_always_allow_setting: bool,
1582 cx: &mut Context<Self>,
1583 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1584 let (tx, rx) = oneshot::channel();
1585
1586 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1587 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1588 // some tools would (incorrectly) continue to auto-accept.
1589 if let Some(allow_once_option) = options.iter().find_map(|option| {
1590 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1591 Some(option.option_id.clone())
1592 } else {
1593 None
1594 }
1595 }) {
1596 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1597 return Ok(async {
1598 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1599 allow_once_option,
1600 ))
1601 }
1602 .boxed());
1603 }
1604 }
1605
1606 let status = ToolCallStatus::WaitingForConfirmation {
1607 options,
1608 respond_tx: tx,
1609 };
1610
1611 self.upsert_tool_call_inner(tool_call, status, cx)?;
1612 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1613
1614 let fut = async {
1615 match rx.await {
1616 Ok(option) => acp::RequestPermissionOutcome::Selected(
1617 acp::SelectedPermissionOutcome::new(option),
1618 ),
1619 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1620 }
1621 }
1622 .boxed();
1623
1624 Ok(fut)
1625 }
1626
1627 pub fn authorize_tool_call(
1628 &mut self,
1629 id: acp::ToolCallId,
1630 option_id: acp::PermissionOptionId,
1631 option_kind: acp::PermissionOptionKind,
1632 cx: &mut Context<Self>,
1633 ) {
1634 let Some((ix, call)) = self.tool_call_mut(&id) else {
1635 return;
1636 };
1637
1638 let new_status = match option_kind {
1639 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1640 ToolCallStatus::Rejected
1641 }
1642 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1643 ToolCallStatus::InProgress
1644 }
1645 _ => ToolCallStatus::InProgress,
1646 };
1647
1648 let curr_status = mem::replace(&mut call.status, new_status);
1649
1650 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1651 respond_tx.send(option_id).log_err();
1652 } else if cfg!(debug_assertions) {
1653 panic!("tried to authorize an already authorized tool call");
1654 }
1655
1656 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1657 }
1658
1659 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1660 let mut first_tool_call = None;
1661
1662 for entry in self.entries.iter().rev() {
1663 match &entry {
1664 AgentThreadEntry::ToolCall(call) => {
1665 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1666 first_tool_call = Some(call);
1667 } else {
1668 continue;
1669 }
1670 }
1671 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1672 // Reached the beginning of the turn.
1673 // If we had pending permission requests in the previous turn, they have been cancelled.
1674 break;
1675 }
1676 }
1677 }
1678
1679 first_tool_call
1680 }
1681
1682 pub fn plan(&self) -> &Plan {
1683 &self.plan
1684 }
1685
1686 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1687 let new_entries_len = request.entries.len();
1688 let mut new_entries = request.entries.into_iter();
1689
1690 // Reuse existing markdown to prevent flickering
1691 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1692 let PlanEntry {
1693 content,
1694 priority,
1695 status,
1696 } = old;
1697 content.update(cx, |old, cx| {
1698 old.replace(new.content, cx);
1699 });
1700 *priority = new.priority;
1701 *status = new.status;
1702 }
1703 for new in new_entries {
1704 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1705 }
1706 self.plan.entries.truncate(new_entries_len);
1707
1708 cx.notify();
1709 }
1710
1711 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1712 self.plan
1713 .entries
1714 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1715 cx.notify();
1716 }
1717
1718 #[cfg(any(test, feature = "test-support"))]
1719 pub fn send_raw(
1720 &mut self,
1721 message: &str,
1722 cx: &mut Context<Self>,
1723 ) -> BoxFuture<'static, Result<()>> {
1724 self.send(vec![message.into()], cx)
1725 }
1726
1727 pub fn send(
1728 &mut self,
1729 message: Vec<acp::ContentBlock>,
1730 cx: &mut Context<Self>,
1731 ) -> BoxFuture<'static, Result<()>> {
1732 let block = ContentBlock::new_combined(
1733 message.clone(),
1734 self.project.read(cx).languages().clone(),
1735 self.project.read(cx).path_style(cx),
1736 cx,
1737 );
1738 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1739 let git_store = self.project.read(cx).git_store().clone();
1740
1741 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1742 Some(UserMessageId::new())
1743 } else {
1744 None
1745 };
1746
1747 self.run_turn(cx, async move |this, cx| {
1748 this.update(cx, |this, cx| {
1749 this.push_entry(
1750 AgentThreadEntry::UserMessage(UserMessage {
1751 id: message_id.clone(),
1752 content: block,
1753 chunks: message,
1754 checkpoint: None,
1755 indented: false,
1756 }),
1757 cx,
1758 );
1759 })
1760 .ok();
1761
1762 let old_checkpoint = git_store
1763 .update(cx, |git, cx| git.checkpoint(cx))?
1764 .await
1765 .context("failed to get old checkpoint")
1766 .log_err();
1767 this.update(cx, |this, cx| {
1768 if let Some((_ix, message)) = this.last_user_message() {
1769 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1770 git_checkpoint,
1771 show: false,
1772 });
1773 }
1774 this.connection.prompt(message_id, request, cx)
1775 })?
1776 .await
1777 })
1778 }
1779
1780 pub fn can_resume(&self, cx: &App) -> bool {
1781 self.connection.resume(&self.session_id, cx).is_some()
1782 }
1783
1784 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1785 self.run_turn(cx, async move |this, cx| {
1786 this.update(cx, |this, cx| {
1787 this.connection
1788 .resume(&this.session_id, cx)
1789 .map(|resume| resume.run(cx))
1790 })?
1791 .context("resuming a session is not supported")?
1792 .await
1793 })
1794 }
1795
1796 fn run_turn(
1797 &mut self,
1798 cx: &mut Context<Self>,
1799 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1800 ) -> BoxFuture<'static, Result<()>> {
1801 self.clear_completed_plan_entries(cx);
1802
1803 let (tx, rx) = oneshot::channel();
1804 let cancel_task = self.cancel(cx);
1805
1806 self.send_task = Some(cx.spawn(async move |this, cx| {
1807 cancel_task.await;
1808 tx.send(f(this, cx).await).ok();
1809 }));
1810
1811 cx.spawn(async move |this, cx| {
1812 let response = rx.await;
1813
1814 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1815 .await?;
1816
1817 this.update(cx, |this, cx| {
1818 this.project
1819 .update(cx, |project, cx| project.set_agent_location(None, cx));
1820 match response {
1821 Ok(Err(e)) => {
1822 this.send_task.take();
1823 cx.emit(AcpThreadEvent::Error);
1824 Err(e)
1825 }
1826 result => {
1827 let canceled = matches!(
1828 result,
1829 Ok(Ok(acp::PromptResponse {
1830 stop_reason: acp::StopReason::Cancelled,
1831 ..
1832 }))
1833 );
1834
1835 // We only take the task if the current prompt wasn't canceled.
1836 //
1837 // This prompt may have been canceled because another one was sent
1838 // while it was still generating. In these cases, dropping `send_task`
1839 // would cause the next generation to be canceled.
1840 if !canceled {
1841 this.send_task.take();
1842 }
1843
1844 // Handle refusal - distinguish between user prompt and tool call refusals
1845 if let Ok(Ok(acp::PromptResponse {
1846 stop_reason: acp::StopReason::Refusal,
1847 ..
1848 })) = result
1849 {
1850 if let Some((user_msg_ix, _)) = this.last_user_message() {
1851 // Check if there's a completed tool call with results after the last user message
1852 // This indicates the refusal is in response to tool output, not the user's prompt
1853 let has_completed_tool_call_after_user_msg =
1854 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1855 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1856 // Check if the tool call has completed and has output
1857 matches!(tool_call.status, ToolCallStatus::Completed)
1858 && tool_call.raw_output.is_some()
1859 } else {
1860 false
1861 }
1862 });
1863
1864 if has_completed_tool_call_after_user_msg {
1865 // Refusal is due to tool output - don't truncate, just notify
1866 // The model refused based on what the tool returned
1867 cx.emit(AcpThreadEvent::Refusal);
1868 } else {
1869 // User prompt was refused - truncate back to before the user message
1870 let range = user_msg_ix..this.entries.len();
1871 if range.start < range.end {
1872 this.entries.truncate(user_msg_ix);
1873 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1874 }
1875 cx.emit(AcpThreadEvent::Refusal);
1876 }
1877 } else {
1878 // No user message found, treat as general refusal
1879 cx.emit(AcpThreadEvent::Refusal);
1880 }
1881 }
1882
1883 cx.emit(AcpThreadEvent::Stopped);
1884 Ok(())
1885 }
1886 }
1887 })?
1888 })
1889 .boxed()
1890 }
1891
1892 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1893 let Some(send_task) = self.send_task.take() else {
1894 return Task::ready(());
1895 };
1896
1897 for entry in self.entries.iter_mut() {
1898 if let AgentThreadEntry::ToolCall(call) = entry {
1899 let cancel = matches!(
1900 call.status,
1901 ToolCallStatus::Pending
1902 | ToolCallStatus::WaitingForConfirmation { .. }
1903 | ToolCallStatus::InProgress
1904 );
1905
1906 if cancel {
1907 call.status = ToolCallStatus::Canceled;
1908 }
1909 }
1910 }
1911
1912 self.connection.cancel(&self.session_id, cx);
1913
1914 // Wait for the send task to complete
1915 cx.foreground_executor().spawn(send_task)
1916 }
1917
1918 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1919 pub fn restore_checkpoint(
1920 &mut self,
1921 id: UserMessageId,
1922 cx: &mut Context<Self>,
1923 ) -> Task<Result<()>> {
1924 let Some((_, message)) = self.user_message_mut(&id) else {
1925 return Task::ready(Err(anyhow!("message not found")));
1926 };
1927
1928 let checkpoint = message
1929 .checkpoint
1930 .as_ref()
1931 .map(|c| c.git_checkpoint.clone());
1932
1933 // Cancel any in-progress generation before restoring
1934 let cancel_task = self.cancel(cx);
1935 let rewind = self.rewind(id.clone(), cx);
1936 let git_store = self.project.read(cx).git_store().clone();
1937
1938 cx.spawn(async move |_, cx| {
1939 cancel_task.await;
1940 rewind.await?;
1941 if let Some(checkpoint) = checkpoint {
1942 git_store
1943 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1944 .await?;
1945 }
1946
1947 Ok(())
1948 })
1949 }
1950
1951 /// Rewinds this thread to before the entry at `index`, removing it and all
1952 /// subsequent entries while rejecting any action_log changes made from that point.
1953 /// Unlike `restore_checkpoint`, this method does not restore from git.
1954 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1955 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1956 return Task::ready(Err(anyhow!("not supported")));
1957 };
1958
1959 let telemetry = ActionLogTelemetry::from(&*self);
1960 cx.spawn(async move |this, cx| {
1961 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1962 this.update(cx, |this, cx| {
1963 if let Some((ix, _)) = this.user_message_mut(&id) {
1964 // Collect all terminals from entries that will be removed
1965 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1966 .iter()
1967 .flat_map(|entry| entry.terminals())
1968 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1969 .collect();
1970
1971 let range = ix..this.entries.len();
1972 this.entries.truncate(ix);
1973 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1974
1975 // Kill and remove the terminals
1976 for terminal_id in terminals_to_remove {
1977 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1978 terminal.update(cx, |terminal, cx| {
1979 terminal.kill(cx);
1980 });
1981 }
1982 }
1983 }
1984 this.action_log().update(cx, |action_log, cx| {
1985 action_log.reject_all_edits(Some(telemetry), cx)
1986 })
1987 })?
1988 .await;
1989 Ok(())
1990 })
1991 }
1992
1993 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1994 let git_store = self.project.read(cx).git_store().clone();
1995
1996 let Some((_, message)) = self.last_user_message() else {
1997 return Task::ready(Ok(()));
1998 };
1999 let Some(user_message_id) = message.id.clone() else {
2000 return Task::ready(Ok(()));
2001 };
2002 let Some(checkpoint) = message.checkpoint.as_ref() else {
2003 return Task::ready(Ok(()));
2004 };
2005 let old_checkpoint = checkpoint.git_checkpoint.clone();
2006
2007 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2008 cx.spawn(async move |this, cx| {
2009 let Some(new_checkpoint) = new_checkpoint
2010 .await
2011 .context("failed to get new checkpoint")
2012 .log_err()
2013 else {
2014 return Ok(());
2015 };
2016
2017 let equal = git_store
2018 .update(cx, |git, cx| {
2019 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2020 })?
2021 .await
2022 .unwrap_or(true);
2023
2024 this.update(cx, |this, cx| {
2025 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2026 if let Some(checkpoint) = message.checkpoint.as_mut() {
2027 checkpoint.show = !equal;
2028 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2029 }
2030 }
2031 })?;
2032
2033 Ok(())
2034 })
2035 }
2036
2037 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2038 self.entries
2039 .iter_mut()
2040 .enumerate()
2041 .rev()
2042 .find_map(|(ix, entry)| {
2043 if let AgentThreadEntry::UserMessage(message) = entry {
2044 Some((ix, message))
2045 } else {
2046 None
2047 }
2048 })
2049 }
2050
2051 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2052 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2053 if let AgentThreadEntry::UserMessage(message) = entry {
2054 if message.id.as_ref() == Some(id) {
2055 Some((ix, message))
2056 } else {
2057 None
2058 }
2059 } else {
2060 None
2061 }
2062 })
2063 }
2064
2065 pub fn read_text_file(
2066 &self,
2067 path: PathBuf,
2068 line: Option<u32>,
2069 limit: Option<u32>,
2070 reuse_shared_snapshot: bool,
2071 cx: &mut Context<Self>,
2072 ) -> Task<Result<String, acp::Error>> {
2073 // Args are 1-based, move to 0-based
2074 let line = line.unwrap_or_default().saturating_sub(1);
2075 let limit = limit.unwrap_or(u32::MAX);
2076 let project = self.project.clone();
2077 let action_log = self.action_log.clone();
2078 cx.spawn(async move |this, cx| {
2079 let load = project
2080 .update(cx, |project, cx| {
2081 let path = project
2082 .project_path_for_absolute_path(&path, cx)
2083 .ok_or_else(|| {
2084 acp::Error::resource_not_found(Some(path.display().to_string()))
2085 })?;
2086 Ok(project.open_buffer(path, cx))
2087 })
2088 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2089 .flatten()?;
2090
2091 let buffer = load.await?;
2092
2093 let snapshot = if reuse_shared_snapshot {
2094 this.read_with(cx, |this, _| {
2095 this.shared_buffers.get(&buffer.clone()).cloned()
2096 })
2097 .log_err()
2098 .flatten()
2099 } else {
2100 None
2101 };
2102
2103 let snapshot = if let Some(snapshot) = snapshot {
2104 snapshot
2105 } else {
2106 action_log.update(cx, |action_log, cx| {
2107 action_log.buffer_read(buffer.clone(), cx);
2108 })?;
2109
2110 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2111 this.update(cx, |this, _| {
2112 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2113 })?;
2114 snapshot
2115 };
2116
2117 let max_point = snapshot.max_point();
2118 let start_position = Point::new(line, 0);
2119
2120 if start_position > max_point {
2121 return Err(acp::Error::invalid_params().data(format!(
2122 "Attempting to read beyond the end of the file, line {}:{}",
2123 max_point.row + 1,
2124 max_point.column
2125 )));
2126 }
2127
2128 let start = snapshot.anchor_before(start_position);
2129 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2130
2131 project.update(cx, |project, cx| {
2132 project.set_agent_location(
2133 Some(AgentLocation {
2134 buffer: buffer.downgrade(),
2135 position: start,
2136 }),
2137 cx,
2138 );
2139 })?;
2140
2141 Ok(snapshot.text_for_range(start..end).collect::<String>())
2142 })
2143 }
2144
2145 pub fn write_text_file(
2146 &self,
2147 path: PathBuf,
2148 content: String,
2149 cx: &mut Context<Self>,
2150 ) -> Task<Result<()>> {
2151 let project = self.project.clone();
2152 let action_log = self.action_log.clone();
2153 cx.spawn(async move |this, cx| {
2154 let load = project.update(cx, |project, cx| {
2155 let path = project
2156 .project_path_for_absolute_path(&path, cx)
2157 .context("invalid path")?;
2158 anyhow::Ok(project.open_buffer(path, cx))
2159 });
2160 let buffer = load??.await?;
2161 let snapshot = this.update(cx, |this, cx| {
2162 this.shared_buffers
2163 .get(&buffer)
2164 .cloned()
2165 .unwrap_or_else(|| buffer.read(cx).snapshot())
2166 })?;
2167 let edits = cx
2168 .background_executor()
2169 .spawn(async move {
2170 let old_text = snapshot.text();
2171 text_diff(old_text.as_str(), &content)
2172 .into_iter()
2173 .map(|(range, replacement)| {
2174 (
2175 snapshot.anchor_after(range.start)
2176 ..snapshot.anchor_before(range.end),
2177 replacement,
2178 )
2179 })
2180 .collect::<Vec<_>>()
2181 })
2182 .await;
2183
2184 project.update(cx, |project, cx| {
2185 project.set_agent_location(
2186 Some(AgentLocation {
2187 buffer: buffer.downgrade(),
2188 position: edits
2189 .last()
2190 .map(|(range, _)| range.end)
2191 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2192 }),
2193 cx,
2194 );
2195 })?;
2196
2197 let format_on_save = cx.update(|cx| {
2198 action_log.update(cx, |action_log, cx| {
2199 action_log.buffer_read(buffer.clone(), cx);
2200 });
2201
2202 let format_on_save = buffer.update(cx, |buffer, cx| {
2203 buffer.edit(edits, None, cx);
2204
2205 let settings = language::language_settings::language_settings(
2206 buffer.language().map(|l| l.name()),
2207 buffer.file(),
2208 cx,
2209 );
2210
2211 settings.format_on_save != FormatOnSave::Off
2212 });
2213 action_log.update(cx, |action_log, cx| {
2214 action_log.buffer_edited(buffer.clone(), cx);
2215 });
2216 format_on_save
2217 })?;
2218
2219 if format_on_save {
2220 let format_task = project.update(cx, |project, cx| {
2221 project.format(
2222 HashSet::from_iter([buffer.clone()]),
2223 LspFormatTarget::Buffers,
2224 false,
2225 FormatTrigger::Save,
2226 cx,
2227 )
2228 })?;
2229 format_task.await.log_err();
2230
2231 action_log.update(cx, |action_log, cx| {
2232 action_log.buffer_edited(buffer.clone(), cx);
2233 })?;
2234 }
2235
2236 project
2237 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2238 .await
2239 })
2240 }
2241
2242 pub fn create_terminal(
2243 &self,
2244 command: String,
2245 args: Vec<String>,
2246 extra_env: Vec<acp::EnvVariable>,
2247 cwd: Option<PathBuf>,
2248 output_byte_limit: Option<u64>,
2249 cx: &mut Context<Self>,
2250 ) -> Task<Result<Entity<Terminal>>> {
2251 let env = match &cwd {
2252 Some(dir) => self.project.update(cx, |project, cx| {
2253 project.environment().update(cx, |env, cx| {
2254 env.directory_environment(dir.as_path().into(), cx)
2255 })
2256 }),
2257 None => Task::ready(None).shared(),
2258 };
2259 let env = cx.spawn(async move |_, _| {
2260 let mut env = env.await.unwrap_or_default();
2261 // Disables paging for `git` and hopefully other commands
2262 env.insert("PAGER".into(), "".into());
2263 for var in extra_env {
2264 env.insert(var.name, var.value);
2265 }
2266 env
2267 });
2268
2269 let project = self.project.clone();
2270 let language_registry = project.read(cx).languages().clone();
2271 let is_windows = project.read(cx).path_style(cx).is_windows();
2272
2273 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2274 let terminal_task = cx.spawn({
2275 let terminal_id = terminal_id.clone();
2276 async move |_this, cx| {
2277 let env = env.await;
2278 let shell = project
2279 .update(cx, |project, cx| {
2280 project
2281 .remote_client()
2282 .and_then(|r| r.read(cx).default_system_shell())
2283 })?
2284 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2285 let (task_command, task_args) =
2286 ShellBuilder::new(&Shell::Program(shell), is_windows)
2287 .redirect_stdin_to_dev_null()
2288 .build(Some(command.clone()), &args);
2289 let terminal = project
2290 .update(cx, |project, cx| {
2291 project.create_terminal_task(
2292 task::SpawnInTerminal {
2293 command: Some(task_command),
2294 args: task_args,
2295 cwd: cwd.clone(),
2296 env,
2297 ..Default::default()
2298 },
2299 cx,
2300 )
2301 })?
2302 .await?;
2303
2304 cx.new(|cx| {
2305 Terminal::new(
2306 terminal_id,
2307 &format!("{} {}", command, args.join(" ")),
2308 cwd,
2309 output_byte_limit.map(|l| l as usize),
2310 terminal,
2311 language_registry,
2312 cx,
2313 )
2314 })
2315 }
2316 });
2317
2318 cx.spawn(async move |this, cx| {
2319 let terminal = terminal_task.await?;
2320 this.update(cx, |this, _cx| {
2321 this.terminals.insert(terminal_id, terminal.clone());
2322 terminal
2323 })
2324 })
2325 }
2326
2327 pub fn kill_terminal(
2328 &mut self,
2329 terminal_id: acp::TerminalId,
2330 cx: &mut Context<Self>,
2331 ) -> Result<()> {
2332 self.terminals
2333 .get(&terminal_id)
2334 .context("Terminal not found")?
2335 .update(cx, |terminal, cx| {
2336 terminal.kill(cx);
2337 });
2338
2339 Ok(())
2340 }
2341
2342 pub fn release_terminal(
2343 &mut self,
2344 terminal_id: acp::TerminalId,
2345 cx: &mut Context<Self>,
2346 ) -> Result<()> {
2347 self.terminals
2348 .remove(&terminal_id)
2349 .context("Terminal not found")?
2350 .update(cx, |terminal, cx| {
2351 terminal.kill(cx);
2352 });
2353
2354 Ok(())
2355 }
2356
2357 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2358 self.terminals
2359 .get(&terminal_id)
2360 .context("Terminal not found")
2361 .cloned()
2362 }
2363
2364 pub fn to_markdown(&self, cx: &App) -> String {
2365 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2366 }
2367
2368 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2369 cx.emit(AcpThreadEvent::LoadError(error));
2370 }
2371
2372 pub fn register_terminal_created(
2373 &mut self,
2374 terminal_id: acp::TerminalId,
2375 command_label: String,
2376 working_dir: Option<PathBuf>,
2377 output_byte_limit: Option<u64>,
2378 terminal: Entity<::terminal::Terminal>,
2379 cx: &mut Context<Self>,
2380 ) -> Entity<Terminal> {
2381 let language_registry = self.project.read(cx).languages().clone();
2382
2383 let entity = cx.new(|cx| {
2384 Terminal::new(
2385 terminal_id.clone(),
2386 &command_label,
2387 working_dir.clone(),
2388 output_byte_limit.map(|l| l as usize),
2389 terminal,
2390 language_registry,
2391 cx,
2392 )
2393 });
2394 self.terminals.insert(terminal_id.clone(), entity.clone());
2395 entity
2396 }
2397}
2398
2399fn markdown_for_raw_output(
2400 raw_output: &serde_json::Value,
2401 language_registry: &Arc<LanguageRegistry>,
2402 cx: &mut App,
2403) -> Option<Entity<Markdown>> {
2404 match raw_output {
2405 serde_json::Value::Null => None,
2406 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2407 Markdown::new(
2408 value.to_string().into(),
2409 Some(language_registry.clone()),
2410 None,
2411 cx,
2412 )
2413 })),
2414 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2415 Markdown::new(
2416 value.to_string().into(),
2417 Some(language_registry.clone()),
2418 None,
2419 cx,
2420 )
2421 })),
2422 serde_json::Value::String(value) => Some(cx.new(|cx| {
2423 Markdown::new(
2424 value.clone().into(),
2425 Some(language_registry.clone()),
2426 None,
2427 cx,
2428 )
2429 })),
2430 value => Some(cx.new(|cx| {
2431 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2432
2433 Markdown::new(
2434 format!("```json\n{}\n```", pretty_json).into(),
2435 Some(language_registry.clone()),
2436 None,
2437 cx,
2438 )
2439 })),
2440 }
2441}
2442
2443#[cfg(test)]
2444mod tests {
2445 use super::*;
2446 use anyhow::anyhow;
2447 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2448 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2449 use indoc::indoc;
2450 use project::{FakeFs, Fs};
2451 use rand::{distr, prelude::*};
2452 use serde_json::json;
2453 use settings::SettingsStore;
2454 use smol::stream::StreamExt as _;
2455 use std::{
2456 any::Any,
2457 cell::RefCell,
2458 path::Path,
2459 rc::Rc,
2460 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2461 time::Duration,
2462 };
2463 use util::path;
2464
2465 fn init_test(cx: &mut TestAppContext) {
2466 env_logger::try_init().ok();
2467 cx.update(|cx| {
2468 let settings_store = SettingsStore::test(cx);
2469 cx.set_global(settings_store);
2470 });
2471 }
2472
2473 #[gpui::test]
2474 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2475 init_test(cx);
2476
2477 let fs = FakeFs::new(cx.executor());
2478 let project = Project::test(fs, [], cx).await;
2479 let connection = Rc::new(FakeAgentConnection::new());
2480 let thread = cx
2481 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2482 .await
2483 .unwrap();
2484
2485 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2486
2487 // Send Output BEFORE Created - should be buffered by acp_thread
2488 thread.update(cx, |thread, cx| {
2489 thread.on_terminal_provider_event(
2490 TerminalProviderEvent::Output {
2491 terminal_id: terminal_id.clone(),
2492 data: b"hello buffered".to_vec(),
2493 },
2494 cx,
2495 );
2496 });
2497
2498 // Create a display-only terminal and then send Created
2499 let lower = cx.new(|cx| {
2500 let builder = ::terminal::TerminalBuilder::new_display_only(
2501 ::terminal::terminal_settings::CursorShape::default(),
2502 ::terminal::terminal_settings::AlternateScroll::On,
2503 None,
2504 0,
2505 )
2506 .unwrap();
2507 builder.subscribe(cx)
2508 });
2509
2510 thread.update(cx, |thread, cx| {
2511 thread.on_terminal_provider_event(
2512 TerminalProviderEvent::Created {
2513 terminal_id: terminal_id.clone(),
2514 label: "Buffered Test".to_string(),
2515 cwd: None,
2516 output_byte_limit: None,
2517 terminal: lower.clone(),
2518 },
2519 cx,
2520 );
2521 });
2522
2523 // After Created, buffered Output should have been flushed into the renderer
2524 let content = thread.read_with(cx, |thread, cx| {
2525 let term = thread.terminal(terminal_id.clone()).unwrap();
2526 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2527 });
2528
2529 assert!(
2530 content.contains("hello buffered"),
2531 "expected buffered output to render, got: {content}"
2532 );
2533 }
2534
2535 #[gpui::test]
2536 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2537 init_test(cx);
2538
2539 let fs = FakeFs::new(cx.executor());
2540 let project = Project::test(fs, [], cx).await;
2541 let connection = Rc::new(FakeAgentConnection::new());
2542 let thread = cx
2543 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2544 .await
2545 .unwrap();
2546
2547 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2548
2549 // Send Output BEFORE Created
2550 thread.update(cx, |thread, cx| {
2551 thread.on_terminal_provider_event(
2552 TerminalProviderEvent::Output {
2553 terminal_id: terminal_id.clone(),
2554 data: b"pre-exit data".to_vec(),
2555 },
2556 cx,
2557 );
2558 });
2559
2560 // Send Exit BEFORE Created
2561 thread.update(cx, |thread, cx| {
2562 thread.on_terminal_provider_event(
2563 TerminalProviderEvent::Exit {
2564 terminal_id: terminal_id.clone(),
2565 status: acp::TerminalExitStatus::new().exit_code(0),
2566 },
2567 cx,
2568 );
2569 });
2570
2571 // Now create a display-only lower-level terminal and send Created
2572 let lower = cx.new(|cx| {
2573 let builder = ::terminal::TerminalBuilder::new_display_only(
2574 ::terminal::terminal_settings::CursorShape::default(),
2575 ::terminal::terminal_settings::AlternateScroll::On,
2576 None,
2577 0,
2578 )
2579 .unwrap();
2580 builder.subscribe(cx)
2581 });
2582
2583 thread.update(cx, |thread, cx| {
2584 thread.on_terminal_provider_event(
2585 TerminalProviderEvent::Created {
2586 terminal_id: terminal_id.clone(),
2587 label: "Buffered Exit Test".to_string(),
2588 cwd: None,
2589 output_byte_limit: None,
2590 terminal: lower.clone(),
2591 },
2592 cx,
2593 );
2594 });
2595
2596 // Output should be present after Created (flushed from buffer)
2597 let content = thread.read_with(cx, |thread, cx| {
2598 let term = thread.terminal(terminal_id.clone()).unwrap();
2599 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2600 });
2601
2602 assert!(
2603 content.contains("pre-exit data"),
2604 "expected pre-exit data to render, got: {content}"
2605 );
2606 }
2607
2608 #[gpui::test]
2609 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2610 init_test(cx);
2611
2612 let fs = FakeFs::new(cx.executor());
2613 let project = Project::test(fs, [], cx).await;
2614 let connection = Rc::new(FakeAgentConnection::new());
2615 let thread = cx
2616 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2617 .await
2618 .unwrap();
2619
2620 // Test creating a new user message
2621 thread.update(cx, |thread, cx| {
2622 thread.push_user_content_block(None, "Hello, ".into(), cx);
2623 });
2624
2625 thread.update(cx, |thread, cx| {
2626 assert_eq!(thread.entries.len(), 1);
2627 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2628 assert_eq!(user_msg.id, None);
2629 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2630 } else {
2631 panic!("Expected UserMessage");
2632 }
2633 });
2634
2635 // Test appending to existing user message
2636 let message_1_id = UserMessageId::new();
2637 thread.update(cx, |thread, cx| {
2638 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2639 });
2640
2641 thread.update(cx, |thread, cx| {
2642 assert_eq!(thread.entries.len(), 1);
2643 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2644 assert_eq!(user_msg.id, Some(message_1_id));
2645 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2646 } else {
2647 panic!("Expected UserMessage");
2648 }
2649 });
2650
2651 // Test creating new user message after assistant message
2652 thread.update(cx, |thread, cx| {
2653 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2654 });
2655
2656 let message_2_id = UserMessageId::new();
2657 thread.update(cx, |thread, cx| {
2658 thread.push_user_content_block(
2659 Some(message_2_id.clone()),
2660 "New user message".into(),
2661 cx,
2662 );
2663 });
2664
2665 thread.update(cx, |thread, cx| {
2666 assert_eq!(thread.entries.len(), 3);
2667 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2668 assert_eq!(user_msg.id, Some(message_2_id));
2669 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2670 } else {
2671 panic!("Expected UserMessage at index 2");
2672 }
2673 });
2674 }
2675
2676 #[gpui::test]
2677 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2678 init_test(cx);
2679
2680 let fs = FakeFs::new(cx.executor());
2681 let project = Project::test(fs, [], cx).await;
2682 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2683 |_, thread, mut cx| {
2684 async move {
2685 thread.update(&mut cx, |thread, cx| {
2686 thread
2687 .handle_session_update(
2688 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2689 "Thinking ".into(),
2690 )),
2691 cx,
2692 )
2693 .unwrap();
2694 thread
2695 .handle_session_update(
2696 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2697 "hard!".into(),
2698 )),
2699 cx,
2700 )
2701 .unwrap();
2702 })?;
2703 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2704 }
2705 .boxed_local()
2706 },
2707 ));
2708
2709 let thread = cx
2710 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2711 .await
2712 .unwrap();
2713
2714 thread
2715 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2716 .await
2717 .unwrap();
2718
2719 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2720 assert_eq!(
2721 output,
2722 indoc! {r#"
2723 ## User
2724
2725 Hello from Zed!
2726
2727 ## Assistant
2728
2729 <thinking>
2730 Thinking hard!
2731 </thinking>
2732
2733 "#}
2734 );
2735 }
2736
2737 #[gpui::test]
2738 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2739 init_test(cx);
2740
2741 let fs = FakeFs::new(cx.executor());
2742 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2743 .await;
2744 let project = Project::test(fs.clone(), [], cx).await;
2745 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2746 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2747 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2748 move |_, thread, mut cx| {
2749 let read_file_tx = read_file_tx.clone();
2750 async move {
2751 let content = thread
2752 .update(&mut cx, |thread, cx| {
2753 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2754 })
2755 .unwrap()
2756 .await
2757 .unwrap();
2758 assert_eq!(content, "one\ntwo\nthree\n");
2759 read_file_tx.take().unwrap().send(()).unwrap();
2760 thread
2761 .update(&mut cx, |thread, cx| {
2762 thread.write_text_file(
2763 path!("/tmp/foo").into(),
2764 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2765 cx,
2766 )
2767 })
2768 .unwrap()
2769 .await
2770 .unwrap();
2771 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2772 }
2773 .boxed_local()
2774 },
2775 ));
2776
2777 let (worktree, pathbuf) = project
2778 .update(cx, |project, cx| {
2779 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2780 })
2781 .await
2782 .unwrap();
2783 let buffer = project
2784 .update(cx, |project, cx| {
2785 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2786 })
2787 .await
2788 .unwrap();
2789
2790 let thread = cx
2791 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2792 .await
2793 .unwrap();
2794
2795 let request = thread.update(cx, |thread, cx| {
2796 thread.send_raw("Extend the count in /tmp/foo", cx)
2797 });
2798 read_file_rx.await.ok();
2799 buffer.update(cx, |buffer, cx| {
2800 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2801 });
2802 cx.run_until_parked();
2803 assert_eq!(
2804 buffer.read_with(cx, |buffer, _| buffer.text()),
2805 "zero\none\ntwo\nthree\nfour\nfive\n"
2806 );
2807 assert_eq!(
2808 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2809 "zero\none\ntwo\nthree\nfour\nfive\n"
2810 );
2811 request.await.unwrap();
2812 }
2813
2814 #[gpui::test]
2815 async fn test_reading_from_line(cx: &mut TestAppContext) {
2816 init_test(cx);
2817
2818 let fs = FakeFs::new(cx.executor());
2819 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2820 .await;
2821 let project = Project::test(fs.clone(), [], cx).await;
2822 project
2823 .update(cx, |project, cx| {
2824 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2825 })
2826 .await
2827 .unwrap();
2828
2829 let connection = Rc::new(FakeAgentConnection::new());
2830
2831 let thread = cx
2832 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2833 .await
2834 .unwrap();
2835
2836 // Whole file
2837 let content = thread
2838 .update(cx, |thread, cx| {
2839 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2840 })
2841 .await
2842 .unwrap();
2843
2844 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2845
2846 // Only start line
2847 let content = thread
2848 .update(cx, |thread, cx| {
2849 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2850 })
2851 .await
2852 .unwrap();
2853
2854 assert_eq!(content, "three\nfour\n");
2855
2856 // Only limit
2857 let content = thread
2858 .update(cx, |thread, cx| {
2859 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2860 })
2861 .await
2862 .unwrap();
2863
2864 assert_eq!(content, "one\ntwo\n");
2865
2866 // Range
2867 let content = thread
2868 .update(cx, |thread, cx| {
2869 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2870 })
2871 .await
2872 .unwrap();
2873
2874 assert_eq!(content, "two\nthree\n");
2875
2876 // Invalid
2877 let err = thread
2878 .update(cx, |thread, cx| {
2879 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2880 })
2881 .await
2882 .unwrap_err();
2883
2884 assert_eq!(
2885 err.to_string(),
2886 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2887 );
2888 }
2889
2890 #[gpui::test]
2891 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2892 init_test(cx);
2893
2894 let fs = FakeFs::new(cx.executor());
2895 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2896 let project = Project::test(fs.clone(), [], cx).await;
2897 project
2898 .update(cx, |project, cx| {
2899 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2900 })
2901 .await
2902 .unwrap();
2903
2904 let connection = Rc::new(FakeAgentConnection::new());
2905
2906 let thread = cx
2907 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2908 .await
2909 .unwrap();
2910
2911 // Whole file
2912 let content = thread
2913 .update(cx, |thread, cx| {
2914 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2915 })
2916 .await
2917 .unwrap();
2918
2919 assert_eq!(content, "");
2920
2921 // Only start line
2922 let content = thread
2923 .update(cx, |thread, cx| {
2924 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2925 })
2926 .await
2927 .unwrap();
2928
2929 assert_eq!(content, "");
2930
2931 // Only limit
2932 let content = thread
2933 .update(cx, |thread, cx| {
2934 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2935 })
2936 .await
2937 .unwrap();
2938
2939 assert_eq!(content, "");
2940
2941 // Range
2942 let content = thread
2943 .update(cx, |thread, cx| {
2944 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2945 })
2946 .await
2947 .unwrap();
2948
2949 assert_eq!(content, "");
2950
2951 // Invalid
2952 let err = thread
2953 .update(cx, |thread, cx| {
2954 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2955 })
2956 .await
2957 .unwrap_err();
2958
2959 assert_eq!(
2960 err.to_string(),
2961 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2962 );
2963 }
2964 #[gpui::test]
2965 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2966 init_test(cx);
2967
2968 let fs = FakeFs::new(cx.executor());
2969 fs.insert_tree(path!("/tmp"), json!({})).await;
2970 let project = Project::test(fs.clone(), [], cx).await;
2971 project
2972 .update(cx, |project, cx| {
2973 project.find_or_create_worktree(path!("/tmp"), true, cx)
2974 })
2975 .await
2976 .unwrap();
2977
2978 let connection = Rc::new(FakeAgentConnection::new());
2979
2980 let thread = cx
2981 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2982 .await
2983 .unwrap();
2984
2985 // Out of project file
2986 let err = thread
2987 .update(cx, |thread, cx| {
2988 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2989 })
2990 .await
2991 .unwrap_err();
2992
2993 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
2994 }
2995
2996 #[gpui::test]
2997 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2998 init_test(cx);
2999
3000 let fs = FakeFs::new(cx.executor());
3001 let project = Project::test(fs, [], cx).await;
3002 let id = acp::ToolCallId::new("test");
3003
3004 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3005 let id = id.clone();
3006 move |_, thread, mut cx| {
3007 let id = id.clone();
3008 async move {
3009 thread
3010 .update(&mut cx, |thread, cx| {
3011 thread.handle_session_update(
3012 acp::SessionUpdate::ToolCall(
3013 acp::ToolCall::new(id.clone(), "Label")
3014 .kind(acp::ToolKind::Fetch)
3015 .status(acp::ToolCallStatus::InProgress),
3016 ),
3017 cx,
3018 )
3019 })
3020 .unwrap()
3021 .unwrap();
3022 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3023 }
3024 .boxed_local()
3025 }
3026 }));
3027
3028 let thread = cx
3029 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3030 .await
3031 .unwrap();
3032
3033 let request = thread.update(cx, |thread, cx| {
3034 thread.send_raw("Fetch https://example.com", cx)
3035 });
3036
3037 run_until_first_tool_call(&thread, cx).await;
3038
3039 thread.read_with(cx, |thread, _| {
3040 assert!(matches!(
3041 thread.entries[1],
3042 AgentThreadEntry::ToolCall(ToolCall {
3043 status: ToolCallStatus::InProgress,
3044 ..
3045 })
3046 ));
3047 });
3048
3049 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3050
3051 thread.read_with(cx, |thread, _| {
3052 assert!(matches!(
3053 &thread.entries[1],
3054 AgentThreadEntry::ToolCall(ToolCall {
3055 status: ToolCallStatus::Canceled,
3056 ..
3057 })
3058 ));
3059 });
3060
3061 thread
3062 .update(cx, |thread, cx| {
3063 thread.handle_session_update(
3064 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3065 id,
3066 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3067 )),
3068 cx,
3069 )
3070 })
3071 .unwrap();
3072
3073 request.await.unwrap();
3074
3075 thread.read_with(cx, |thread, _| {
3076 assert!(matches!(
3077 thread.entries[1],
3078 AgentThreadEntry::ToolCall(ToolCall {
3079 status: ToolCallStatus::Completed,
3080 ..
3081 })
3082 ));
3083 });
3084 }
3085
3086 #[gpui::test]
3087 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3088 init_test(cx);
3089 let fs = FakeFs::new(cx.background_executor.clone());
3090 fs.insert_tree(path!("/test"), json!({})).await;
3091 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3092
3093 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3094 move |_, thread, mut cx| {
3095 async move {
3096 thread
3097 .update(&mut cx, |thread, cx| {
3098 thread.handle_session_update(
3099 acp::SessionUpdate::ToolCall(
3100 acp::ToolCall::new("test", "Label")
3101 .kind(acp::ToolKind::Edit)
3102 .status(acp::ToolCallStatus::Completed)
3103 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3104 "/test/test.txt",
3105 "foo",
3106 ))]),
3107 ),
3108 cx,
3109 )
3110 })
3111 .unwrap()
3112 .unwrap();
3113 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3114 }
3115 .boxed_local()
3116 }
3117 }));
3118
3119 let thread = cx
3120 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3121 .await
3122 .unwrap();
3123
3124 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3125 .await
3126 .unwrap();
3127
3128 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3129 }
3130
3131 #[gpui::test(iterations = 10)]
3132 async fn test_checkpoints(cx: &mut TestAppContext) {
3133 init_test(cx);
3134 let fs = FakeFs::new(cx.background_executor.clone());
3135 fs.insert_tree(
3136 path!("/test"),
3137 json!({
3138 ".git": {}
3139 }),
3140 )
3141 .await;
3142 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3143
3144 let simulate_changes = Arc::new(AtomicBool::new(true));
3145 let next_filename = Arc::new(AtomicUsize::new(0));
3146 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3147 let simulate_changes = simulate_changes.clone();
3148 let next_filename = next_filename.clone();
3149 let fs = fs.clone();
3150 move |request, thread, mut cx| {
3151 let fs = fs.clone();
3152 let simulate_changes = simulate_changes.clone();
3153 let next_filename = next_filename.clone();
3154 async move {
3155 if simulate_changes.load(SeqCst) {
3156 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3157 fs.write(Path::new(&filename), b"").await?;
3158 }
3159
3160 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3161 panic!("expected text content block");
3162 };
3163 thread.update(&mut cx, |thread, cx| {
3164 thread
3165 .handle_session_update(
3166 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3167 content.text.to_uppercase().into(),
3168 )),
3169 cx,
3170 )
3171 .unwrap();
3172 })?;
3173 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3174 }
3175 .boxed_local()
3176 }
3177 }));
3178 let thread = cx
3179 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3180 .await
3181 .unwrap();
3182
3183 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3184 .await
3185 .unwrap();
3186 thread.read_with(cx, |thread, cx| {
3187 assert_eq!(
3188 thread.to_markdown(cx),
3189 indoc! {"
3190 ## User (checkpoint)
3191
3192 Lorem
3193
3194 ## Assistant
3195
3196 LOREM
3197
3198 "}
3199 );
3200 });
3201 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3202
3203 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3204 .await
3205 .unwrap();
3206 thread.read_with(cx, |thread, cx| {
3207 assert_eq!(
3208 thread.to_markdown(cx),
3209 indoc! {"
3210 ## User (checkpoint)
3211
3212 Lorem
3213
3214 ## Assistant
3215
3216 LOREM
3217
3218 ## User (checkpoint)
3219
3220 ipsum
3221
3222 ## Assistant
3223
3224 IPSUM
3225
3226 "}
3227 );
3228 });
3229 assert_eq!(
3230 fs.files(),
3231 vec![
3232 Path::new(path!("/test/file-0")),
3233 Path::new(path!("/test/file-1"))
3234 ]
3235 );
3236
3237 // Checkpoint isn't stored when there are no changes.
3238 simulate_changes.store(false, SeqCst);
3239 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3240 .await
3241 .unwrap();
3242 thread.read_with(cx, |thread, cx| {
3243 assert_eq!(
3244 thread.to_markdown(cx),
3245 indoc! {"
3246 ## User (checkpoint)
3247
3248 Lorem
3249
3250 ## Assistant
3251
3252 LOREM
3253
3254 ## User (checkpoint)
3255
3256 ipsum
3257
3258 ## Assistant
3259
3260 IPSUM
3261
3262 ## User
3263
3264 dolor
3265
3266 ## Assistant
3267
3268 DOLOR
3269
3270 "}
3271 );
3272 });
3273 assert_eq!(
3274 fs.files(),
3275 vec![
3276 Path::new(path!("/test/file-0")),
3277 Path::new(path!("/test/file-1"))
3278 ]
3279 );
3280
3281 // Rewinding the conversation truncates the history and restores the checkpoint.
3282 thread
3283 .update(cx, |thread, cx| {
3284 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3285 panic!("unexpected entries {:?}", thread.entries)
3286 };
3287 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3288 })
3289 .await
3290 .unwrap();
3291 thread.read_with(cx, |thread, cx| {
3292 assert_eq!(
3293 thread.to_markdown(cx),
3294 indoc! {"
3295 ## User (checkpoint)
3296
3297 Lorem
3298
3299 ## Assistant
3300
3301 LOREM
3302
3303 "}
3304 );
3305 });
3306 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3307 }
3308
3309 #[gpui::test]
3310 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3311 use std::sync::atomic::AtomicUsize;
3312 init_test(cx);
3313
3314 let fs = FakeFs::new(cx.executor());
3315 let project = Project::test(fs, None, cx).await;
3316
3317 // Create a connection that simulates refusal after tool result
3318 let prompt_count = Arc::new(AtomicUsize::new(0));
3319 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3320 let prompt_count = prompt_count.clone();
3321 move |_request, thread, mut cx| {
3322 let count = prompt_count.fetch_add(1, SeqCst);
3323 async move {
3324 if count == 0 {
3325 // First prompt: Generate a tool call with result
3326 thread.update(&mut cx, |thread, cx| {
3327 thread
3328 .handle_session_update(
3329 acp::SessionUpdate::ToolCall(
3330 acp::ToolCall::new("tool1", "Test Tool")
3331 .kind(acp::ToolKind::Fetch)
3332 .status(acp::ToolCallStatus::Completed)
3333 .raw_input(serde_json::json!({"query": "test"}))
3334 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3335 ),
3336 cx,
3337 )
3338 .unwrap();
3339 })?;
3340
3341 // Now return refusal because of the tool result
3342 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3343 } else {
3344 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3345 }
3346 }
3347 .boxed_local()
3348 }
3349 }));
3350
3351 let thread = cx
3352 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3353 .await
3354 .unwrap();
3355
3356 // Track if we see a Refusal event
3357 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3358 let saw_refusal_event_captured = saw_refusal_event.clone();
3359 thread.update(cx, |_thread, cx| {
3360 cx.subscribe(
3361 &thread,
3362 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3363 if matches!(event, AcpThreadEvent::Refusal) {
3364 *saw_refusal_event_captured.lock().unwrap() = true;
3365 }
3366 },
3367 )
3368 .detach();
3369 });
3370
3371 // Send a user message - this will trigger tool call and then refusal
3372 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3373 cx.background_executor.spawn(send_task).detach();
3374 cx.run_until_parked();
3375
3376 // Verify that:
3377 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3378 // 2. The user message was NOT truncated
3379 assert!(
3380 *saw_refusal_event.lock().unwrap(),
3381 "Refusal event should be emitted for tool result refusals"
3382 );
3383
3384 thread.read_with(cx, |thread, _| {
3385 let entries = thread.entries();
3386 assert!(entries.len() >= 2, "Should have user message and tool call");
3387
3388 // Verify user message is still there
3389 assert!(
3390 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3391 "User message should not be truncated"
3392 );
3393
3394 // Verify tool call is there with result
3395 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3396 assert!(
3397 tool_call.raw_output.is_some(),
3398 "Tool call should have output"
3399 );
3400 } else {
3401 panic!("Expected tool call at index 1");
3402 }
3403 });
3404 }
3405
3406 #[gpui::test]
3407 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3408 init_test(cx);
3409
3410 let fs = FakeFs::new(cx.executor());
3411 let project = Project::test(fs, None, cx).await;
3412
3413 let refuse_next = Arc::new(AtomicBool::new(false));
3414 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3415 let refuse_next = refuse_next.clone();
3416 move |_request, _thread, _cx| {
3417 if refuse_next.load(SeqCst) {
3418 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3419 .boxed_local()
3420 } else {
3421 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3422 .boxed_local()
3423 }
3424 }
3425 }));
3426
3427 let thread = cx
3428 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3429 .await
3430 .unwrap();
3431
3432 // Track if we see a Refusal event
3433 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3434 let saw_refusal_event_captured = saw_refusal_event.clone();
3435 thread.update(cx, |_thread, cx| {
3436 cx.subscribe(
3437 &thread,
3438 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3439 if matches!(event, AcpThreadEvent::Refusal) {
3440 *saw_refusal_event_captured.lock().unwrap() = true;
3441 }
3442 },
3443 )
3444 .detach();
3445 });
3446
3447 // Send a message that will be refused
3448 refuse_next.store(true, SeqCst);
3449 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3450 .await
3451 .unwrap();
3452
3453 // Verify that a Refusal event WAS emitted for user prompt refusal
3454 assert!(
3455 *saw_refusal_event.lock().unwrap(),
3456 "Refusal event should be emitted for user prompt refusals"
3457 );
3458
3459 // Verify the message was truncated (user prompt refusal)
3460 thread.read_with(cx, |thread, cx| {
3461 assert_eq!(thread.to_markdown(cx), "");
3462 });
3463 }
3464
3465 #[gpui::test]
3466 async fn test_refusal(cx: &mut TestAppContext) {
3467 init_test(cx);
3468 let fs = FakeFs::new(cx.background_executor.clone());
3469 fs.insert_tree(path!("/"), json!({})).await;
3470 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3471
3472 let refuse_next = Arc::new(AtomicBool::new(false));
3473 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3474 let refuse_next = refuse_next.clone();
3475 move |request, thread, mut cx| {
3476 let refuse_next = refuse_next.clone();
3477 async move {
3478 if refuse_next.load(SeqCst) {
3479 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3480 }
3481
3482 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3483 panic!("expected text content block");
3484 };
3485 thread.update(&mut cx, |thread, cx| {
3486 thread
3487 .handle_session_update(
3488 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3489 content.text.to_uppercase().into(),
3490 )),
3491 cx,
3492 )
3493 .unwrap();
3494 })?;
3495 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3496 }
3497 .boxed_local()
3498 }
3499 }));
3500 let thread = cx
3501 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3502 .await
3503 .unwrap();
3504
3505 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3506 .await
3507 .unwrap();
3508 thread.read_with(cx, |thread, cx| {
3509 assert_eq!(
3510 thread.to_markdown(cx),
3511 indoc! {"
3512 ## User
3513
3514 hello
3515
3516 ## Assistant
3517
3518 HELLO
3519
3520 "}
3521 );
3522 });
3523
3524 // Simulate refusing the second message. The message should be truncated
3525 // when a user prompt is refused.
3526 refuse_next.store(true, SeqCst);
3527 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3528 .await
3529 .unwrap();
3530 thread.read_with(cx, |thread, cx| {
3531 assert_eq!(
3532 thread.to_markdown(cx),
3533 indoc! {"
3534 ## User
3535
3536 hello
3537
3538 ## Assistant
3539
3540 HELLO
3541
3542 "}
3543 );
3544 });
3545 }
3546
3547 async fn run_until_first_tool_call(
3548 thread: &Entity<AcpThread>,
3549 cx: &mut TestAppContext,
3550 ) -> usize {
3551 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3552
3553 let subscription = cx.update(|cx| {
3554 cx.subscribe(thread, move |thread, _, cx| {
3555 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3556 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3557 return tx.try_send(ix).unwrap();
3558 }
3559 }
3560 })
3561 });
3562
3563 select! {
3564 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3565 panic!("Timeout waiting for tool call")
3566 }
3567 ix = rx.next().fuse() => {
3568 drop(subscription);
3569 ix.unwrap()
3570 }
3571 }
3572 }
3573
3574 #[derive(Clone, Default)]
3575 struct FakeAgentConnection {
3576 auth_methods: Vec<acp::AuthMethod>,
3577 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3578 on_user_message: Option<
3579 Rc<
3580 dyn Fn(
3581 acp::PromptRequest,
3582 WeakEntity<AcpThread>,
3583 AsyncApp,
3584 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3585 + 'static,
3586 >,
3587 >,
3588 }
3589
3590 impl FakeAgentConnection {
3591 fn new() -> Self {
3592 Self {
3593 auth_methods: Vec::new(),
3594 on_user_message: None,
3595 sessions: Arc::default(),
3596 }
3597 }
3598
3599 #[expect(unused)]
3600 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3601 self.auth_methods = auth_methods;
3602 self
3603 }
3604
3605 fn on_user_message(
3606 mut self,
3607 handler: impl Fn(
3608 acp::PromptRequest,
3609 WeakEntity<AcpThread>,
3610 AsyncApp,
3611 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3612 + 'static,
3613 ) -> Self {
3614 self.on_user_message.replace(Rc::new(handler));
3615 self
3616 }
3617 }
3618
3619 impl AgentConnection for FakeAgentConnection {
3620 fn telemetry_id(&self) -> SharedString {
3621 "fake".into()
3622 }
3623
3624 fn auth_methods(&self) -> &[acp::AuthMethod] {
3625 &self.auth_methods
3626 }
3627
3628 fn new_thread(
3629 self: Rc<Self>,
3630 project: Entity<Project>,
3631 _cwd: &Path,
3632 cx: &mut App,
3633 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3634 let session_id = acp::SessionId::new(
3635 rand::rng()
3636 .sample_iter(&distr::Alphanumeric)
3637 .take(7)
3638 .map(char::from)
3639 .collect::<String>(),
3640 );
3641 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3642 let thread = cx.new(|cx| {
3643 AcpThread::new(
3644 "Test",
3645 self.clone(),
3646 project,
3647 action_log,
3648 session_id.clone(),
3649 watch::Receiver::constant(
3650 acp::PromptCapabilities::new()
3651 .image(true)
3652 .audio(true)
3653 .embedded_context(true),
3654 ),
3655 cx,
3656 )
3657 });
3658 self.sessions.lock().insert(session_id, thread.downgrade());
3659 Task::ready(Ok(thread))
3660 }
3661
3662 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3663 if self.auth_methods().iter().any(|m| m.id == method) {
3664 Task::ready(Ok(()))
3665 } else {
3666 Task::ready(Err(anyhow!("Invalid Auth Method")))
3667 }
3668 }
3669
3670 fn prompt(
3671 &self,
3672 _id: Option<UserMessageId>,
3673 params: acp::PromptRequest,
3674 cx: &mut App,
3675 ) -> Task<gpui::Result<acp::PromptResponse>> {
3676 let sessions = self.sessions.lock();
3677 let thread = sessions.get(¶ms.session_id).unwrap();
3678 if let Some(handler) = &self.on_user_message {
3679 let handler = handler.clone();
3680 let thread = thread.clone();
3681 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3682 } else {
3683 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3684 }
3685 }
3686
3687 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3688 let sessions = self.sessions.lock();
3689 let thread = sessions.get(session_id).unwrap().clone();
3690
3691 cx.spawn(async move |cx| {
3692 thread
3693 .update(cx, |thread, cx| thread.cancel(cx))
3694 .unwrap()
3695 .await
3696 })
3697 .detach();
3698 }
3699
3700 fn truncate(
3701 &self,
3702 session_id: &acp::SessionId,
3703 _cx: &App,
3704 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3705 Some(Rc::new(FakeAgentSessionEditor {
3706 _session_id: session_id.clone(),
3707 }))
3708 }
3709
3710 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3711 self
3712 }
3713 }
3714
3715 struct FakeAgentSessionEditor {
3716 _session_id: acp::SessionId,
3717 }
3718
3719 impl AgentSessionTruncate for FakeAgentSessionEditor {
3720 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3721 Task::ready(Ok(()))
3722 }
3723 }
3724
3725 #[gpui::test]
3726 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3727 init_test(cx);
3728
3729 let fs = FakeFs::new(cx.executor());
3730 let project = Project::test(fs, [], cx).await;
3731 let connection = Rc::new(FakeAgentConnection::new());
3732 let thread = cx
3733 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3734 .await
3735 .unwrap();
3736
3737 // Try to update a tool call that doesn't exist
3738 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3739 thread.update(cx, |thread, cx| {
3740 let result = thread.handle_session_update(
3741 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3742 nonexistent_id.clone(),
3743 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3744 )),
3745 cx,
3746 );
3747
3748 // The update should succeed (not return an error)
3749 assert!(result.is_ok());
3750
3751 // There should now be exactly one entry in the thread
3752 assert_eq!(thread.entries.len(), 1);
3753
3754 // The entry should be a failed tool call
3755 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3756 assert_eq!(tool_call.id, nonexistent_id);
3757 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3758 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3759
3760 // Check that the content contains the error message
3761 assert_eq!(tool_call.content.len(), 1);
3762 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3763 match content_block {
3764 ContentBlock::Markdown { markdown } => {
3765 let markdown_text = markdown.read(cx).source();
3766 assert!(markdown_text.contains("Tool call not found"));
3767 }
3768 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3769 ContentBlock::ResourceLink { .. } => {
3770 panic!("Expected markdown content, got resource link")
3771 }
3772 }
3773 } else {
3774 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3775 }
3776 } else {
3777 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3778 }
3779 });
3780 }
3781
3782 /// Tests that restoring a checkpoint properly cleans up terminals that were
3783 /// created after that checkpoint, and cancels any in-progress generation.
3784 ///
3785 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3786 /// that were started after that checkpoint should be terminated, and any in-progress
3787 /// AI generation should be canceled.
3788 #[gpui::test]
3789 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3790 init_test(cx);
3791
3792 let fs = FakeFs::new(cx.executor());
3793 let project = Project::test(fs, [], cx).await;
3794 let connection = Rc::new(FakeAgentConnection::new());
3795 let thread = cx
3796 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3797 .await
3798 .unwrap();
3799
3800 // Send first user message to create a checkpoint
3801 cx.update(|cx| {
3802 thread.update(cx, |thread, cx| {
3803 thread.send(vec!["first message".into()], cx)
3804 })
3805 })
3806 .await
3807 .unwrap();
3808
3809 // Send second message (creates another checkpoint) - we'll restore to this one
3810 cx.update(|cx| {
3811 thread.update(cx, |thread, cx| {
3812 thread.send(vec!["second message".into()], cx)
3813 })
3814 })
3815 .await
3816 .unwrap();
3817
3818 // Create 2 terminals BEFORE the checkpoint that have completed running
3819 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3820 let mock_terminal_1 = cx.new(|cx| {
3821 let builder = ::terminal::TerminalBuilder::new_display_only(
3822 ::terminal::terminal_settings::CursorShape::default(),
3823 ::terminal::terminal_settings::AlternateScroll::On,
3824 None,
3825 0,
3826 )
3827 .unwrap();
3828 builder.subscribe(cx)
3829 });
3830
3831 thread.update(cx, |thread, cx| {
3832 thread.on_terminal_provider_event(
3833 TerminalProviderEvent::Created {
3834 terminal_id: terminal_id_1.clone(),
3835 label: "echo 'first'".to_string(),
3836 cwd: Some(PathBuf::from("/test")),
3837 output_byte_limit: None,
3838 terminal: mock_terminal_1.clone(),
3839 },
3840 cx,
3841 );
3842 });
3843
3844 thread.update(cx, |thread, cx| {
3845 thread.on_terminal_provider_event(
3846 TerminalProviderEvent::Output {
3847 terminal_id: terminal_id_1.clone(),
3848 data: b"first\n".to_vec(),
3849 },
3850 cx,
3851 );
3852 });
3853
3854 thread.update(cx, |thread, cx| {
3855 thread.on_terminal_provider_event(
3856 TerminalProviderEvent::Exit {
3857 terminal_id: terminal_id_1.clone(),
3858 status: acp::TerminalExitStatus::new().exit_code(0),
3859 },
3860 cx,
3861 );
3862 });
3863
3864 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3865 let mock_terminal_2 = cx.new(|cx| {
3866 let builder = ::terminal::TerminalBuilder::new_display_only(
3867 ::terminal::terminal_settings::CursorShape::default(),
3868 ::terminal::terminal_settings::AlternateScroll::On,
3869 None,
3870 0,
3871 )
3872 .unwrap();
3873 builder.subscribe(cx)
3874 });
3875
3876 thread.update(cx, |thread, cx| {
3877 thread.on_terminal_provider_event(
3878 TerminalProviderEvent::Created {
3879 terminal_id: terminal_id_2.clone(),
3880 label: "echo 'second'".to_string(),
3881 cwd: Some(PathBuf::from("/test")),
3882 output_byte_limit: None,
3883 terminal: mock_terminal_2.clone(),
3884 },
3885 cx,
3886 );
3887 });
3888
3889 thread.update(cx, |thread, cx| {
3890 thread.on_terminal_provider_event(
3891 TerminalProviderEvent::Output {
3892 terminal_id: terminal_id_2.clone(),
3893 data: b"second\n".to_vec(),
3894 },
3895 cx,
3896 );
3897 });
3898
3899 thread.update(cx, |thread, cx| {
3900 thread.on_terminal_provider_event(
3901 TerminalProviderEvent::Exit {
3902 terminal_id: terminal_id_2.clone(),
3903 status: acp::TerminalExitStatus::new().exit_code(0),
3904 },
3905 cx,
3906 );
3907 });
3908
3909 // Get the second message ID to restore to
3910 let second_message_id = thread.read_with(cx, |thread, _| {
3911 // At this point we have:
3912 // - Index 0: First user message (with checkpoint)
3913 // - Index 1: Second user message (with checkpoint)
3914 // No assistant responses because FakeAgentConnection just returns EndTurn
3915 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3916 panic!("expected user message at index 1");
3917 };
3918 message.id.clone().unwrap()
3919 });
3920
3921 // Create a terminal AFTER the checkpoint we'll restore to.
3922 // This simulates the AI agent starting a long-running terminal command.
3923 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3924 let mock_terminal = cx.new(|cx| {
3925 let builder = ::terminal::TerminalBuilder::new_display_only(
3926 ::terminal::terminal_settings::CursorShape::default(),
3927 ::terminal::terminal_settings::AlternateScroll::On,
3928 None,
3929 0,
3930 )
3931 .unwrap();
3932 builder.subscribe(cx)
3933 });
3934
3935 // Register the terminal as created
3936 thread.update(cx, |thread, cx| {
3937 thread.on_terminal_provider_event(
3938 TerminalProviderEvent::Created {
3939 terminal_id: terminal_id.clone(),
3940 label: "sleep 1000".to_string(),
3941 cwd: Some(PathBuf::from("/test")),
3942 output_byte_limit: None,
3943 terminal: mock_terminal.clone(),
3944 },
3945 cx,
3946 );
3947 });
3948
3949 // Simulate the terminal producing output (still running)
3950 thread.update(cx, |thread, cx| {
3951 thread.on_terminal_provider_event(
3952 TerminalProviderEvent::Output {
3953 terminal_id: terminal_id.clone(),
3954 data: b"terminal is running...\n".to_vec(),
3955 },
3956 cx,
3957 );
3958 });
3959
3960 // Create a tool call entry that references this terminal
3961 // This represents the agent requesting a terminal command
3962 thread.update(cx, |thread, cx| {
3963 thread
3964 .handle_session_update(
3965 acp::SessionUpdate::ToolCall(
3966 acp::ToolCall::new("terminal-tool-1", "Running command")
3967 .kind(acp::ToolKind::Execute)
3968 .status(acp::ToolCallStatus::InProgress)
3969 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
3970 terminal_id.clone(),
3971 ))])
3972 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
3973 ),
3974 cx,
3975 )
3976 .unwrap();
3977 });
3978
3979 // Verify terminal exists and is in the thread
3980 let terminal_exists_before =
3981 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
3982 assert!(
3983 terminal_exists_before,
3984 "Terminal should exist before checkpoint restore"
3985 );
3986
3987 // Verify the terminal's underlying task is still running (not completed)
3988 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
3989 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
3990 terminal_entity.read_with(cx, |term, _cx| {
3991 term.output().is_none() // output is None means it's still running
3992 })
3993 });
3994 assert!(
3995 terminal_running_before,
3996 "Terminal should be running before checkpoint restore"
3997 );
3998
3999 // Verify we have the expected entries before restore
4000 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4001 assert!(
4002 entry_count_before > 1,
4003 "Should have multiple entries before restore"
4004 );
4005
4006 // Restore the checkpoint to the second message.
4007 // This should:
4008 // 1. Cancel any in-progress generation (via the cancel() call)
4009 // 2. Remove the terminal that was created after that point
4010 thread
4011 .update(cx, |thread, cx| {
4012 thread.restore_checkpoint(second_message_id, cx)
4013 })
4014 .await
4015 .unwrap();
4016
4017 // Verify that no send_task is in progress after restore
4018 // (cancel() clears the send_task)
4019 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4020 assert!(
4021 !has_send_task_after,
4022 "Should not have a send_task after restore (cancel should have cleared it)"
4023 );
4024
4025 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4026 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4027 assert_eq!(
4028 entry_count, 1,
4029 "Should have 1 entry after restore (only the first user message)"
4030 );
4031
4032 // Verify the 2 completed terminals from before the checkpoint still exist
4033 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4034 thread.terminals.contains_key(&terminal_id_1)
4035 });
4036 assert!(
4037 terminal_1_exists,
4038 "Terminal 1 (from before checkpoint) should still exist"
4039 );
4040
4041 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4042 thread.terminals.contains_key(&terminal_id_2)
4043 });
4044 assert!(
4045 terminal_2_exists,
4046 "Terminal 2 (from before checkpoint) should still exist"
4047 );
4048
4049 // Verify they're still in completed state
4050 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4051 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4052 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4053 });
4054 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4055
4056 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4057 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4058 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4059 });
4060 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4061
4062 // Verify the running terminal (created after checkpoint) was removed
4063 let terminal_3_exists =
4064 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4065 assert!(
4066 !terminal_3_exists,
4067 "Terminal 3 (created after checkpoint) should have been removed"
4068 );
4069
4070 // Verify total count is 2 (the two from before the checkpoint)
4071 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4072 assert_eq!(
4073 terminal_count, 2,
4074 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4075 );
4076 }
4077
4078 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4079 /// even when a new user message is added while the async checkpoint comparison is in progress.
4080 ///
4081 /// This is a regression test for a bug where update_last_checkpoint would fail with
4082 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4083 /// update_last_checkpoint started and when its async closure ran.
4084 #[gpui::test]
4085 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4086 init_test(cx);
4087
4088 let fs = FakeFs::new(cx.executor());
4089 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4090 .await;
4091 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4092
4093 let handler_done = Arc::new(AtomicBool::new(false));
4094 let handler_done_clone = handler_done.clone();
4095 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4096 move |_, _thread, _cx| {
4097 handler_done_clone.store(true, SeqCst);
4098 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4099 },
4100 ));
4101
4102 let thread = cx
4103 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4104 .await
4105 .unwrap();
4106
4107 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4108 let send_task = cx.background_executor.spawn(send_future);
4109
4110 // Tick until handler completes, then a few more to let update_last_checkpoint start
4111 while !handler_done.load(SeqCst) {
4112 cx.executor().tick();
4113 }
4114 for _ in 0..5 {
4115 cx.executor().tick();
4116 }
4117
4118 thread.update(cx, |thread, cx| {
4119 thread.push_entry(
4120 AgentThreadEntry::UserMessage(UserMessage {
4121 id: Some(UserMessageId::new()),
4122 content: ContentBlock::Empty,
4123 chunks: vec!["Injected message (no checkpoint)".into()],
4124 checkpoint: None,
4125 indented: false,
4126 }),
4127 cx,
4128 );
4129 });
4130
4131 cx.run_until_parked();
4132 let result = send_task.await;
4133
4134 assert!(
4135 result.is_ok(),
4136 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4137 result.err()
4138 );
4139 }
4140}