1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer
379 .update(cx, |buffer, _| {
380 let snapshot = buffer.snapshot();
381 if let Some(row) = location.line {
382 let column = snapshot.indent_size_for_line(row).len;
383 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
384 snapshot.anchor_before(point)
385 } else {
386 Anchor::min_for_buffer(snapshot.remote_id())
387 }
388 })
389 .ok()?;
390
391 Some(ResolvedLocation { buffer, position })
392 }
393
394 fn resolve_locations(
395 &self,
396 project: Entity<Project>,
397 cx: &mut App,
398 ) -> Task<Vec<Option<ResolvedLocation>>> {
399 let locations = self.locations.clone();
400 project.update(cx, |_, cx| {
401 cx.spawn(async move |project, cx| {
402 let mut new_locations = Vec::new();
403 for location in locations {
404 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
405 }
406 new_locations
407 })
408 })
409 }
410}
411
412// Separate so we can hold a strong reference to the buffer
413// for saving on the thread
414#[derive(Clone, Debug, PartialEq, Eq)]
415struct ResolvedLocation {
416 buffer: Entity<Buffer>,
417 position: Anchor,
418}
419
420impl From<&ResolvedLocation> for AgentLocation {
421 fn from(value: &ResolvedLocation) -> Self {
422 Self {
423 buffer: value.buffer.downgrade(),
424 position: value.position,
425 }
426 }
427}
428
429#[derive(Debug)]
430pub enum ToolCallStatus {
431 /// The tool call hasn't started running yet, but we start showing it to
432 /// the user.
433 Pending,
434 /// The tool call is waiting for confirmation from the user.
435 WaitingForConfirmation {
436 options: Vec<acp::PermissionOption>,
437 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
438 },
439 /// The tool call is currently running.
440 InProgress,
441 /// The tool call completed successfully.
442 Completed,
443 /// The tool call failed.
444 Failed,
445 /// The user rejected the tool call.
446 Rejected,
447 /// The user canceled generation so the tool call was canceled.
448 Canceled,
449}
450
451impl From<acp::ToolCallStatus> for ToolCallStatus {
452 fn from(status: acp::ToolCallStatus) -> Self {
453 match status {
454 acp::ToolCallStatus::Pending => Self::Pending,
455 acp::ToolCallStatus::InProgress => Self::InProgress,
456 acp::ToolCallStatus::Completed => Self::Completed,
457 acp::ToolCallStatus::Failed => Self::Failed,
458 _ => Self::Pending,
459 }
460 }
461}
462
463impl Display for ToolCallStatus {
464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465 write!(
466 f,
467 "{}",
468 match self {
469 ToolCallStatus::Pending => "Pending",
470 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
471 ToolCallStatus::InProgress => "In Progress",
472 ToolCallStatus::Completed => "Completed",
473 ToolCallStatus::Failed => "Failed",
474 ToolCallStatus::Rejected => "Rejected",
475 ToolCallStatus::Canceled => "Canceled",
476 }
477 )
478 }
479}
480
481#[derive(Debug, PartialEq, Clone)]
482pub enum ContentBlock {
483 Empty,
484 Markdown { markdown: Entity<Markdown> },
485 ResourceLink { resource_link: acp::ResourceLink },
486}
487
488impl ContentBlock {
489 pub fn new(
490 block: acp::ContentBlock,
491 language_registry: &Arc<LanguageRegistry>,
492 path_style: PathStyle,
493 cx: &mut App,
494 ) -> Self {
495 let mut this = Self::Empty;
496 this.append(block, language_registry, path_style, cx);
497 this
498 }
499
500 pub fn new_combined(
501 blocks: impl IntoIterator<Item = acp::ContentBlock>,
502 language_registry: Arc<LanguageRegistry>,
503 path_style: PathStyle,
504 cx: &mut App,
505 ) -> Self {
506 let mut this = Self::Empty;
507 for block in blocks {
508 this.append(block, &language_registry, path_style, cx);
509 }
510 this
511 }
512
513 pub fn append(
514 &mut self,
515 block: acp::ContentBlock,
516 language_registry: &Arc<LanguageRegistry>,
517 path_style: PathStyle,
518 cx: &mut App,
519 ) {
520 if matches!(self, ContentBlock::Empty)
521 && let acp::ContentBlock::ResourceLink(resource_link) = block
522 {
523 *self = ContentBlock::ResourceLink { resource_link };
524 return;
525 }
526
527 let new_content = self.block_string_contents(block, path_style);
528
529 match self {
530 ContentBlock::Empty => {
531 *self = Self::create_markdown_block(new_content, language_registry, cx);
532 }
533 ContentBlock::Markdown { markdown } => {
534 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
535 }
536 ContentBlock::ResourceLink { resource_link } => {
537 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
538 let combined = format!("{}\n{}", existing_content, new_content);
539
540 *self = Self::create_markdown_block(combined, language_registry, cx);
541 }
542 }
543 }
544
545 fn create_markdown_block(
546 content: String,
547 language_registry: &Arc<LanguageRegistry>,
548 cx: &mut App,
549 ) -> ContentBlock {
550 ContentBlock::Markdown {
551 markdown: cx
552 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
553 }
554 }
555
556 fn block_string_contents(&self, block: acp::ContentBlock, path_style: PathStyle) -> String {
557 match block {
558 acp::ContentBlock::Text(text_content) => text_content.text,
559 acp::ContentBlock::ResourceLink(resource_link) => {
560 Self::resource_link_md(&resource_link.uri, path_style)
561 }
562 acp::ContentBlock::Resource(acp::EmbeddedResource {
563 resource:
564 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
565 uri,
566 ..
567 }),
568 ..
569 }) => Self::resource_link_md(&uri, path_style),
570 acp::ContentBlock::Image(image) => Self::image_md(&image),
571 _ => String::new(),
572 }
573 }
574
575 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
576 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
577 uri.as_link().to_string()
578 } else {
579 uri.to_string()
580 }
581 }
582
583 fn image_md(_image: &acp::ImageContent) -> String {
584 "`Image`".into()
585 }
586
587 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
588 match self {
589 ContentBlock::Empty => "",
590 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
591 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
592 }
593 }
594
595 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
596 match self {
597 ContentBlock::Empty => None,
598 ContentBlock::Markdown { markdown } => Some(markdown),
599 ContentBlock::ResourceLink { .. } => None,
600 }
601 }
602
603 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
604 match self {
605 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
606 _ => None,
607 }
608 }
609}
610
611#[derive(Debug)]
612pub enum ToolCallContent {
613 ContentBlock(ContentBlock),
614 Diff(Entity<Diff>),
615 Terminal(Entity<Terminal>),
616}
617
618impl ToolCallContent {
619 pub fn from_acp(
620 content: acp::ToolCallContent,
621 language_registry: Arc<LanguageRegistry>,
622 path_style: PathStyle,
623 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
624 cx: &mut App,
625 ) -> Result<Option<Self>> {
626 match content {
627 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
628 Ok(Some(Self::ContentBlock(ContentBlock::new(
629 content,
630 &language_registry,
631 path_style,
632 cx,
633 ))))
634 }
635 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
636 Diff::finalized(
637 diff.path.to_string_lossy().into_owned(),
638 diff.old_text,
639 diff.new_text,
640 language_registry,
641 cx,
642 )
643 })))),
644 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
645 .get(&terminal_id)
646 .cloned()
647 .map(|terminal| Some(Self::Terminal(terminal)))
648 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
649 _ => Ok(None),
650 }
651 }
652
653 pub fn update_from_acp(
654 &mut self,
655 new: acp::ToolCallContent,
656 language_registry: Arc<LanguageRegistry>,
657 path_style: PathStyle,
658 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
659 cx: &mut App,
660 ) -> Result<bool> {
661 let needs_update = match (&self, &new) {
662 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
663 old_diff.read(cx).needs_update(
664 new_diff.old_text.as_deref().unwrap_or(""),
665 &new_diff.new_text,
666 cx,
667 )
668 }
669 _ => true,
670 };
671
672 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
673 if needs_update {
674 *self = update;
675 }
676 Ok(true)
677 } else {
678 Ok(false)
679 }
680 }
681
682 pub fn to_markdown(&self, cx: &App) -> String {
683 match self {
684 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
685 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
686 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
687 }
688 }
689}
690
691#[derive(Debug, PartialEq)]
692pub enum ToolCallUpdate {
693 UpdateFields(acp::ToolCallUpdate),
694 UpdateDiff(ToolCallUpdateDiff),
695 UpdateTerminal(ToolCallUpdateTerminal),
696}
697
698impl ToolCallUpdate {
699 fn id(&self) -> &acp::ToolCallId {
700 match self {
701 Self::UpdateFields(update) => &update.tool_call_id,
702 Self::UpdateDiff(diff) => &diff.id,
703 Self::UpdateTerminal(terminal) => &terminal.id,
704 }
705 }
706}
707
708impl From<acp::ToolCallUpdate> for ToolCallUpdate {
709 fn from(update: acp::ToolCallUpdate) -> Self {
710 Self::UpdateFields(update)
711 }
712}
713
714impl From<ToolCallUpdateDiff> for ToolCallUpdate {
715 fn from(diff: ToolCallUpdateDiff) -> Self {
716 Self::UpdateDiff(diff)
717 }
718}
719
720#[derive(Debug, PartialEq)]
721pub struct ToolCallUpdateDiff {
722 pub id: acp::ToolCallId,
723 pub diff: Entity<Diff>,
724}
725
726impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
727 fn from(terminal: ToolCallUpdateTerminal) -> Self {
728 Self::UpdateTerminal(terminal)
729 }
730}
731
732#[derive(Debug, PartialEq)]
733pub struct ToolCallUpdateTerminal {
734 pub id: acp::ToolCallId,
735 pub terminal: Entity<Terminal>,
736}
737
738#[derive(Debug, Default)]
739pub struct Plan {
740 pub entries: Vec<PlanEntry>,
741}
742
743#[derive(Debug)]
744pub struct PlanStats<'a> {
745 pub in_progress_entry: Option<&'a PlanEntry>,
746 pub pending: u32,
747 pub completed: u32,
748}
749
750impl Plan {
751 pub fn is_empty(&self) -> bool {
752 self.entries.is_empty()
753 }
754
755 pub fn stats(&self) -> PlanStats<'_> {
756 let mut stats = PlanStats {
757 in_progress_entry: None,
758 pending: 0,
759 completed: 0,
760 };
761
762 for entry in &self.entries {
763 match &entry.status {
764 acp::PlanEntryStatus::Pending => {
765 stats.pending += 1;
766 }
767 acp::PlanEntryStatus::InProgress => {
768 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
769 }
770 acp::PlanEntryStatus::Completed => {
771 stats.completed += 1;
772 }
773 _ => {}
774 }
775 }
776
777 stats
778 }
779}
780
781#[derive(Debug)]
782pub struct PlanEntry {
783 pub content: Entity<Markdown>,
784 pub priority: acp::PlanEntryPriority,
785 pub status: acp::PlanEntryStatus,
786}
787
788impl PlanEntry {
789 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
790 Self {
791 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
792 priority: entry.priority,
793 status: entry.status,
794 }
795 }
796}
797
798#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
799pub struct TokenUsage {
800 pub max_tokens: u64,
801 pub used_tokens: u64,
802}
803
804impl TokenUsage {
805 pub fn ratio(&self) -> TokenUsageRatio {
806 #[cfg(debug_assertions)]
807 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
808 .unwrap_or("0.8".to_string())
809 .parse()
810 .unwrap();
811 #[cfg(not(debug_assertions))]
812 let warning_threshold: f32 = 0.8;
813
814 // When the maximum is unknown because there is no selected model,
815 // avoid showing the token limit warning.
816 if self.max_tokens == 0 {
817 TokenUsageRatio::Normal
818 } else if self.used_tokens >= self.max_tokens {
819 TokenUsageRatio::Exceeded
820 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
821 TokenUsageRatio::Warning
822 } else {
823 TokenUsageRatio::Normal
824 }
825 }
826}
827
828#[derive(Debug, Clone, PartialEq, Eq)]
829pub enum TokenUsageRatio {
830 Normal,
831 Warning,
832 Exceeded,
833}
834
835#[derive(Debug, Clone)]
836pub struct RetryStatus {
837 pub last_error: SharedString,
838 pub attempt: usize,
839 pub max_attempts: usize,
840 pub started_at: Instant,
841 pub duration: Duration,
842}
843
844pub struct AcpThread {
845 title: SharedString,
846 entries: Vec<AgentThreadEntry>,
847 plan: Plan,
848 project: Entity<Project>,
849 action_log: Entity<ActionLog>,
850 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
851 send_task: Option<Task<()>>,
852 connection: Rc<dyn AgentConnection>,
853 session_id: acp::SessionId,
854 token_usage: Option<TokenUsage>,
855 prompt_capabilities: acp::PromptCapabilities,
856 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
857 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
858 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
859 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
860}
861
862impl From<&AcpThread> for ActionLogTelemetry {
863 fn from(value: &AcpThread) -> Self {
864 Self {
865 agent_telemetry_id: value.connection().telemetry_id(),
866 session_id: value.session_id.0.clone(),
867 }
868 }
869}
870
871#[derive(Debug)]
872pub enum AcpThreadEvent {
873 NewEntry,
874 TitleUpdated,
875 TokenUsageUpdated,
876 EntryUpdated(usize),
877 EntriesRemoved(Range<usize>),
878 ToolAuthorizationRequired,
879 Retry(RetryStatus),
880 Stopped,
881 Error,
882 LoadError(LoadError),
883 PromptCapabilitiesUpdated,
884 Refusal,
885 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
886 ModeUpdated(acp::SessionModeId),
887 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
888}
889
890impl EventEmitter<AcpThreadEvent> for AcpThread {}
891
892#[derive(Debug, Clone)]
893pub enum TerminalProviderEvent {
894 Created {
895 terminal_id: acp::TerminalId,
896 label: String,
897 cwd: Option<PathBuf>,
898 output_byte_limit: Option<u64>,
899 terminal: Entity<::terminal::Terminal>,
900 },
901 Output {
902 terminal_id: acp::TerminalId,
903 data: Vec<u8>,
904 },
905 TitleChanged {
906 terminal_id: acp::TerminalId,
907 title: String,
908 },
909 Exit {
910 terminal_id: acp::TerminalId,
911 status: acp::TerminalExitStatus,
912 },
913}
914
915#[derive(Debug, Clone)]
916pub enum TerminalProviderCommand {
917 WriteInput {
918 terminal_id: acp::TerminalId,
919 bytes: Vec<u8>,
920 },
921 Resize {
922 terminal_id: acp::TerminalId,
923 cols: u16,
924 rows: u16,
925 },
926 Close {
927 terminal_id: acp::TerminalId,
928 },
929}
930
931impl AcpThread {
932 pub fn on_terminal_provider_event(
933 &mut self,
934 event: TerminalProviderEvent,
935 cx: &mut Context<Self>,
936 ) {
937 match event {
938 TerminalProviderEvent::Created {
939 terminal_id,
940 label,
941 cwd,
942 output_byte_limit,
943 terminal,
944 } => {
945 let entity = self.register_terminal_created(
946 terminal_id.clone(),
947 label,
948 cwd,
949 output_byte_limit,
950 terminal,
951 cx,
952 );
953
954 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
955 for data in chunks.drain(..) {
956 entity.update(cx, |term, cx| {
957 term.inner().update(cx, |inner, cx| {
958 inner.write_output(&data, cx);
959 })
960 });
961 }
962 }
963
964 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
965 entity.update(cx, |_term, cx| {
966 cx.notify();
967 });
968 }
969
970 cx.notify();
971 }
972 TerminalProviderEvent::Output { terminal_id, data } => {
973 if let Some(entity) = self.terminals.get(&terminal_id) {
974 entity.update(cx, |term, cx| {
975 term.inner().update(cx, |inner, cx| {
976 inner.write_output(&data, cx);
977 })
978 });
979 } else {
980 self.pending_terminal_output
981 .entry(terminal_id)
982 .or_default()
983 .push(data);
984 }
985 }
986 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
987 if let Some(entity) = self.terminals.get(&terminal_id) {
988 entity.update(cx, |term, cx| {
989 term.inner().update(cx, |inner, cx| {
990 inner.breadcrumb_text = title;
991 cx.emit(::terminal::Event::BreadcrumbsChanged);
992 })
993 });
994 }
995 }
996 TerminalProviderEvent::Exit {
997 terminal_id,
998 status,
999 } => {
1000 if let Some(entity) = self.terminals.get(&terminal_id) {
1001 entity.update(cx, |_term, cx| {
1002 cx.notify();
1003 });
1004 } else {
1005 self.pending_terminal_exit.insert(terminal_id, status);
1006 }
1007 }
1008 }
1009 }
1010}
1011
1012#[derive(PartialEq, Eq, Debug)]
1013pub enum ThreadStatus {
1014 Idle,
1015 Generating,
1016}
1017
1018#[derive(Debug, Clone)]
1019pub enum LoadError {
1020 Unsupported {
1021 command: SharedString,
1022 current_version: SharedString,
1023 minimum_version: SharedString,
1024 },
1025 FailedToInstall(SharedString),
1026 Exited {
1027 status: ExitStatus,
1028 },
1029 Other(SharedString),
1030}
1031
1032impl Display for LoadError {
1033 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1034 match self {
1035 LoadError::Unsupported {
1036 command: path,
1037 current_version,
1038 minimum_version,
1039 } => {
1040 write!(
1041 f,
1042 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1043 )
1044 }
1045 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1046 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1047 LoadError::Other(msg) => write!(f, "{msg}"),
1048 }
1049 }
1050}
1051
1052impl Error for LoadError {}
1053
1054impl AcpThread {
1055 pub fn new(
1056 title: impl Into<SharedString>,
1057 connection: Rc<dyn AgentConnection>,
1058 project: Entity<Project>,
1059 action_log: Entity<ActionLog>,
1060 session_id: acp::SessionId,
1061 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1062 cx: &mut Context<Self>,
1063 ) -> Self {
1064 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1065 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1066 loop {
1067 let caps = prompt_capabilities_rx.recv().await?;
1068 this.update(cx, |this, cx| {
1069 this.prompt_capabilities = caps;
1070 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1071 })?;
1072 }
1073 });
1074
1075 Self {
1076 action_log,
1077 shared_buffers: Default::default(),
1078 entries: Default::default(),
1079 plan: Default::default(),
1080 title: title.into(),
1081 project,
1082 send_task: None,
1083 connection,
1084 session_id,
1085 token_usage: None,
1086 prompt_capabilities,
1087 _observe_prompt_capabilities: task,
1088 terminals: HashMap::default(),
1089 pending_terminal_output: HashMap::default(),
1090 pending_terminal_exit: HashMap::default(),
1091 }
1092 }
1093
1094 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1095 self.prompt_capabilities.clone()
1096 }
1097
1098 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1099 &self.connection
1100 }
1101
1102 pub fn action_log(&self) -> &Entity<ActionLog> {
1103 &self.action_log
1104 }
1105
1106 pub fn project(&self) -> &Entity<Project> {
1107 &self.project
1108 }
1109
1110 pub fn title(&self) -> SharedString {
1111 self.title.clone()
1112 }
1113
1114 pub fn entries(&self) -> &[AgentThreadEntry] {
1115 &self.entries
1116 }
1117
1118 pub fn session_id(&self) -> &acp::SessionId {
1119 &self.session_id
1120 }
1121
1122 pub fn status(&self) -> ThreadStatus {
1123 if self.send_task.is_some() {
1124 ThreadStatus::Generating
1125 } else {
1126 ThreadStatus::Idle
1127 }
1128 }
1129
1130 pub fn token_usage(&self) -> Option<&TokenUsage> {
1131 self.token_usage.as_ref()
1132 }
1133
1134 pub fn has_pending_edit_tool_calls(&self) -> bool {
1135 for entry in self.entries.iter().rev() {
1136 match entry {
1137 AgentThreadEntry::UserMessage(_) => return false,
1138 AgentThreadEntry::ToolCall(
1139 call @ ToolCall {
1140 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1141 ..
1142 },
1143 ) if call.diffs().next().is_some() => {
1144 return true;
1145 }
1146 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1147 }
1148 }
1149
1150 false
1151 }
1152
1153 pub fn used_tools_since_last_user_message(&self) -> bool {
1154 for entry in self.entries.iter().rev() {
1155 match entry {
1156 AgentThreadEntry::UserMessage(..) => return false,
1157 AgentThreadEntry::AssistantMessage(..) => continue,
1158 AgentThreadEntry::ToolCall(..) => return true,
1159 }
1160 }
1161
1162 false
1163 }
1164
1165 pub fn handle_session_update(
1166 &mut self,
1167 update: acp::SessionUpdate,
1168 cx: &mut Context<Self>,
1169 ) -> Result<(), acp::Error> {
1170 match update {
1171 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1172 self.push_user_content_block(None, content, cx);
1173 }
1174 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1175 self.push_assistant_content_block(content, false, cx);
1176 }
1177 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1178 self.push_assistant_content_block(content, true, cx);
1179 }
1180 acp::SessionUpdate::ToolCall(tool_call) => {
1181 self.upsert_tool_call(tool_call, cx)?;
1182 }
1183 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1184 self.update_tool_call(tool_call_update, cx)?;
1185 }
1186 acp::SessionUpdate::Plan(plan) => {
1187 self.update_plan(plan, cx);
1188 }
1189 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1190 available_commands,
1191 ..
1192 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1193 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1194 current_mode_id,
1195 ..
1196 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1197 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1198 config_options,
1199 ..
1200 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1201 _ => {}
1202 }
1203 Ok(())
1204 }
1205
1206 pub fn push_user_content_block(
1207 &mut self,
1208 message_id: Option<UserMessageId>,
1209 chunk: acp::ContentBlock,
1210 cx: &mut Context<Self>,
1211 ) {
1212 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1213 }
1214
1215 pub fn push_user_content_block_with_indent(
1216 &mut self,
1217 message_id: Option<UserMessageId>,
1218 chunk: acp::ContentBlock,
1219 indented: bool,
1220 cx: &mut Context<Self>,
1221 ) {
1222 let language_registry = self.project.read(cx).languages().clone();
1223 let path_style = self.project.read(cx).path_style(cx);
1224 let entries_len = self.entries.len();
1225
1226 if let Some(last_entry) = self.entries.last_mut()
1227 && let AgentThreadEntry::UserMessage(UserMessage {
1228 id,
1229 content,
1230 chunks,
1231 indented: existing_indented,
1232 ..
1233 }) = last_entry
1234 && *existing_indented == indented
1235 {
1236 *id = message_id.or(id.take());
1237 content.append(chunk.clone(), &language_registry, path_style, cx);
1238 chunks.push(chunk);
1239 let idx = entries_len - 1;
1240 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1241 } else {
1242 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1243 self.push_entry(
1244 AgentThreadEntry::UserMessage(UserMessage {
1245 id: message_id,
1246 content,
1247 chunks: vec![chunk],
1248 checkpoint: None,
1249 indented,
1250 }),
1251 cx,
1252 );
1253 }
1254 }
1255
1256 pub fn push_assistant_content_block(
1257 &mut self,
1258 chunk: acp::ContentBlock,
1259 is_thought: bool,
1260 cx: &mut Context<Self>,
1261 ) {
1262 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1263 }
1264
1265 pub fn push_assistant_content_block_with_indent(
1266 &mut self,
1267 chunk: acp::ContentBlock,
1268 is_thought: bool,
1269 indented: bool,
1270 cx: &mut Context<Self>,
1271 ) {
1272 let language_registry = self.project.read(cx).languages().clone();
1273 let path_style = self.project.read(cx).path_style(cx);
1274 let entries_len = self.entries.len();
1275 if let Some(last_entry) = self.entries.last_mut()
1276 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1277 chunks,
1278 indented: existing_indented,
1279 }) = last_entry
1280 && *existing_indented == indented
1281 {
1282 let idx = entries_len - 1;
1283 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1284 match (chunks.last_mut(), is_thought) {
1285 (Some(AssistantMessageChunk::Message { block }), false)
1286 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1287 block.append(chunk, &language_registry, path_style, cx)
1288 }
1289 _ => {
1290 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1291 if is_thought {
1292 chunks.push(AssistantMessageChunk::Thought { block })
1293 } else {
1294 chunks.push(AssistantMessageChunk::Message { block })
1295 }
1296 }
1297 }
1298 } else {
1299 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1300 let chunk = if is_thought {
1301 AssistantMessageChunk::Thought { block }
1302 } else {
1303 AssistantMessageChunk::Message { block }
1304 };
1305
1306 self.push_entry(
1307 AgentThreadEntry::AssistantMessage(AssistantMessage {
1308 chunks: vec![chunk],
1309 indented,
1310 }),
1311 cx,
1312 );
1313 }
1314 }
1315
1316 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1317 self.entries.push(entry);
1318 cx.emit(AcpThreadEvent::NewEntry);
1319 }
1320
1321 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1322 self.connection.set_title(&self.session_id, cx).is_some()
1323 }
1324
1325 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1326 if title != self.title {
1327 self.title = title.clone();
1328 cx.emit(AcpThreadEvent::TitleUpdated);
1329 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1330 return set_title.run(title, cx);
1331 }
1332 }
1333 Task::ready(Ok(()))
1334 }
1335
1336 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1337 self.token_usage = usage;
1338 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1339 }
1340
1341 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1342 cx.emit(AcpThreadEvent::Retry(status));
1343 }
1344
1345 pub fn update_tool_call(
1346 &mut self,
1347 update: impl Into<ToolCallUpdate>,
1348 cx: &mut Context<Self>,
1349 ) -> Result<()> {
1350 let update = update.into();
1351 let languages = self.project.read(cx).languages().clone();
1352 let path_style = self.project.read(cx).path_style(cx);
1353
1354 let ix = match self.index_for_tool_call(update.id()) {
1355 Some(ix) => ix,
1356 None => {
1357 // Tool call not found - create a failed tool call entry
1358 let failed_tool_call = ToolCall {
1359 id: update.id().clone(),
1360 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1361 kind: acp::ToolKind::Fetch,
1362 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1363 "Tool call not found".into(),
1364 &languages,
1365 path_style,
1366 cx,
1367 ))],
1368 status: ToolCallStatus::Failed,
1369 locations: Vec::new(),
1370 resolved_locations: Vec::new(),
1371 raw_input: None,
1372 raw_input_markdown: None,
1373 raw_output: None,
1374 };
1375 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1376 return Ok(());
1377 }
1378 };
1379 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1380 unreachable!()
1381 };
1382
1383 match update {
1384 ToolCallUpdate::UpdateFields(update) => {
1385 let location_updated = update.fields.locations.is_some();
1386 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1387 if location_updated {
1388 self.resolve_locations(update.tool_call_id, cx);
1389 }
1390 }
1391 ToolCallUpdate::UpdateDiff(update) => {
1392 call.content.clear();
1393 call.content.push(ToolCallContent::Diff(update.diff));
1394 }
1395 ToolCallUpdate::UpdateTerminal(update) => {
1396 call.content.clear();
1397 call.content
1398 .push(ToolCallContent::Terminal(update.terminal));
1399 }
1400 }
1401
1402 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1403
1404 Ok(())
1405 }
1406
1407 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1408 pub fn upsert_tool_call(
1409 &mut self,
1410 tool_call: acp::ToolCall,
1411 cx: &mut Context<Self>,
1412 ) -> Result<(), acp::Error> {
1413 let status = tool_call.status.into();
1414 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1415 }
1416
1417 /// Fails if id does not match an existing entry.
1418 pub fn upsert_tool_call_inner(
1419 &mut self,
1420 update: acp::ToolCallUpdate,
1421 status: ToolCallStatus,
1422 cx: &mut Context<Self>,
1423 ) -> Result<(), acp::Error> {
1424 let language_registry = self.project.read(cx).languages().clone();
1425 let path_style = self.project.read(cx).path_style(cx);
1426 let id = update.tool_call_id.clone();
1427
1428 let agent_telemetry_id = self.connection().telemetry_id();
1429 let session = self.session_id();
1430 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1431 let status = if matches!(status, ToolCallStatus::Completed) {
1432 "completed"
1433 } else {
1434 "failed"
1435 };
1436 telemetry::event!(
1437 "Agent Tool Call Completed",
1438 agent_telemetry_id,
1439 session,
1440 status
1441 );
1442 }
1443
1444 if let Some(ix) = self.index_for_tool_call(&id) {
1445 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1446 unreachable!()
1447 };
1448
1449 call.update_fields(
1450 update.fields,
1451 language_registry,
1452 path_style,
1453 &self.terminals,
1454 cx,
1455 )?;
1456 call.status = status;
1457
1458 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1459 } else {
1460 let call = ToolCall::from_acp(
1461 update.try_into()?,
1462 status,
1463 language_registry,
1464 self.project.read(cx).path_style(cx),
1465 &self.terminals,
1466 cx,
1467 )?;
1468 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1469 };
1470
1471 self.resolve_locations(id, cx);
1472 Ok(())
1473 }
1474
1475 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1476 self.entries
1477 .iter()
1478 .enumerate()
1479 .rev()
1480 .find_map(|(index, entry)| {
1481 if let AgentThreadEntry::ToolCall(tool_call) = entry
1482 && &tool_call.id == id
1483 {
1484 Some(index)
1485 } else {
1486 None
1487 }
1488 })
1489 }
1490
1491 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1492 // The tool call we are looking for is typically the last one, or very close to the end.
1493 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1494 self.entries
1495 .iter_mut()
1496 .enumerate()
1497 .rev()
1498 .find_map(|(index, tool_call)| {
1499 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1500 && &tool_call.id == id
1501 {
1502 Some((index, tool_call))
1503 } else {
1504 None
1505 }
1506 })
1507 }
1508
1509 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1510 self.entries
1511 .iter()
1512 .enumerate()
1513 .rev()
1514 .find_map(|(index, tool_call)| {
1515 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1516 && &tool_call.id == id
1517 {
1518 Some((index, tool_call))
1519 } else {
1520 None
1521 }
1522 })
1523 }
1524
1525 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1526 let project = self.project.clone();
1527 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1528 return;
1529 };
1530 let task = tool_call.resolve_locations(project, cx);
1531 cx.spawn(async move |this, cx| {
1532 let resolved_locations = task.await;
1533
1534 this.update(cx, |this, cx| {
1535 let project = this.project.clone();
1536
1537 for location in resolved_locations.iter().flatten() {
1538 this.shared_buffers
1539 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1540 }
1541 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1542 return;
1543 };
1544
1545 if let Some(Some(location)) = resolved_locations.last() {
1546 project.update(cx, |project, cx| {
1547 let should_ignore = if let Some(agent_location) = project
1548 .agent_location()
1549 .filter(|agent_location| agent_location.buffer == location.buffer)
1550 {
1551 let snapshot = location.buffer.read(cx).snapshot();
1552 let old_position = agent_location.position.to_point(&snapshot);
1553 let new_position = location.position.to_point(&snapshot);
1554
1555 // ignore this so that when we get updates from the edit tool
1556 // the position doesn't reset to the startof line
1557 old_position.row == new_position.row
1558 && old_position.column > new_position.column
1559 } else {
1560 false
1561 };
1562 if !should_ignore {
1563 project.set_agent_location(Some(location.into()), cx);
1564 }
1565 });
1566 }
1567
1568 let resolved_locations = resolved_locations
1569 .iter()
1570 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1571 .collect::<Vec<_>>();
1572
1573 if tool_call.resolved_locations != resolved_locations {
1574 tool_call.resolved_locations = resolved_locations;
1575 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1576 }
1577 })
1578 })
1579 .detach();
1580 }
1581
1582 pub fn request_tool_call_authorization(
1583 &mut self,
1584 tool_call: acp::ToolCallUpdate,
1585 options: Vec<acp::PermissionOption>,
1586 respect_always_allow_setting: bool,
1587 cx: &mut Context<Self>,
1588 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1589 let (tx, rx) = oneshot::channel();
1590
1591 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1592 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1593 // some tools would (incorrectly) continue to auto-accept.
1594 if let Some(allow_once_option) = options.iter().find_map(|option| {
1595 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1596 Some(option.option_id.clone())
1597 } else {
1598 None
1599 }
1600 }) {
1601 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1602 return Ok(async {
1603 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1604 allow_once_option,
1605 ))
1606 }
1607 .boxed());
1608 }
1609 }
1610
1611 let status = ToolCallStatus::WaitingForConfirmation {
1612 options,
1613 respond_tx: tx,
1614 };
1615
1616 self.upsert_tool_call_inner(tool_call, status, cx)?;
1617 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1618
1619 let fut = async {
1620 match rx.await {
1621 Ok(option) => acp::RequestPermissionOutcome::Selected(
1622 acp::SelectedPermissionOutcome::new(option),
1623 ),
1624 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1625 }
1626 }
1627 .boxed();
1628
1629 Ok(fut)
1630 }
1631
1632 pub fn authorize_tool_call(
1633 &mut self,
1634 id: acp::ToolCallId,
1635 option_id: acp::PermissionOptionId,
1636 option_kind: acp::PermissionOptionKind,
1637 cx: &mut Context<Self>,
1638 ) {
1639 let Some((ix, call)) = self.tool_call_mut(&id) else {
1640 return;
1641 };
1642
1643 let new_status = match option_kind {
1644 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1645 ToolCallStatus::Rejected
1646 }
1647 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1648 ToolCallStatus::InProgress
1649 }
1650 _ => ToolCallStatus::InProgress,
1651 };
1652
1653 let curr_status = mem::replace(&mut call.status, new_status);
1654
1655 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1656 respond_tx.send(option_id).log_err();
1657 } else if cfg!(debug_assertions) {
1658 panic!("tried to authorize an already authorized tool call");
1659 }
1660
1661 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1662 }
1663
1664 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1665 let mut first_tool_call = None;
1666
1667 for entry in self.entries.iter().rev() {
1668 match &entry {
1669 AgentThreadEntry::ToolCall(call) => {
1670 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1671 first_tool_call = Some(call);
1672 } else {
1673 continue;
1674 }
1675 }
1676 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1677 // Reached the beginning of the turn.
1678 // If we had pending permission requests in the previous turn, they have been cancelled.
1679 break;
1680 }
1681 }
1682 }
1683
1684 first_tool_call
1685 }
1686
1687 pub fn plan(&self) -> &Plan {
1688 &self.plan
1689 }
1690
1691 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1692 let new_entries_len = request.entries.len();
1693 let mut new_entries = request.entries.into_iter();
1694
1695 // Reuse existing markdown to prevent flickering
1696 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1697 let PlanEntry {
1698 content,
1699 priority,
1700 status,
1701 } = old;
1702 content.update(cx, |old, cx| {
1703 old.replace(new.content, cx);
1704 });
1705 *priority = new.priority;
1706 *status = new.status;
1707 }
1708 for new in new_entries {
1709 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1710 }
1711 self.plan.entries.truncate(new_entries_len);
1712
1713 cx.notify();
1714 }
1715
1716 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1717 self.plan
1718 .entries
1719 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1720 cx.notify();
1721 }
1722
1723 #[cfg(any(test, feature = "test-support"))]
1724 pub fn send_raw(
1725 &mut self,
1726 message: &str,
1727 cx: &mut Context<Self>,
1728 ) -> BoxFuture<'static, Result<()>> {
1729 self.send(vec![message.into()], cx)
1730 }
1731
1732 pub fn send(
1733 &mut self,
1734 message: Vec<acp::ContentBlock>,
1735 cx: &mut Context<Self>,
1736 ) -> BoxFuture<'static, Result<()>> {
1737 let block = ContentBlock::new_combined(
1738 message.clone(),
1739 self.project.read(cx).languages().clone(),
1740 self.project.read(cx).path_style(cx),
1741 cx,
1742 );
1743 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1744 let git_store = self.project.read(cx).git_store().clone();
1745
1746 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1747 Some(UserMessageId::new())
1748 } else {
1749 None
1750 };
1751
1752 self.run_turn(cx, async move |this, cx| {
1753 this.update(cx, |this, cx| {
1754 this.push_entry(
1755 AgentThreadEntry::UserMessage(UserMessage {
1756 id: message_id.clone(),
1757 content: block,
1758 chunks: message,
1759 checkpoint: None,
1760 indented: false,
1761 }),
1762 cx,
1763 );
1764 })
1765 .ok();
1766
1767 let old_checkpoint = git_store
1768 .update(cx, |git, cx| git.checkpoint(cx))?
1769 .await
1770 .context("failed to get old checkpoint")
1771 .log_err();
1772 this.update(cx, |this, cx| {
1773 if let Some((_ix, message)) = this.last_user_message() {
1774 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1775 git_checkpoint,
1776 show: false,
1777 });
1778 }
1779 this.connection.prompt(message_id, request, cx)
1780 })?
1781 .await
1782 })
1783 }
1784
1785 pub fn can_resume(&self, cx: &App) -> bool {
1786 self.connection.resume(&self.session_id, cx).is_some()
1787 }
1788
1789 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1790 self.run_turn(cx, async move |this, cx| {
1791 this.update(cx, |this, cx| {
1792 this.connection
1793 .resume(&this.session_id, cx)
1794 .map(|resume| resume.run(cx))
1795 })?
1796 .context("resuming a session is not supported")?
1797 .await
1798 })
1799 }
1800
1801 fn run_turn(
1802 &mut self,
1803 cx: &mut Context<Self>,
1804 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1805 ) -> BoxFuture<'static, Result<()>> {
1806 self.clear_completed_plan_entries(cx);
1807
1808 let (tx, rx) = oneshot::channel();
1809 let cancel_task = self.cancel(cx);
1810
1811 self.send_task = Some(cx.spawn(async move |this, cx| {
1812 cancel_task.await;
1813 tx.send(f(this, cx).await).ok();
1814 }));
1815
1816 cx.spawn(async move |this, cx| {
1817 let response = rx.await;
1818
1819 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1820 .await?;
1821
1822 this.update(cx, |this, cx| {
1823 this.project
1824 .update(cx, |project, cx| project.set_agent_location(None, cx));
1825 match response {
1826 Ok(Err(e)) => {
1827 this.send_task.take();
1828 cx.emit(AcpThreadEvent::Error);
1829 Err(e)
1830 }
1831 result => {
1832 let canceled = matches!(
1833 result,
1834 Ok(Ok(acp::PromptResponse {
1835 stop_reason: acp::StopReason::Cancelled,
1836 ..
1837 }))
1838 );
1839
1840 // We only take the task if the current prompt wasn't canceled.
1841 //
1842 // This prompt may have been canceled because another one was sent
1843 // while it was still generating. In these cases, dropping `send_task`
1844 // would cause the next generation to be canceled.
1845 if !canceled {
1846 this.send_task.take();
1847 }
1848
1849 // Handle refusal - distinguish between user prompt and tool call refusals
1850 if let Ok(Ok(acp::PromptResponse {
1851 stop_reason: acp::StopReason::Refusal,
1852 ..
1853 })) = result
1854 {
1855 if let Some((user_msg_ix, _)) = this.last_user_message() {
1856 // Check if there's a completed tool call with results after the last user message
1857 // This indicates the refusal is in response to tool output, not the user's prompt
1858 let has_completed_tool_call_after_user_msg =
1859 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1860 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1861 // Check if the tool call has completed and has output
1862 matches!(tool_call.status, ToolCallStatus::Completed)
1863 && tool_call.raw_output.is_some()
1864 } else {
1865 false
1866 }
1867 });
1868
1869 if has_completed_tool_call_after_user_msg {
1870 // Refusal is due to tool output - don't truncate, just notify
1871 // The model refused based on what the tool returned
1872 cx.emit(AcpThreadEvent::Refusal);
1873 } else {
1874 // User prompt was refused - truncate back to before the user message
1875 let range = user_msg_ix..this.entries.len();
1876 if range.start < range.end {
1877 this.entries.truncate(user_msg_ix);
1878 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1879 }
1880 cx.emit(AcpThreadEvent::Refusal);
1881 }
1882 } else {
1883 // No user message found, treat as general refusal
1884 cx.emit(AcpThreadEvent::Refusal);
1885 }
1886 }
1887
1888 cx.emit(AcpThreadEvent::Stopped);
1889 Ok(())
1890 }
1891 }
1892 })?
1893 })
1894 .boxed()
1895 }
1896
1897 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1898 let Some(send_task) = self.send_task.take() else {
1899 return Task::ready(());
1900 };
1901
1902 for entry in self.entries.iter_mut() {
1903 if let AgentThreadEntry::ToolCall(call) = entry {
1904 let cancel = matches!(
1905 call.status,
1906 ToolCallStatus::Pending
1907 | ToolCallStatus::WaitingForConfirmation { .. }
1908 | ToolCallStatus::InProgress
1909 );
1910
1911 if cancel {
1912 call.status = ToolCallStatus::Canceled;
1913 }
1914 }
1915 }
1916
1917 self.connection.cancel(&self.session_id, cx);
1918
1919 // Wait for the send task to complete
1920 cx.foreground_executor().spawn(send_task)
1921 }
1922
1923 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1924 pub fn restore_checkpoint(
1925 &mut self,
1926 id: UserMessageId,
1927 cx: &mut Context<Self>,
1928 ) -> Task<Result<()>> {
1929 let Some((_, message)) = self.user_message_mut(&id) else {
1930 return Task::ready(Err(anyhow!("message not found")));
1931 };
1932
1933 let checkpoint = message
1934 .checkpoint
1935 .as_ref()
1936 .map(|c| c.git_checkpoint.clone());
1937
1938 // Cancel any in-progress generation before restoring
1939 let cancel_task = self.cancel(cx);
1940 let rewind = self.rewind(id.clone(), cx);
1941 let git_store = self.project.read(cx).git_store().clone();
1942
1943 cx.spawn(async move |_, cx| {
1944 cancel_task.await;
1945 rewind.await?;
1946 if let Some(checkpoint) = checkpoint {
1947 git_store
1948 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1949 .await?;
1950 }
1951
1952 Ok(())
1953 })
1954 }
1955
1956 /// Rewinds this thread to before the entry at `index`, removing it and all
1957 /// subsequent entries while rejecting any action_log changes made from that point.
1958 /// Unlike `restore_checkpoint`, this method does not restore from git.
1959 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1960 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1961 return Task::ready(Err(anyhow!("not supported")));
1962 };
1963
1964 let telemetry = ActionLogTelemetry::from(&*self);
1965 cx.spawn(async move |this, cx| {
1966 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1967 this.update(cx, |this, cx| {
1968 if let Some((ix, _)) = this.user_message_mut(&id) {
1969 // Collect all terminals from entries that will be removed
1970 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
1971 .iter()
1972 .flat_map(|entry| entry.terminals())
1973 .filter_map(|terminal| terminal.read(cx).id().clone().into())
1974 .collect();
1975
1976 let range = ix..this.entries.len();
1977 this.entries.truncate(ix);
1978 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1979
1980 // Kill and remove the terminals
1981 for terminal_id in terminals_to_remove {
1982 if let Some(terminal) = this.terminals.remove(&terminal_id) {
1983 terminal.update(cx, |terminal, cx| {
1984 terminal.kill(cx);
1985 });
1986 }
1987 }
1988 }
1989 this.action_log().update(cx, |action_log, cx| {
1990 action_log.reject_all_edits(Some(telemetry), cx)
1991 })
1992 })?
1993 .await;
1994 Ok(())
1995 })
1996 }
1997
1998 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1999 let git_store = self.project.read(cx).git_store().clone();
2000
2001 let Some((_, message)) = self.last_user_message() else {
2002 return Task::ready(Ok(()));
2003 };
2004 let Some(user_message_id) = message.id.clone() else {
2005 return Task::ready(Ok(()));
2006 };
2007 let Some(checkpoint) = message.checkpoint.as_ref() else {
2008 return Task::ready(Ok(()));
2009 };
2010 let old_checkpoint = checkpoint.git_checkpoint.clone();
2011
2012 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2013 cx.spawn(async move |this, cx| {
2014 let Some(new_checkpoint) = new_checkpoint
2015 .await
2016 .context("failed to get new checkpoint")
2017 .log_err()
2018 else {
2019 return Ok(());
2020 };
2021
2022 let equal = git_store
2023 .update(cx, |git, cx| {
2024 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2025 })?
2026 .await
2027 .unwrap_or(true);
2028
2029 this.update(cx, |this, cx| {
2030 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2031 if let Some(checkpoint) = message.checkpoint.as_mut() {
2032 checkpoint.show = !equal;
2033 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2034 }
2035 }
2036 })?;
2037
2038 Ok(())
2039 })
2040 }
2041
2042 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2043 self.entries
2044 .iter_mut()
2045 .enumerate()
2046 .rev()
2047 .find_map(|(ix, entry)| {
2048 if let AgentThreadEntry::UserMessage(message) = entry {
2049 Some((ix, message))
2050 } else {
2051 None
2052 }
2053 })
2054 }
2055
2056 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2057 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2058 if let AgentThreadEntry::UserMessage(message) = entry {
2059 if message.id.as_ref() == Some(id) {
2060 Some((ix, message))
2061 } else {
2062 None
2063 }
2064 } else {
2065 None
2066 }
2067 })
2068 }
2069
2070 pub fn read_text_file(
2071 &self,
2072 path: PathBuf,
2073 line: Option<u32>,
2074 limit: Option<u32>,
2075 reuse_shared_snapshot: bool,
2076 cx: &mut Context<Self>,
2077 ) -> Task<Result<String, acp::Error>> {
2078 // Args are 1-based, move to 0-based
2079 let line = line.unwrap_or_default().saturating_sub(1);
2080 let limit = limit.unwrap_or(u32::MAX);
2081 let project = self.project.clone();
2082 let action_log = self.action_log.clone();
2083 cx.spawn(async move |this, cx| {
2084 let load = project
2085 .update(cx, |project, cx| {
2086 let path = project
2087 .project_path_for_absolute_path(&path, cx)
2088 .ok_or_else(|| {
2089 acp::Error::resource_not_found(Some(path.display().to_string()))
2090 })?;
2091 Ok(project.open_buffer(path, cx))
2092 })
2093 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2094 .flatten()?;
2095
2096 let buffer = load.await?;
2097
2098 let snapshot = if reuse_shared_snapshot {
2099 this.read_with(cx, |this, _| {
2100 this.shared_buffers.get(&buffer.clone()).cloned()
2101 })
2102 .log_err()
2103 .flatten()
2104 } else {
2105 None
2106 };
2107
2108 let snapshot = if let Some(snapshot) = snapshot {
2109 snapshot
2110 } else {
2111 action_log.update(cx, |action_log, cx| {
2112 action_log.buffer_read(buffer.clone(), cx);
2113 })?;
2114
2115 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2116 this.update(cx, |this, _| {
2117 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2118 })?;
2119 snapshot
2120 };
2121
2122 let max_point = snapshot.max_point();
2123 let start_position = Point::new(line, 0);
2124
2125 if start_position > max_point {
2126 return Err(acp::Error::invalid_params().data(format!(
2127 "Attempting to read beyond the end of the file, line {}:{}",
2128 max_point.row + 1,
2129 max_point.column
2130 )));
2131 }
2132
2133 let start = snapshot.anchor_before(start_position);
2134 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2135
2136 project.update(cx, |project, cx| {
2137 project.set_agent_location(
2138 Some(AgentLocation {
2139 buffer: buffer.downgrade(),
2140 position: start,
2141 }),
2142 cx,
2143 );
2144 })?;
2145
2146 Ok(snapshot.text_for_range(start..end).collect::<String>())
2147 })
2148 }
2149
2150 pub fn write_text_file(
2151 &self,
2152 path: PathBuf,
2153 content: String,
2154 cx: &mut Context<Self>,
2155 ) -> Task<Result<()>> {
2156 let project = self.project.clone();
2157 let action_log = self.action_log.clone();
2158 cx.spawn(async move |this, cx| {
2159 let load = project.update(cx, |project, cx| {
2160 let path = project
2161 .project_path_for_absolute_path(&path, cx)
2162 .context("invalid path")?;
2163 anyhow::Ok(project.open_buffer(path, cx))
2164 });
2165 let buffer = load??.await?;
2166 let snapshot = this.update(cx, |this, cx| {
2167 this.shared_buffers
2168 .get(&buffer)
2169 .cloned()
2170 .unwrap_or_else(|| buffer.read(cx).snapshot())
2171 })?;
2172 let edits = cx
2173 .background_executor()
2174 .spawn(async move {
2175 let old_text = snapshot.text();
2176 text_diff(old_text.as_str(), &content)
2177 .into_iter()
2178 .map(|(range, replacement)| {
2179 (
2180 snapshot.anchor_after(range.start)
2181 ..snapshot.anchor_before(range.end),
2182 replacement,
2183 )
2184 })
2185 .collect::<Vec<_>>()
2186 })
2187 .await;
2188
2189 project.update(cx, |project, cx| {
2190 project.set_agent_location(
2191 Some(AgentLocation {
2192 buffer: buffer.downgrade(),
2193 position: edits
2194 .last()
2195 .map(|(range, _)| range.end)
2196 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2197 }),
2198 cx,
2199 );
2200 })?;
2201
2202 let format_on_save = cx.update(|cx| {
2203 action_log.update(cx, |action_log, cx| {
2204 action_log.buffer_read(buffer.clone(), cx);
2205 });
2206
2207 let format_on_save = buffer.update(cx, |buffer, cx| {
2208 buffer.edit(edits, None, cx);
2209
2210 let settings = language::language_settings::language_settings(
2211 buffer.language().map(|l| l.name()),
2212 buffer.file(),
2213 cx,
2214 );
2215
2216 settings.format_on_save != FormatOnSave::Off
2217 });
2218 action_log.update(cx, |action_log, cx| {
2219 action_log.buffer_edited(buffer.clone(), cx);
2220 });
2221 format_on_save
2222 })?;
2223
2224 if format_on_save {
2225 let format_task = project.update(cx, |project, cx| {
2226 project.format(
2227 HashSet::from_iter([buffer.clone()]),
2228 LspFormatTarget::Buffers,
2229 false,
2230 FormatTrigger::Save,
2231 cx,
2232 )
2233 })?;
2234 format_task.await.log_err();
2235
2236 action_log.update(cx, |action_log, cx| {
2237 action_log.buffer_edited(buffer.clone(), cx);
2238 })?;
2239 }
2240
2241 project
2242 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2243 .await
2244 })
2245 }
2246
2247 pub fn create_terminal(
2248 &self,
2249 command: String,
2250 args: Vec<String>,
2251 extra_env: Vec<acp::EnvVariable>,
2252 cwd: Option<PathBuf>,
2253 output_byte_limit: Option<u64>,
2254 cx: &mut Context<Self>,
2255 ) -> Task<Result<Entity<Terminal>>> {
2256 let env = match &cwd {
2257 Some(dir) => self.project.update(cx, |project, cx| {
2258 project.environment().update(cx, |env, cx| {
2259 env.directory_environment(dir.as_path().into(), cx)
2260 })
2261 }),
2262 None => Task::ready(None).shared(),
2263 };
2264 let env = cx.spawn(async move |_, _| {
2265 let mut env = env.await.unwrap_or_default();
2266 // Disables paging for `git` and hopefully other commands
2267 env.insert("PAGER".into(), "".into());
2268 for var in extra_env {
2269 env.insert(var.name, var.value);
2270 }
2271 env
2272 });
2273
2274 let project = self.project.clone();
2275 let language_registry = project.read(cx).languages().clone();
2276 let is_windows = project.read(cx).path_style(cx).is_windows();
2277
2278 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2279 let terminal_task = cx.spawn({
2280 let terminal_id = terminal_id.clone();
2281 async move |_this, cx| {
2282 let env = env.await;
2283 let shell = project
2284 .update(cx, |project, cx| {
2285 project
2286 .remote_client()
2287 .and_then(|r| r.read(cx).default_system_shell())
2288 })?
2289 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2290 let (task_command, task_args) =
2291 ShellBuilder::new(&Shell::Program(shell), is_windows)
2292 .redirect_stdin_to_dev_null()
2293 .build(Some(command.clone()), &args);
2294 let terminal = project
2295 .update(cx, |project, cx| {
2296 project.create_terminal_task(
2297 task::SpawnInTerminal {
2298 command: Some(task_command),
2299 args: task_args,
2300 cwd: cwd.clone(),
2301 env,
2302 ..Default::default()
2303 },
2304 cx,
2305 )
2306 })?
2307 .await?;
2308
2309 cx.new(|cx| {
2310 Terminal::new(
2311 terminal_id,
2312 &format!("{} {}", command, args.join(" ")),
2313 cwd,
2314 output_byte_limit.map(|l| l as usize),
2315 terminal,
2316 language_registry,
2317 cx,
2318 )
2319 })
2320 }
2321 });
2322
2323 cx.spawn(async move |this, cx| {
2324 let terminal = terminal_task.await?;
2325 this.update(cx, |this, _cx| {
2326 this.terminals.insert(terminal_id, terminal.clone());
2327 terminal
2328 })
2329 })
2330 }
2331
2332 pub fn kill_terminal(
2333 &mut self,
2334 terminal_id: acp::TerminalId,
2335 cx: &mut Context<Self>,
2336 ) -> Result<()> {
2337 self.terminals
2338 .get(&terminal_id)
2339 .context("Terminal not found")?
2340 .update(cx, |terminal, cx| {
2341 terminal.kill(cx);
2342 });
2343
2344 Ok(())
2345 }
2346
2347 pub fn release_terminal(
2348 &mut self,
2349 terminal_id: acp::TerminalId,
2350 cx: &mut Context<Self>,
2351 ) -> Result<()> {
2352 self.terminals
2353 .remove(&terminal_id)
2354 .context("Terminal not found")?
2355 .update(cx, |terminal, cx| {
2356 terminal.kill(cx);
2357 });
2358
2359 Ok(())
2360 }
2361
2362 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2363 self.terminals
2364 .get(&terminal_id)
2365 .context("Terminal not found")
2366 .cloned()
2367 }
2368
2369 pub fn to_markdown(&self, cx: &App) -> String {
2370 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2371 }
2372
2373 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2374 cx.emit(AcpThreadEvent::LoadError(error));
2375 }
2376
2377 pub fn register_terminal_created(
2378 &mut self,
2379 terminal_id: acp::TerminalId,
2380 command_label: String,
2381 working_dir: Option<PathBuf>,
2382 output_byte_limit: Option<u64>,
2383 terminal: Entity<::terminal::Terminal>,
2384 cx: &mut Context<Self>,
2385 ) -> Entity<Terminal> {
2386 let language_registry = self.project.read(cx).languages().clone();
2387
2388 let entity = cx.new(|cx| {
2389 Terminal::new(
2390 terminal_id.clone(),
2391 &command_label,
2392 working_dir.clone(),
2393 output_byte_limit.map(|l| l as usize),
2394 terminal,
2395 language_registry,
2396 cx,
2397 )
2398 });
2399 self.terminals.insert(terminal_id.clone(), entity.clone());
2400 entity
2401 }
2402}
2403
2404fn markdown_for_raw_output(
2405 raw_output: &serde_json::Value,
2406 language_registry: &Arc<LanguageRegistry>,
2407 cx: &mut App,
2408) -> Option<Entity<Markdown>> {
2409 match raw_output {
2410 serde_json::Value::Null => None,
2411 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2412 Markdown::new(
2413 value.to_string().into(),
2414 Some(language_registry.clone()),
2415 None,
2416 cx,
2417 )
2418 })),
2419 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2420 Markdown::new(
2421 value.to_string().into(),
2422 Some(language_registry.clone()),
2423 None,
2424 cx,
2425 )
2426 })),
2427 serde_json::Value::String(value) => Some(cx.new(|cx| {
2428 Markdown::new(
2429 value.clone().into(),
2430 Some(language_registry.clone()),
2431 None,
2432 cx,
2433 )
2434 })),
2435 value => Some(cx.new(|cx| {
2436 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2437
2438 Markdown::new(
2439 format!("```json\n{}\n```", pretty_json).into(),
2440 Some(language_registry.clone()),
2441 None,
2442 cx,
2443 )
2444 })),
2445 }
2446}
2447
2448#[cfg(test)]
2449mod tests {
2450 use super::*;
2451 use anyhow::anyhow;
2452 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2453 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2454 use indoc::indoc;
2455 use project::{FakeFs, Fs};
2456 use rand::{distr, prelude::*};
2457 use serde_json::json;
2458 use settings::SettingsStore;
2459 use smol::stream::StreamExt as _;
2460 use std::{
2461 any::Any,
2462 cell::RefCell,
2463 path::Path,
2464 rc::Rc,
2465 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2466 time::Duration,
2467 };
2468 use util::path;
2469
2470 fn init_test(cx: &mut TestAppContext) {
2471 env_logger::try_init().ok();
2472 cx.update(|cx| {
2473 let settings_store = SettingsStore::test(cx);
2474 cx.set_global(settings_store);
2475 });
2476 }
2477
2478 #[gpui::test]
2479 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2480 init_test(cx);
2481
2482 let fs = FakeFs::new(cx.executor());
2483 let project = Project::test(fs, [], cx).await;
2484 let connection = Rc::new(FakeAgentConnection::new());
2485 let thread = cx
2486 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2487 .await
2488 .unwrap();
2489
2490 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2491
2492 // Send Output BEFORE Created - should be buffered by acp_thread
2493 thread.update(cx, |thread, cx| {
2494 thread.on_terminal_provider_event(
2495 TerminalProviderEvent::Output {
2496 terminal_id: terminal_id.clone(),
2497 data: b"hello buffered".to_vec(),
2498 },
2499 cx,
2500 );
2501 });
2502
2503 // Create a display-only terminal and then send Created
2504 let lower = cx.new(|cx| {
2505 let builder = ::terminal::TerminalBuilder::new_display_only(
2506 ::terminal::terminal_settings::CursorShape::default(),
2507 ::terminal::terminal_settings::AlternateScroll::On,
2508 None,
2509 0,
2510 )
2511 .unwrap();
2512 builder.subscribe(cx)
2513 });
2514
2515 thread.update(cx, |thread, cx| {
2516 thread.on_terminal_provider_event(
2517 TerminalProviderEvent::Created {
2518 terminal_id: terminal_id.clone(),
2519 label: "Buffered Test".to_string(),
2520 cwd: None,
2521 output_byte_limit: None,
2522 terminal: lower.clone(),
2523 },
2524 cx,
2525 );
2526 });
2527
2528 // After Created, buffered Output should have been flushed into the renderer
2529 let content = thread.read_with(cx, |thread, cx| {
2530 let term = thread.terminal(terminal_id.clone()).unwrap();
2531 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2532 });
2533
2534 assert!(
2535 content.contains("hello buffered"),
2536 "expected buffered output to render, got: {content}"
2537 );
2538 }
2539
2540 #[gpui::test]
2541 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2542 init_test(cx);
2543
2544 let fs = FakeFs::new(cx.executor());
2545 let project = Project::test(fs, [], cx).await;
2546 let connection = Rc::new(FakeAgentConnection::new());
2547 let thread = cx
2548 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2549 .await
2550 .unwrap();
2551
2552 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2553
2554 // Send Output BEFORE Created
2555 thread.update(cx, |thread, cx| {
2556 thread.on_terminal_provider_event(
2557 TerminalProviderEvent::Output {
2558 terminal_id: terminal_id.clone(),
2559 data: b"pre-exit data".to_vec(),
2560 },
2561 cx,
2562 );
2563 });
2564
2565 // Send Exit BEFORE Created
2566 thread.update(cx, |thread, cx| {
2567 thread.on_terminal_provider_event(
2568 TerminalProviderEvent::Exit {
2569 terminal_id: terminal_id.clone(),
2570 status: acp::TerminalExitStatus::new().exit_code(0),
2571 },
2572 cx,
2573 );
2574 });
2575
2576 // Now create a display-only lower-level terminal and send Created
2577 let lower = cx.new(|cx| {
2578 let builder = ::terminal::TerminalBuilder::new_display_only(
2579 ::terminal::terminal_settings::CursorShape::default(),
2580 ::terminal::terminal_settings::AlternateScroll::On,
2581 None,
2582 0,
2583 )
2584 .unwrap();
2585 builder.subscribe(cx)
2586 });
2587
2588 thread.update(cx, |thread, cx| {
2589 thread.on_terminal_provider_event(
2590 TerminalProviderEvent::Created {
2591 terminal_id: terminal_id.clone(),
2592 label: "Buffered Exit Test".to_string(),
2593 cwd: None,
2594 output_byte_limit: None,
2595 terminal: lower.clone(),
2596 },
2597 cx,
2598 );
2599 });
2600
2601 // Output should be present after Created (flushed from buffer)
2602 let content = thread.read_with(cx, |thread, cx| {
2603 let term = thread.terminal(terminal_id.clone()).unwrap();
2604 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2605 });
2606
2607 assert!(
2608 content.contains("pre-exit data"),
2609 "expected pre-exit data to render, got: {content}"
2610 );
2611 }
2612
2613 #[gpui::test]
2614 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2615 init_test(cx);
2616
2617 let fs = FakeFs::new(cx.executor());
2618 let project = Project::test(fs, [], cx).await;
2619 let connection = Rc::new(FakeAgentConnection::new());
2620 let thread = cx
2621 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2622 .await
2623 .unwrap();
2624
2625 // Test creating a new user message
2626 thread.update(cx, |thread, cx| {
2627 thread.push_user_content_block(None, "Hello, ".into(), cx);
2628 });
2629
2630 thread.update(cx, |thread, cx| {
2631 assert_eq!(thread.entries.len(), 1);
2632 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2633 assert_eq!(user_msg.id, None);
2634 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2635 } else {
2636 panic!("Expected UserMessage");
2637 }
2638 });
2639
2640 // Test appending to existing user message
2641 let message_1_id = UserMessageId::new();
2642 thread.update(cx, |thread, cx| {
2643 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2644 });
2645
2646 thread.update(cx, |thread, cx| {
2647 assert_eq!(thread.entries.len(), 1);
2648 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2649 assert_eq!(user_msg.id, Some(message_1_id));
2650 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2651 } else {
2652 panic!("Expected UserMessage");
2653 }
2654 });
2655
2656 // Test creating new user message after assistant message
2657 thread.update(cx, |thread, cx| {
2658 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2659 });
2660
2661 let message_2_id = UserMessageId::new();
2662 thread.update(cx, |thread, cx| {
2663 thread.push_user_content_block(
2664 Some(message_2_id.clone()),
2665 "New user message".into(),
2666 cx,
2667 );
2668 });
2669
2670 thread.update(cx, |thread, cx| {
2671 assert_eq!(thread.entries.len(), 3);
2672 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2673 assert_eq!(user_msg.id, Some(message_2_id));
2674 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2675 } else {
2676 panic!("Expected UserMessage at index 2");
2677 }
2678 });
2679 }
2680
2681 #[gpui::test]
2682 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2683 init_test(cx);
2684
2685 let fs = FakeFs::new(cx.executor());
2686 let project = Project::test(fs, [], cx).await;
2687 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2688 |_, thread, mut cx| {
2689 async move {
2690 thread.update(&mut cx, |thread, cx| {
2691 thread
2692 .handle_session_update(
2693 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2694 "Thinking ".into(),
2695 )),
2696 cx,
2697 )
2698 .unwrap();
2699 thread
2700 .handle_session_update(
2701 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2702 "hard!".into(),
2703 )),
2704 cx,
2705 )
2706 .unwrap();
2707 })?;
2708 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2709 }
2710 .boxed_local()
2711 },
2712 ));
2713
2714 let thread = cx
2715 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2716 .await
2717 .unwrap();
2718
2719 thread
2720 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2721 .await
2722 .unwrap();
2723
2724 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2725 assert_eq!(
2726 output,
2727 indoc! {r#"
2728 ## User
2729
2730 Hello from Zed!
2731
2732 ## Assistant
2733
2734 <thinking>
2735 Thinking hard!
2736 </thinking>
2737
2738 "#}
2739 );
2740 }
2741
2742 #[gpui::test]
2743 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2744 init_test(cx);
2745
2746 let fs = FakeFs::new(cx.executor());
2747 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2748 .await;
2749 let project = Project::test(fs.clone(), [], cx).await;
2750 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2751 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2752 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2753 move |_, thread, mut cx| {
2754 let read_file_tx = read_file_tx.clone();
2755 async move {
2756 let content = thread
2757 .update(&mut cx, |thread, cx| {
2758 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2759 })
2760 .unwrap()
2761 .await
2762 .unwrap();
2763 assert_eq!(content, "one\ntwo\nthree\n");
2764 read_file_tx.take().unwrap().send(()).unwrap();
2765 thread
2766 .update(&mut cx, |thread, cx| {
2767 thread.write_text_file(
2768 path!("/tmp/foo").into(),
2769 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2770 cx,
2771 )
2772 })
2773 .unwrap()
2774 .await
2775 .unwrap();
2776 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2777 }
2778 .boxed_local()
2779 },
2780 ));
2781
2782 let (worktree, pathbuf) = project
2783 .update(cx, |project, cx| {
2784 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2785 })
2786 .await
2787 .unwrap();
2788 let buffer = project
2789 .update(cx, |project, cx| {
2790 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2791 })
2792 .await
2793 .unwrap();
2794
2795 let thread = cx
2796 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2797 .await
2798 .unwrap();
2799
2800 let request = thread.update(cx, |thread, cx| {
2801 thread.send_raw("Extend the count in /tmp/foo", cx)
2802 });
2803 read_file_rx.await.ok();
2804 buffer.update(cx, |buffer, cx| {
2805 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2806 });
2807 cx.run_until_parked();
2808 assert_eq!(
2809 buffer.read_with(cx, |buffer, _| buffer.text()),
2810 "zero\none\ntwo\nthree\nfour\nfive\n"
2811 );
2812 assert_eq!(
2813 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2814 "zero\none\ntwo\nthree\nfour\nfive\n"
2815 );
2816 request.await.unwrap();
2817 }
2818
2819 #[gpui::test]
2820 async fn test_reading_from_line(cx: &mut TestAppContext) {
2821 init_test(cx);
2822
2823 let fs = FakeFs::new(cx.executor());
2824 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2825 .await;
2826 let project = Project::test(fs.clone(), [], cx).await;
2827 project
2828 .update(cx, |project, cx| {
2829 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2830 })
2831 .await
2832 .unwrap();
2833
2834 let connection = Rc::new(FakeAgentConnection::new());
2835
2836 let thread = cx
2837 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2838 .await
2839 .unwrap();
2840
2841 // Whole file
2842 let content = thread
2843 .update(cx, |thread, cx| {
2844 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2845 })
2846 .await
2847 .unwrap();
2848
2849 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2850
2851 // Only start line
2852 let content = thread
2853 .update(cx, |thread, cx| {
2854 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2855 })
2856 .await
2857 .unwrap();
2858
2859 assert_eq!(content, "three\nfour\n");
2860
2861 // Only limit
2862 let content = thread
2863 .update(cx, |thread, cx| {
2864 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2865 })
2866 .await
2867 .unwrap();
2868
2869 assert_eq!(content, "one\ntwo\n");
2870
2871 // Range
2872 let content = thread
2873 .update(cx, |thread, cx| {
2874 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2875 })
2876 .await
2877 .unwrap();
2878
2879 assert_eq!(content, "two\nthree\n");
2880
2881 // Invalid
2882 let err = thread
2883 .update(cx, |thread, cx| {
2884 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2885 })
2886 .await
2887 .unwrap_err();
2888
2889 assert_eq!(
2890 err.to_string(),
2891 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2892 );
2893 }
2894
2895 #[gpui::test]
2896 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2897 init_test(cx);
2898
2899 let fs = FakeFs::new(cx.executor());
2900 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2901 let project = Project::test(fs.clone(), [], cx).await;
2902 project
2903 .update(cx, |project, cx| {
2904 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2905 })
2906 .await
2907 .unwrap();
2908
2909 let connection = Rc::new(FakeAgentConnection::new());
2910
2911 let thread = cx
2912 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2913 .await
2914 .unwrap();
2915
2916 // Whole file
2917 let content = thread
2918 .update(cx, |thread, cx| {
2919 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2920 })
2921 .await
2922 .unwrap();
2923
2924 assert_eq!(content, "");
2925
2926 // Only start line
2927 let content = thread
2928 .update(cx, |thread, cx| {
2929 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2930 })
2931 .await
2932 .unwrap();
2933
2934 assert_eq!(content, "");
2935
2936 // Only limit
2937 let content = thread
2938 .update(cx, |thread, cx| {
2939 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2940 })
2941 .await
2942 .unwrap();
2943
2944 assert_eq!(content, "");
2945
2946 // Range
2947 let content = thread
2948 .update(cx, |thread, cx| {
2949 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2950 })
2951 .await
2952 .unwrap();
2953
2954 assert_eq!(content, "");
2955
2956 // Invalid
2957 let err = thread
2958 .update(cx, |thread, cx| {
2959 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2960 })
2961 .await
2962 .unwrap_err();
2963
2964 assert_eq!(
2965 err.to_string(),
2966 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2967 );
2968 }
2969 #[gpui::test]
2970 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2971 init_test(cx);
2972
2973 let fs = FakeFs::new(cx.executor());
2974 fs.insert_tree(path!("/tmp"), json!({})).await;
2975 let project = Project::test(fs.clone(), [], cx).await;
2976 project
2977 .update(cx, |project, cx| {
2978 project.find_or_create_worktree(path!("/tmp"), true, cx)
2979 })
2980 .await
2981 .unwrap();
2982
2983 let connection = Rc::new(FakeAgentConnection::new());
2984
2985 let thread = cx
2986 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2987 .await
2988 .unwrap();
2989
2990 // Out of project file
2991 let err = thread
2992 .update(cx, |thread, cx| {
2993 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2994 })
2995 .await
2996 .unwrap_err();
2997
2998 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
2999 }
3000
3001 #[gpui::test]
3002 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3003 init_test(cx);
3004
3005 let fs = FakeFs::new(cx.executor());
3006 let project = Project::test(fs, [], cx).await;
3007 let id = acp::ToolCallId::new("test");
3008
3009 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3010 let id = id.clone();
3011 move |_, thread, mut cx| {
3012 let id = id.clone();
3013 async move {
3014 thread
3015 .update(&mut cx, |thread, cx| {
3016 thread.handle_session_update(
3017 acp::SessionUpdate::ToolCall(
3018 acp::ToolCall::new(id.clone(), "Label")
3019 .kind(acp::ToolKind::Fetch)
3020 .status(acp::ToolCallStatus::InProgress),
3021 ),
3022 cx,
3023 )
3024 })
3025 .unwrap()
3026 .unwrap();
3027 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3028 }
3029 .boxed_local()
3030 }
3031 }));
3032
3033 let thread = cx
3034 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3035 .await
3036 .unwrap();
3037
3038 let request = thread.update(cx, |thread, cx| {
3039 thread.send_raw("Fetch https://example.com", cx)
3040 });
3041
3042 run_until_first_tool_call(&thread, cx).await;
3043
3044 thread.read_with(cx, |thread, _| {
3045 assert!(matches!(
3046 thread.entries[1],
3047 AgentThreadEntry::ToolCall(ToolCall {
3048 status: ToolCallStatus::InProgress,
3049 ..
3050 })
3051 ));
3052 });
3053
3054 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3055
3056 thread.read_with(cx, |thread, _| {
3057 assert!(matches!(
3058 &thread.entries[1],
3059 AgentThreadEntry::ToolCall(ToolCall {
3060 status: ToolCallStatus::Canceled,
3061 ..
3062 })
3063 ));
3064 });
3065
3066 thread
3067 .update(cx, |thread, cx| {
3068 thread.handle_session_update(
3069 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3070 id,
3071 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3072 )),
3073 cx,
3074 )
3075 })
3076 .unwrap();
3077
3078 request.await.unwrap();
3079
3080 thread.read_with(cx, |thread, _| {
3081 assert!(matches!(
3082 thread.entries[1],
3083 AgentThreadEntry::ToolCall(ToolCall {
3084 status: ToolCallStatus::Completed,
3085 ..
3086 })
3087 ));
3088 });
3089 }
3090
3091 #[gpui::test]
3092 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3093 init_test(cx);
3094 let fs = FakeFs::new(cx.background_executor.clone());
3095 fs.insert_tree(path!("/test"), json!({})).await;
3096 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3097
3098 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3099 move |_, thread, mut cx| {
3100 async move {
3101 thread
3102 .update(&mut cx, |thread, cx| {
3103 thread.handle_session_update(
3104 acp::SessionUpdate::ToolCall(
3105 acp::ToolCall::new("test", "Label")
3106 .kind(acp::ToolKind::Edit)
3107 .status(acp::ToolCallStatus::Completed)
3108 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3109 "/test/test.txt",
3110 "foo",
3111 ))]),
3112 ),
3113 cx,
3114 )
3115 })
3116 .unwrap()
3117 .unwrap();
3118 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3119 }
3120 .boxed_local()
3121 }
3122 }));
3123
3124 let thread = cx
3125 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3126 .await
3127 .unwrap();
3128
3129 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3130 .await
3131 .unwrap();
3132
3133 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3134 }
3135
3136 #[gpui::test(iterations = 10)]
3137 async fn test_checkpoints(cx: &mut TestAppContext) {
3138 init_test(cx);
3139 let fs = FakeFs::new(cx.background_executor.clone());
3140 fs.insert_tree(
3141 path!("/test"),
3142 json!({
3143 ".git": {}
3144 }),
3145 )
3146 .await;
3147 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3148
3149 let simulate_changes = Arc::new(AtomicBool::new(true));
3150 let next_filename = Arc::new(AtomicUsize::new(0));
3151 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3152 let simulate_changes = simulate_changes.clone();
3153 let next_filename = next_filename.clone();
3154 let fs = fs.clone();
3155 move |request, thread, mut cx| {
3156 let fs = fs.clone();
3157 let simulate_changes = simulate_changes.clone();
3158 let next_filename = next_filename.clone();
3159 async move {
3160 if simulate_changes.load(SeqCst) {
3161 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3162 fs.write(Path::new(&filename), b"").await?;
3163 }
3164
3165 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3166 panic!("expected text content block");
3167 };
3168 thread.update(&mut cx, |thread, cx| {
3169 thread
3170 .handle_session_update(
3171 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3172 content.text.to_uppercase().into(),
3173 )),
3174 cx,
3175 )
3176 .unwrap();
3177 })?;
3178 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3179 }
3180 .boxed_local()
3181 }
3182 }));
3183 let thread = cx
3184 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3185 .await
3186 .unwrap();
3187
3188 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3189 .await
3190 .unwrap();
3191 thread.read_with(cx, |thread, cx| {
3192 assert_eq!(
3193 thread.to_markdown(cx),
3194 indoc! {"
3195 ## User (checkpoint)
3196
3197 Lorem
3198
3199 ## Assistant
3200
3201 LOREM
3202
3203 "}
3204 );
3205 });
3206 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3207
3208 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3209 .await
3210 .unwrap();
3211 thread.read_with(cx, |thread, cx| {
3212 assert_eq!(
3213 thread.to_markdown(cx),
3214 indoc! {"
3215 ## User (checkpoint)
3216
3217 Lorem
3218
3219 ## Assistant
3220
3221 LOREM
3222
3223 ## User (checkpoint)
3224
3225 ipsum
3226
3227 ## Assistant
3228
3229 IPSUM
3230
3231 "}
3232 );
3233 });
3234 assert_eq!(
3235 fs.files(),
3236 vec![
3237 Path::new(path!("/test/file-0")),
3238 Path::new(path!("/test/file-1"))
3239 ]
3240 );
3241
3242 // Checkpoint isn't stored when there are no changes.
3243 simulate_changes.store(false, SeqCst);
3244 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3245 .await
3246 .unwrap();
3247 thread.read_with(cx, |thread, cx| {
3248 assert_eq!(
3249 thread.to_markdown(cx),
3250 indoc! {"
3251 ## User (checkpoint)
3252
3253 Lorem
3254
3255 ## Assistant
3256
3257 LOREM
3258
3259 ## User (checkpoint)
3260
3261 ipsum
3262
3263 ## Assistant
3264
3265 IPSUM
3266
3267 ## User
3268
3269 dolor
3270
3271 ## Assistant
3272
3273 DOLOR
3274
3275 "}
3276 );
3277 });
3278 assert_eq!(
3279 fs.files(),
3280 vec![
3281 Path::new(path!("/test/file-0")),
3282 Path::new(path!("/test/file-1"))
3283 ]
3284 );
3285
3286 // Rewinding the conversation truncates the history and restores the checkpoint.
3287 thread
3288 .update(cx, |thread, cx| {
3289 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3290 panic!("unexpected entries {:?}", thread.entries)
3291 };
3292 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3293 })
3294 .await
3295 .unwrap();
3296 thread.read_with(cx, |thread, cx| {
3297 assert_eq!(
3298 thread.to_markdown(cx),
3299 indoc! {"
3300 ## User (checkpoint)
3301
3302 Lorem
3303
3304 ## Assistant
3305
3306 LOREM
3307
3308 "}
3309 );
3310 });
3311 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3312 }
3313
3314 #[gpui::test]
3315 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3316 use std::sync::atomic::AtomicUsize;
3317 init_test(cx);
3318
3319 let fs = FakeFs::new(cx.executor());
3320 let project = Project::test(fs, None, cx).await;
3321
3322 // Create a connection that simulates refusal after tool result
3323 let prompt_count = Arc::new(AtomicUsize::new(0));
3324 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3325 let prompt_count = prompt_count.clone();
3326 move |_request, thread, mut cx| {
3327 let count = prompt_count.fetch_add(1, SeqCst);
3328 async move {
3329 if count == 0 {
3330 // First prompt: Generate a tool call with result
3331 thread.update(&mut cx, |thread, cx| {
3332 thread
3333 .handle_session_update(
3334 acp::SessionUpdate::ToolCall(
3335 acp::ToolCall::new("tool1", "Test Tool")
3336 .kind(acp::ToolKind::Fetch)
3337 .status(acp::ToolCallStatus::Completed)
3338 .raw_input(serde_json::json!({"query": "test"}))
3339 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3340 ),
3341 cx,
3342 )
3343 .unwrap();
3344 })?;
3345
3346 // Now return refusal because of the tool result
3347 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3348 } else {
3349 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3350 }
3351 }
3352 .boxed_local()
3353 }
3354 }));
3355
3356 let thread = cx
3357 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3358 .await
3359 .unwrap();
3360
3361 // Track if we see a Refusal event
3362 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3363 let saw_refusal_event_captured = saw_refusal_event.clone();
3364 thread.update(cx, |_thread, cx| {
3365 cx.subscribe(
3366 &thread,
3367 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3368 if matches!(event, AcpThreadEvent::Refusal) {
3369 *saw_refusal_event_captured.lock().unwrap() = true;
3370 }
3371 },
3372 )
3373 .detach();
3374 });
3375
3376 // Send a user message - this will trigger tool call and then refusal
3377 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3378 cx.background_executor.spawn(send_task).detach();
3379 cx.run_until_parked();
3380
3381 // Verify that:
3382 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3383 // 2. The user message was NOT truncated
3384 assert!(
3385 *saw_refusal_event.lock().unwrap(),
3386 "Refusal event should be emitted for tool result refusals"
3387 );
3388
3389 thread.read_with(cx, |thread, _| {
3390 let entries = thread.entries();
3391 assert!(entries.len() >= 2, "Should have user message and tool call");
3392
3393 // Verify user message is still there
3394 assert!(
3395 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3396 "User message should not be truncated"
3397 );
3398
3399 // Verify tool call is there with result
3400 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3401 assert!(
3402 tool_call.raw_output.is_some(),
3403 "Tool call should have output"
3404 );
3405 } else {
3406 panic!("Expected tool call at index 1");
3407 }
3408 });
3409 }
3410
3411 #[gpui::test]
3412 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3413 init_test(cx);
3414
3415 let fs = FakeFs::new(cx.executor());
3416 let project = Project::test(fs, None, cx).await;
3417
3418 let refuse_next = Arc::new(AtomicBool::new(false));
3419 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3420 let refuse_next = refuse_next.clone();
3421 move |_request, _thread, _cx| {
3422 if refuse_next.load(SeqCst) {
3423 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3424 .boxed_local()
3425 } else {
3426 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3427 .boxed_local()
3428 }
3429 }
3430 }));
3431
3432 let thread = cx
3433 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3434 .await
3435 .unwrap();
3436
3437 // Track if we see a Refusal event
3438 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3439 let saw_refusal_event_captured = saw_refusal_event.clone();
3440 thread.update(cx, |_thread, cx| {
3441 cx.subscribe(
3442 &thread,
3443 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3444 if matches!(event, AcpThreadEvent::Refusal) {
3445 *saw_refusal_event_captured.lock().unwrap() = true;
3446 }
3447 },
3448 )
3449 .detach();
3450 });
3451
3452 // Send a message that will be refused
3453 refuse_next.store(true, SeqCst);
3454 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3455 .await
3456 .unwrap();
3457
3458 // Verify that a Refusal event WAS emitted for user prompt refusal
3459 assert!(
3460 *saw_refusal_event.lock().unwrap(),
3461 "Refusal event should be emitted for user prompt refusals"
3462 );
3463
3464 // Verify the message was truncated (user prompt refusal)
3465 thread.read_with(cx, |thread, cx| {
3466 assert_eq!(thread.to_markdown(cx), "");
3467 });
3468 }
3469
3470 #[gpui::test]
3471 async fn test_refusal(cx: &mut TestAppContext) {
3472 init_test(cx);
3473 let fs = FakeFs::new(cx.background_executor.clone());
3474 fs.insert_tree(path!("/"), json!({})).await;
3475 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3476
3477 let refuse_next = Arc::new(AtomicBool::new(false));
3478 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3479 let refuse_next = refuse_next.clone();
3480 move |request, thread, mut cx| {
3481 let refuse_next = refuse_next.clone();
3482 async move {
3483 if refuse_next.load(SeqCst) {
3484 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3485 }
3486
3487 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3488 panic!("expected text content block");
3489 };
3490 thread.update(&mut cx, |thread, cx| {
3491 thread
3492 .handle_session_update(
3493 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3494 content.text.to_uppercase().into(),
3495 )),
3496 cx,
3497 )
3498 .unwrap();
3499 })?;
3500 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3501 }
3502 .boxed_local()
3503 }
3504 }));
3505 let thread = cx
3506 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3507 .await
3508 .unwrap();
3509
3510 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3511 .await
3512 .unwrap();
3513 thread.read_with(cx, |thread, cx| {
3514 assert_eq!(
3515 thread.to_markdown(cx),
3516 indoc! {"
3517 ## User
3518
3519 hello
3520
3521 ## Assistant
3522
3523 HELLO
3524
3525 "}
3526 );
3527 });
3528
3529 // Simulate refusing the second message. The message should be truncated
3530 // when a user prompt is refused.
3531 refuse_next.store(true, SeqCst);
3532 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3533 .await
3534 .unwrap();
3535 thread.read_with(cx, |thread, cx| {
3536 assert_eq!(
3537 thread.to_markdown(cx),
3538 indoc! {"
3539 ## User
3540
3541 hello
3542
3543 ## Assistant
3544
3545 HELLO
3546
3547 "}
3548 );
3549 });
3550 }
3551
3552 async fn run_until_first_tool_call(
3553 thread: &Entity<AcpThread>,
3554 cx: &mut TestAppContext,
3555 ) -> usize {
3556 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3557
3558 let subscription = cx.update(|cx| {
3559 cx.subscribe(thread, move |thread, _, cx| {
3560 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3561 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3562 return tx.try_send(ix).unwrap();
3563 }
3564 }
3565 })
3566 });
3567
3568 select! {
3569 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3570 panic!("Timeout waiting for tool call")
3571 }
3572 ix = rx.next().fuse() => {
3573 drop(subscription);
3574 ix.unwrap()
3575 }
3576 }
3577 }
3578
3579 #[derive(Clone, Default)]
3580 struct FakeAgentConnection {
3581 auth_methods: Vec<acp::AuthMethod>,
3582 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3583 on_user_message: Option<
3584 Rc<
3585 dyn Fn(
3586 acp::PromptRequest,
3587 WeakEntity<AcpThread>,
3588 AsyncApp,
3589 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3590 + 'static,
3591 >,
3592 >,
3593 }
3594
3595 impl FakeAgentConnection {
3596 fn new() -> Self {
3597 Self {
3598 auth_methods: Vec::new(),
3599 on_user_message: None,
3600 sessions: Arc::default(),
3601 }
3602 }
3603
3604 #[expect(unused)]
3605 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3606 self.auth_methods = auth_methods;
3607 self
3608 }
3609
3610 fn on_user_message(
3611 mut self,
3612 handler: impl Fn(
3613 acp::PromptRequest,
3614 WeakEntity<AcpThread>,
3615 AsyncApp,
3616 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3617 + 'static,
3618 ) -> Self {
3619 self.on_user_message.replace(Rc::new(handler));
3620 self
3621 }
3622 }
3623
3624 impl AgentConnection for FakeAgentConnection {
3625 fn telemetry_id(&self) -> SharedString {
3626 "fake".into()
3627 }
3628
3629 fn auth_methods(&self) -> &[acp::AuthMethod] {
3630 &self.auth_methods
3631 }
3632
3633 fn new_thread(
3634 self: Rc<Self>,
3635 project: Entity<Project>,
3636 _cwd: &Path,
3637 cx: &mut App,
3638 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3639 let session_id = acp::SessionId::new(
3640 rand::rng()
3641 .sample_iter(&distr::Alphanumeric)
3642 .take(7)
3643 .map(char::from)
3644 .collect::<String>(),
3645 );
3646 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3647 let thread = cx.new(|cx| {
3648 AcpThread::new(
3649 "Test",
3650 self.clone(),
3651 project,
3652 action_log,
3653 session_id.clone(),
3654 watch::Receiver::constant(
3655 acp::PromptCapabilities::new()
3656 .image(true)
3657 .audio(true)
3658 .embedded_context(true),
3659 ),
3660 cx,
3661 )
3662 });
3663 self.sessions.lock().insert(session_id, thread.downgrade());
3664 Task::ready(Ok(thread))
3665 }
3666
3667 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3668 if self.auth_methods().iter().any(|m| m.id == method) {
3669 Task::ready(Ok(()))
3670 } else {
3671 Task::ready(Err(anyhow!("Invalid Auth Method")))
3672 }
3673 }
3674
3675 fn prompt(
3676 &self,
3677 _id: Option<UserMessageId>,
3678 params: acp::PromptRequest,
3679 cx: &mut App,
3680 ) -> Task<gpui::Result<acp::PromptResponse>> {
3681 let sessions = self.sessions.lock();
3682 let thread = sessions.get(¶ms.session_id).unwrap();
3683 if let Some(handler) = &self.on_user_message {
3684 let handler = handler.clone();
3685 let thread = thread.clone();
3686 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3687 } else {
3688 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3689 }
3690 }
3691
3692 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3693 let sessions = self.sessions.lock();
3694 let thread = sessions.get(session_id).unwrap().clone();
3695
3696 cx.spawn(async move |cx| {
3697 thread
3698 .update(cx, |thread, cx| thread.cancel(cx))
3699 .unwrap()
3700 .await
3701 })
3702 .detach();
3703 }
3704
3705 fn truncate(
3706 &self,
3707 session_id: &acp::SessionId,
3708 _cx: &App,
3709 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3710 Some(Rc::new(FakeAgentSessionEditor {
3711 _session_id: session_id.clone(),
3712 }))
3713 }
3714
3715 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3716 self
3717 }
3718 }
3719
3720 struct FakeAgentSessionEditor {
3721 _session_id: acp::SessionId,
3722 }
3723
3724 impl AgentSessionTruncate for FakeAgentSessionEditor {
3725 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3726 Task::ready(Ok(()))
3727 }
3728 }
3729
3730 #[gpui::test]
3731 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3732 init_test(cx);
3733
3734 let fs = FakeFs::new(cx.executor());
3735 let project = Project::test(fs, [], cx).await;
3736 let connection = Rc::new(FakeAgentConnection::new());
3737 let thread = cx
3738 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3739 .await
3740 .unwrap();
3741
3742 // Try to update a tool call that doesn't exist
3743 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3744 thread.update(cx, |thread, cx| {
3745 let result = thread.handle_session_update(
3746 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3747 nonexistent_id.clone(),
3748 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3749 )),
3750 cx,
3751 );
3752
3753 // The update should succeed (not return an error)
3754 assert!(result.is_ok());
3755
3756 // There should now be exactly one entry in the thread
3757 assert_eq!(thread.entries.len(), 1);
3758
3759 // The entry should be a failed tool call
3760 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3761 assert_eq!(tool_call.id, nonexistent_id);
3762 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3763 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3764
3765 // Check that the content contains the error message
3766 assert_eq!(tool_call.content.len(), 1);
3767 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3768 match content_block {
3769 ContentBlock::Markdown { markdown } => {
3770 let markdown_text = markdown.read(cx).source();
3771 assert!(markdown_text.contains("Tool call not found"));
3772 }
3773 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3774 ContentBlock::ResourceLink { .. } => {
3775 panic!("Expected markdown content, got resource link")
3776 }
3777 }
3778 } else {
3779 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3780 }
3781 } else {
3782 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3783 }
3784 });
3785 }
3786
3787 /// Tests that restoring a checkpoint properly cleans up terminals that were
3788 /// created after that checkpoint, and cancels any in-progress generation.
3789 ///
3790 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3791 /// that were started after that checkpoint should be terminated, and any in-progress
3792 /// AI generation should be canceled.
3793 #[gpui::test]
3794 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3795 init_test(cx);
3796
3797 let fs = FakeFs::new(cx.executor());
3798 let project = Project::test(fs, [], cx).await;
3799 let connection = Rc::new(FakeAgentConnection::new());
3800 let thread = cx
3801 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3802 .await
3803 .unwrap();
3804
3805 // Send first user message to create a checkpoint
3806 cx.update(|cx| {
3807 thread.update(cx, |thread, cx| {
3808 thread.send(vec!["first message".into()], cx)
3809 })
3810 })
3811 .await
3812 .unwrap();
3813
3814 // Send second message (creates another checkpoint) - we'll restore to this one
3815 cx.update(|cx| {
3816 thread.update(cx, |thread, cx| {
3817 thread.send(vec!["second message".into()], cx)
3818 })
3819 })
3820 .await
3821 .unwrap();
3822
3823 // Create 2 terminals BEFORE the checkpoint that have completed running
3824 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3825 let mock_terminal_1 = cx.new(|cx| {
3826 let builder = ::terminal::TerminalBuilder::new_display_only(
3827 ::terminal::terminal_settings::CursorShape::default(),
3828 ::terminal::terminal_settings::AlternateScroll::On,
3829 None,
3830 0,
3831 )
3832 .unwrap();
3833 builder.subscribe(cx)
3834 });
3835
3836 thread.update(cx, |thread, cx| {
3837 thread.on_terminal_provider_event(
3838 TerminalProviderEvent::Created {
3839 terminal_id: terminal_id_1.clone(),
3840 label: "echo 'first'".to_string(),
3841 cwd: Some(PathBuf::from("/test")),
3842 output_byte_limit: None,
3843 terminal: mock_terminal_1.clone(),
3844 },
3845 cx,
3846 );
3847 });
3848
3849 thread.update(cx, |thread, cx| {
3850 thread.on_terminal_provider_event(
3851 TerminalProviderEvent::Output {
3852 terminal_id: terminal_id_1.clone(),
3853 data: b"first\n".to_vec(),
3854 },
3855 cx,
3856 );
3857 });
3858
3859 thread.update(cx, |thread, cx| {
3860 thread.on_terminal_provider_event(
3861 TerminalProviderEvent::Exit {
3862 terminal_id: terminal_id_1.clone(),
3863 status: acp::TerminalExitStatus::new().exit_code(0),
3864 },
3865 cx,
3866 );
3867 });
3868
3869 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3870 let mock_terminal_2 = cx.new(|cx| {
3871 let builder = ::terminal::TerminalBuilder::new_display_only(
3872 ::terminal::terminal_settings::CursorShape::default(),
3873 ::terminal::terminal_settings::AlternateScroll::On,
3874 None,
3875 0,
3876 )
3877 .unwrap();
3878 builder.subscribe(cx)
3879 });
3880
3881 thread.update(cx, |thread, cx| {
3882 thread.on_terminal_provider_event(
3883 TerminalProviderEvent::Created {
3884 terminal_id: terminal_id_2.clone(),
3885 label: "echo 'second'".to_string(),
3886 cwd: Some(PathBuf::from("/test")),
3887 output_byte_limit: None,
3888 terminal: mock_terminal_2.clone(),
3889 },
3890 cx,
3891 );
3892 });
3893
3894 thread.update(cx, |thread, cx| {
3895 thread.on_terminal_provider_event(
3896 TerminalProviderEvent::Output {
3897 terminal_id: terminal_id_2.clone(),
3898 data: b"second\n".to_vec(),
3899 },
3900 cx,
3901 );
3902 });
3903
3904 thread.update(cx, |thread, cx| {
3905 thread.on_terminal_provider_event(
3906 TerminalProviderEvent::Exit {
3907 terminal_id: terminal_id_2.clone(),
3908 status: acp::TerminalExitStatus::new().exit_code(0),
3909 },
3910 cx,
3911 );
3912 });
3913
3914 // Get the second message ID to restore to
3915 let second_message_id = thread.read_with(cx, |thread, _| {
3916 // At this point we have:
3917 // - Index 0: First user message (with checkpoint)
3918 // - Index 1: Second user message (with checkpoint)
3919 // No assistant responses because FakeAgentConnection just returns EndTurn
3920 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3921 panic!("expected user message at index 1");
3922 };
3923 message.id.clone().unwrap()
3924 });
3925
3926 // Create a terminal AFTER the checkpoint we'll restore to.
3927 // This simulates the AI agent starting a long-running terminal command.
3928 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3929 let mock_terminal = cx.new(|cx| {
3930 let builder = ::terminal::TerminalBuilder::new_display_only(
3931 ::terminal::terminal_settings::CursorShape::default(),
3932 ::terminal::terminal_settings::AlternateScroll::On,
3933 None,
3934 0,
3935 )
3936 .unwrap();
3937 builder.subscribe(cx)
3938 });
3939
3940 // Register the terminal as created
3941 thread.update(cx, |thread, cx| {
3942 thread.on_terminal_provider_event(
3943 TerminalProviderEvent::Created {
3944 terminal_id: terminal_id.clone(),
3945 label: "sleep 1000".to_string(),
3946 cwd: Some(PathBuf::from("/test")),
3947 output_byte_limit: None,
3948 terminal: mock_terminal.clone(),
3949 },
3950 cx,
3951 );
3952 });
3953
3954 // Simulate the terminal producing output (still running)
3955 thread.update(cx, |thread, cx| {
3956 thread.on_terminal_provider_event(
3957 TerminalProviderEvent::Output {
3958 terminal_id: terminal_id.clone(),
3959 data: b"terminal is running...\n".to_vec(),
3960 },
3961 cx,
3962 );
3963 });
3964
3965 // Create a tool call entry that references this terminal
3966 // This represents the agent requesting a terminal command
3967 thread.update(cx, |thread, cx| {
3968 thread
3969 .handle_session_update(
3970 acp::SessionUpdate::ToolCall(
3971 acp::ToolCall::new("terminal-tool-1", "Running command")
3972 .kind(acp::ToolKind::Execute)
3973 .status(acp::ToolCallStatus::InProgress)
3974 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
3975 terminal_id.clone(),
3976 ))])
3977 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
3978 ),
3979 cx,
3980 )
3981 .unwrap();
3982 });
3983
3984 // Verify terminal exists and is in the thread
3985 let terminal_exists_before =
3986 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
3987 assert!(
3988 terminal_exists_before,
3989 "Terminal should exist before checkpoint restore"
3990 );
3991
3992 // Verify the terminal's underlying task is still running (not completed)
3993 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
3994 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
3995 terminal_entity.read_with(cx, |term, _cx| {
3996 term.output().is_none() // output is None means it's still running
3997 })
3998 });
3999 assert!(
4000 terminal_running_before,
4001 "Terminal should be running before checkpoint restore"
4002 );
4003
4004 // Verify we have the expected entries before restore
4005 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4006 assert!(
4007 entry_count_before > 1,
4008 "Should have multiple entries before restore"
4009 );
4010
4011 // Restore the checkpoint to the second message.
4012 // This should:
4013 // 1. Cancel any in-progress generation (via the cancel() call)
4014 // 2. Remove the terminal that was created after that point
4015 thread
4016 .update(cx, |thread, cx| {
4017 thread.restore_checkpoint(second_message_id, cx)
4018 })
4019 .await
4020 .unwrap();
4021
4022 // Verify that no send_task is in progress after restore
4023 // (cancel() clears the send_task)
4024 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4025 assert!(
4026 !has_send_task_after,
4027 "Should not have a send_task after restore (cancel should have cleared it)"
4028 );
4029
4030 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4031 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4032 assert_eq!(
4033 entry_count, 1,
4034 "Should have 1 entry after restore (only the first user message)"
4035 );
4036
4037 // Verify the 2 completed terminals from before the checkpoint still exist
4038 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4039 thread.terminals.contains_key(&terminal_id_1)
4040 });
4041 assert!(
4042 terminal_1_exists,
4043 "Terminal 1 (from before checkpoint) should still exist"
4044 );
4045
4046 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4047 thread.terminals.contains_key(&terminal_id_2)
4048 });
4049 assert!(
4050 terminal_2_exists,
4051 "Terminal 2 (from before checkpoint) should still exist"
4052 );
4053
4054 // Verify they're still in completed state
4055 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4056 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4057 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4058 });
4059 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4060
4061 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4062 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4063 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4064 });
4065 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4066
4067 // Verify the running terminal (created after checkpoint) was removed
4068 let terminal_3_exists =
4069 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4070 assert!(
4071 !terminal_3_exists,
4072 "Terminal 3 (created after checkpoint) should have been removed"
4073 );
4074
4075 // Verify total count is 2 (the two from before the checkpoint)
4076 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4077 assert_eq!(
4078 terminal_count, 2,
4079 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4080 );
4081 }
4082
4083 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4084 /// even when a new user message is added while the async checkpoint comparison is in progress.
4085 ///
4086 /// This is a regression test for a bug where update_last_checkpoint would fail with
4087 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4088 /// update_last_checkpoint started and when its async closure ran.
4089 #[gpui::test]
4090 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4091 init_test(cx);
4092
4093 let fs = FakeFs::new(cx.executor());
4094 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4095 .await;
4096 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4097
4098 let handler_done = Arc::new(AtomicBool::new(false));
4099 let handler_done_clone = handler_done.clone();
4100 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4101 move |_, _thread, _cx| {
4102 handler_done_clone.store(true, SeqCst);
4103 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4104 },
4105 ));
4106
4107 let thread = cx
4108 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4109 .await
4110 .unwrap();
4111
4112 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4113 let send_task = cx.background_executor.spawn(send_future);
4114
4115 // Tick until handler completes, then a few more to let update_last_checkpoint start
4116 while !handler_done.load(SeqCst) {
4117 cx.executor().tick();
4118 }
4119 for _ in 0..5 {
4120 cx.executor().tick();
4121 }
4122
4123 thread.update(cx, |thread, cx| {
4124 thread.push_entry(
4125 AgentThreadEntry::UserMessage(UserMessage {
4126 id: Some(UserMessageId::new()),
4127 content: ContentBlock::Empty,
4128 chunks: vec!["Injected message (no checkpoint)".into()],
4129 checkpoint: None,
4130 indented: false,
4131 }),
4132 cx,
4133 );
4134 });
4135
4136 cx.run_until_parked();
4137 let result = send_task.await;
4138
4139 assert!(
4140 result.is_ok(),
4141 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4142 result.err()
4143 );
4144 }
4145}