1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer.update(cx, |buffer, _| {
379 let snapshot = buffer.snapshot();
380 if let Some(row) = location.line {
381 let column = snapshot.indent_size_for_line(row).len;
382 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
383 snapshot.anchor_before(point)
384 } else {
385 Anchor::min_for_buffer(snapshot.remote_id())
386 }
387 });
388
389 Some(ResolvedLocation { buffer, position })
390 }
391
392 fn resolve_locations(
393 &self,
394 project: Entity<Project>,
395 cx: &mut App,
396 ) -> Task<Vec<Option<ResolvedLocation>>> {
397 let locations = self.locations.clone();
398 project.update(cx, |_, cx| {
399 cx.spawn(async move |project, cx| {
400 let mut new_locations = Vec::new();
401 for location in locations {
402 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
403 }
404 new_locations
405 })
406 })
407 }
408}
409
410// Separate so we can hold a strong reference to the buffer
411// for saving on the thread
412#[derive(Clone, Debug, PartialEq, Eq)]
413struct ResolvedLocation {
414 buffer: Entity<Buffer>,
415 position: Anchor,
416}
417
418impl From<&ResolvedLocation> for AgentLocation {
419 fn from(value: &ResolvedLocation) -> Self {
420 Self {
421 buffer: value.buffer.downgrade(),
422 position: value.position,
423 }
424 }
425}
426
427#[derive(Debug)]
428pub enum ToolCallStatus {
429 /// The tool call hasn't started running yet, but we start showing it to
430 /// the user.
431 Pending,
432 /// The tool call is waiting for confirmation from the user.
433 WaitingForConfirmation {
434 options: Vec<acp::PermissionOption>,
435 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
436 },
437 /// The tool call is currently running.
438 InProgress,
439 /// The tool call completed successfully.
440 Completed,
441 /// The tool call failed.
442 Failed,
443 /// The user rejected the tool call.
444 Rejected,
445 /// The user canceled generation so the tool call was canceled.
446 Canceled,
447}
448
449impl From<acp::ToolCallStatus> for ToolCallStatus {
450 fn from(status: acp::ToolCallStatus) -> Self {
451 match status {
452 acp::ToolCallStatus::Pending => Self::Pending,
453 acp::ToolCallStatus::InProgress => Self::InProgress,
454 acp::ToolCallStatus::Completed => Self::Completed,
455 acp::ToolCallStatus::Failed => Self::Failed,
456 _ => Self::Pending,
457 }
458 }
459}
460
461impl Display for ToolCallStatus {
462 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463 write!(
464 f,
465 "{}",
466 match self {
467 ToolCallStatus::Pending => "Pending",
468 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
469 ToolCallStatus::InProgress => "In Progress",
470 ToolCallStatus::Completed => "Completed",
471 ToolCallStatus::Failed => "Failed",
472 ToolCallStatus::Rejected => "Rejected",
473 ToolCallStatus::Canceled => "Canceled",
474 }
475 )
476 }
477}
478
479#[derive(Debug, PartialEq, Clone)]
480pub enum ContentBlock {
481 Empty,
482 Markdown { markdown: Entity<Markdown> },
483 ResourceLink { resource_link: acp::ResourceLink },
484 Image { image: Arc<gpui::Image> },
485}
486
487impl ContentBlock {
488 pub fn new(
489 block: acp::ContentBlock,
490 language_registry: &Arc<LanguageRegistry>,
491 path_style: PathStyle,
492 cx: &mut App,
493 ) -> Self {
494 let mut this = Self::Empty;
495 this.append(block, language_registry, path_style, cx);
496 this
497 }
498
499 pub fn new_combined(
500 blocks: impl IntoIterator<Item = acp::ContentBlock>,
501 language_registry: Arc<LanguageRegistry>,
502 path_style: PathStyle,
503 cx: &mut App,
504 ) -> Self {
505 let mut this = Self::Empty;
506 for block in blocks {
507 this.append(block, &language_registry, path_style, cx);
508 }
509 this
510 }
511
512 pub fn append(
513 &mut self,
514 block: acp::ContentBlock,
515 language_registry: &Arc<LanguageRegistry>,
516 path_style: PathStyle,
517 cx: &mut App,
518 ) {
519 match (&mut *self, &block) {
520 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
521 *self = ContentBlock::ResourceLink {
522 resource_link: resource_link.clone(),
523 };
524 }
525 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
526 if let Some(image) = Self::decode_image(image_content) {
527 *self = ContentBlock::Image { image };
528 } else {
529 let new_content = Self::image_md(image_content);
530 *self = Self::create_markdown_block(new_content, language_registry, cx);
531 }
532 }
533 (ContentBlock::Empty, _) => {
534 let new_content = Self::block_string_contents(&block, path_style);
535 *self = Self::create_markdown_block(new_content, language_registry, cx);
536 }
537 (ContentBlock::Markdown { markdown }, _) => {
538 let new_content = Self::block_string_contents(&block, path_style);
539 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
540 }
541 (ContentBlock::ResourceLink { resource_link }, _) => {
542 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
543 let new_content = Self::block_string_contents(&block, path_style);
544 let combined = format!("{}\n{}", existing_content, new_content);
545 *self = Self::create_markdown_block(combined, language_registry, cx);
546 }
547 (ContentBlock::Image { .. }, _) => {
548 let new_content = Self::block_string_contents(&block, path_style);
549 let combined = format!("`Image`\n{}", new_content);
550 *self = Self::create_markdown_block(combined, language_registry, cx);
551 }
552 }
553 }
554
555 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
556 use base64::Engine as _;
557
558 let bytes = base64::engine::general_purpose::STANDARD
559 .decode(image_content.data.as_bytes())
560 .ok()?;
561 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
562 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
563 }
564
565 fn create_markdown_block(
566 content: String,
567 language_registry: &Arc<LanguageRegistry>,
568 cx: &mut App,
569 ) -> ContentBlock {
570 ContentBlock::Markdown {
571 markdown: cx
572 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
573 }
574 }
575
576 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
577 match block {
578 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
579 acp::ContentBlock::ResourceLink(resource_link) => {
580 Self::resource_link_md(&resource_link.uri, path_style)
581 }
582 acp::ContentBlock::Resource(acp::EmbeddedResource {
583 resource:
584 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
585 uri,
586 ..
587 }),
588 ..
589 }) => Self::resource_link_md(uri, path_style),
590 acp::ContentBlock::Image(image) => Self::image_md(image),
591 _ => String::new(),
592 }
593 }
594
595 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
596 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
597 uri.as_link().to_string()
598 } else {
599 uri.to_string()
600 }
601 }
602
603 fn image_md(_image: &acp::ImageContent) -> String {
604 "`Image`".into()
605 }
606
607 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
608 match self {
609 ContentBlock::Empty => "",
610 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
611 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
612 ContentBlock::Image { .. } => "`Image`",
613 }
614 }
615
616 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
617 match self {
618 ContentBlock::Empty => None,
619 ContentBlock::Markdown { markdown } => Some(markdown),
620 ContentBlock::ResourceLink { .. } => None,
621 ContentBlock::Image { .. } => None,
622 }
623 }
624
625 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
626 match self {
627 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
628 _ => None,
629 }
630 }
631
632 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
633 match self {
634 ContentBlock::Image { image } => Some(image),
635 _ => None,
636 }
637 }
638}
639
640#[derive(Debug)]
641pub enum ToolCallContent {
642 ContentBlock(ContentBlock),
643 Diff(Entity<Diff>),
644 Terminal(Entity<Terminal>),
645}
646
647impl ToolCallContent {
648 pub fn from_acp(
649 content: acp::ToolCallContent,
650 language_registry: Arc<LanguageRegistry>,
651 path_style: PathStyle,
652 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
653 cx: &mut App,
654 ) -> Result<Option<Self>> {
655 match content {
656 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
657 Ok(Some(Self::ContentBlock(ContentBlock::new(
658 content,
659 &language_registry,
660 path_style,
661 cx,
662 ))))
663 }
664 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
665 Diff::finalized(
666 diff.path.to_string_lossy().into_owned(),
667 diff.old_text,
668 diff.new_text,
669 language_registry,
670 cx,
671 )
672 })))),
673 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
674 .get(&terminal_id)
675 .cloned()
676 .map(|terminal| Some(Self::Terminal(terminal)))
677 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
678 _ => Ok(None),
679 }
680 }
681
682 pub fn update_from_acp(
683 &mut self,
684 new: acp::ToolCallContent,
685 language_registry: Arc<LanguageRegistry>,
686 path_style: PathStyle,
687 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
688 cx: &mut App,
689 ) -> Result<bool> {
690 let needs_update = match (&self, &new) {
691 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
692 old_diff.read(cx).needs_update(
693 new_diff.old_text.as_deref().unwrap_or(""),
694 &new_diff.new_text,
695 cx,
696 )
697 }
698 _ => true,
699 };
700
701 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
702 if needs_update {
703 *self = update;
704 }
705 Ok(true)
706 } else {
707 Ok(false)
708 }
709 }
710
711 pub fn to_markdown(&self, cx: &App) -> String {
712 match self {
713 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
714 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
715 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
716 }
717 }
718
719 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
720 match self {
721 Self::ContentBlock(content) => content.image(),
722 _ => None,
723 }
724 }
725}
726
727#[derive(Debug, PartialEq)]
728pub enum ToolCallUpdate {
729 UpdateFields(acp::ToolCallUpdate),
730 UpdateDiff(ToolCallUpdateDiff),
731 UpdateTerminal(ToolCallUpdateTerminal),
732}
733
734impl ToolCallUpdate {
735 fn id(&self) -> &acp::ToolCallId {
736 match self {
737 Self::UpdateFields(update) => &update.tool_call_id,
738 Self::UpdateDiff(diff) => &diff.id,
739 Self::UpdateTerminal(terminal) => &terminal.id,
740 }
741 }
742}
743
744impl From<acp::ToolCallUpdate> for ToolCallUpdate {
745 fn from(update: acp::ToolCallUpdate) -> Self {
746 Self::UpdateFields(update)
747 }
748}
749
750impl From<ToolCallUpdateDiff> for ToolCallUpdate {
751 fn from(diff: ToolCallUpdateDiff) -> Self {
752 Self::UpdateDiff(diff)
753 }
754}
755
756#[derive(Debug, PartialEq)]
757pub struct ToolCallUpdateDiff {
758 pub id: acp::ToolCallId,
759 pub diff: Entity<Diff>,
760}
761
762impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
763 fn from(terminal: ToolCallUpdateTerminal) -> Self {
764 Self::UpdateTerminal(terminal)
765 }
766}
767
768#[derive(Debug, PartialEq)]
769pub struct ToolCallUpdateTerminal {
770 pub id: acp::ToolCallId,
771 pub terminal: Entity<Terminal>,
772}
773
774#[derive(Debug, Default)]
775pub struct Plan {
776 pub entries: Vec<PlanEntry>,
777}
778
779#[derive(Debug)]
780pub struct PlanStats<'a> {
781 pub in_progress_entry: Option<&'a PlanEntry>,
782 pub pending: u32,
783 pub completed: u32,
784}
785
786impl Plan {
787 pub fn is_empty(&self) -> bool {
788 self.entries.is_empty()
789 }
790
791 pub fn stats(&self) -> PlanStats<'_> {
792 let mut stats = PlanStats {
793 in_progress_entry: None,
794 pending: 0,
795 completed: 0,
796 };
797
798 for entry in &self.entries {
799 match &entry.status {
800 acp::PlanEntryStatus::Pending => {
801 stats.pending += 1;
802 }
803 acp::PlanEntryStatus::InProgress => {
804 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
805 }
806 acp::PlanEntryStatus::Completed => {
807 stats.completed += 1;
808 }
809 _ => {}
810 }
811 }
812
813 stats
814 }
815}
816
817#[derive(Debug)]
818pub struct PlanEntry {
819 pub content: Entity<Markdown>,
820 pub priority: acp::PlanEntryPriority,
821 pub status: acp::PlanEntryStatus,
822}
823
824impl PlanEntry {
825 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
826 Self {
827 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
828 priority: entry.priority,
829 status: entry.status,
830 }
831 }
832}
833
834#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
835pub struct TokenUsage {
836 pub max_tokens: u64,
837 pub used_tokens: u64,
838 pub output_tokens: u64,
839}
840
841impl TokenUsage {
842 pub fn ratio(&self) -> TokenUsageRatio {
843 #[cfg(debug_assertions)]
844 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
845 .unwrap_or("0.8".to_string())
846 .parse()
847 .unwrap();
848 #[cfg(not(debug_assertions))]
849 let warning_threshold: f32 = 0.8;
850
851 // When the maximum is unknown because there is no selected model,
852 // avoid showing the token limit warning.
853 if self.max_tokens == 0 {
854 TokenUsageRatio::Normal
855 } else if self.used_tokens >= self.max_tokens {
856 TokenUsageRatio::Exceeded
857 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
858 TokenUsageRatio::Warning
859 } else {
860 TokenUsageRatio::Normal
861 }
862 }
863}
864
865#[derive(Debug, Clone, PartialEq, Eq)]
866pub enum TokenUsageRatio {
867 Normal,
868 Warning,
869 Exceeded,
870}
871
872#[derive(Debug, Clone)]
873pub struct RetryStatus {
874 pub last_error: SharedString,
875 pub attempt: usize,
876 pub max_attempts: usize,
877 pub started_at: Instant,
878 pub duration: Duration,
879}
880
881pub struct AcpThread {
882 title: SharedString,
883 entries: Vec<AgentThreadEntry>,
884 plan: Plan,
885 project: Entity<Project>,
886 action_log: Entity<ActionLog>,
887 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
888 send_task: Option<Task<()>>,
889 connection: Rc<dyn AgentConnection>,
890 session_id: acp::SessionId,
891 token_usage: Option<TokenUsage>,
892 prompt_capabilities: acp::PromptCapabilities,
893 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
894 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
895 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
896 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
897}
898
899impl From<&AcpThread> for ActionLogTelemetry {
900 fn from(value: &AcpThread) -> Self {
901 Self {
902 agent_telemetry_id: value.connection().telemetry_id(),
903 session_id: value.session_id.0.clone(),
904 }
905 }
906}
907
908#[derive(Debug)]
909pub enum AcpThreadEvent {
910 NewEntry,
911 TitleUpdated,
912 TokenUsageUpdated,
913 EntryUpdated(usize),
914 EntriesRemoved(Range<usize>),
915 ToolAuthorizationRequired,
916 Retry(RetryStatus),
917 Stopped,
918 Error,
919 LoadError(LoadError),
920 PromptCapabilitiesUpdated,
921 Refusal,
922 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
923 ModeUpdated(acp::SessionModeId),
924 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
925}
926
927impl EventEmitter<AcpThreadEvent> for AcpThread {}
928
929#[derive(Debug, Clone)]
930pub enum TerminalProviderEvent {
931 Created {
932 terminal_id: acp::TerminalId,
933 label: String,
934 cwd: Option<PathBuf>,
935 output_byte_limit: Option<u64>,
936 terminal: Entity<::terminal::Terminal>,
937 },
938 Output {
939 terminal_id: acp::TerminalId,
940 data: Vec<u8>,
941 },
942 TitleChanged {
943 terminal_id: acp::TerminalId,
944 title: String,
945 },
946 Exit {
947 terminal_id: acp::TerminalId,
948 status: acp::TerminalExitStatus,
949 },
950}
951
952#[derive(Debug, Clone)]
953pub enum TerminalProviderCommand {
954 WriteInput {
955 terminal_id: acp::TerminalId,
956 bytes: Vec<u8>,
957 },
958 Resize {
959 terminal_id: acp::TerminalId,
960 cols: u16,
961 rows: u16,
962 },
963 Close {
964 terminal_id: acp::TerminalId,
965 },
966}
967
968impl AcpThread {
969 pub fn on_terminal_provider_event(
970 &mut self,
971 event: TerminalProviderEvent,
972 cx: &mut Context<Self>,
973 ) {
974 match event {
975 TerminalProviderEvent::Created {
976 terminal_id,
977 label,
978 cwd,
979 output_byte_limit,
980 terminal,
981 } => {
982 let entity = self.register_terminal_created(
983 terminal_id.clone(),
984 label,
985 cwd,
986 output_byte_limit,
987 terminal,
988 cx,
989 );
990
991 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
992 for data in chunks.drain(..) {
993 entity.update(cx, |term, cx| {
994 term.inner().update(cx, |inner, cx| {
995 inner.write_output(&data, cx);
996 })
997 });
998 }
999 }
1000
1001 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1002 entity.update(cx, |_term, cx| {
1003 cx.notify();
1004 });
1005 }
1006
1007 cx.notify();
1008 }
1009 TerminalProviderEvent::Output { terminal_id, data } => {
1010 if let Some(entity) = self.terminals.get(&terminal_id) {
1011 entity.update(cx, |term, cx| {
1012 term.inner().update(cx, |inner, cx| {
1013 inner.write_output(&data, cx);
1014 })
1015 });
1016 } else {
1017 self.pending_terminal_output
1018 .entry(terminal_id)
1019 .or_default()
1020 .push(data);
1021 }
1022 }
1023 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1024 if let Some(entity) = self.terminals.get(&terminal_id) {
1025 entity.update(cx, |term, cx| {
1026 term.inner().update(cx, |inner, cx| {
1027 inner.breadcrumb_text = title;
1028 cx.emit(::terminal::Event::BreadcrumbsChanged);
1029 })
1030 });
1031 }
1032 }
1033 TerminalProviderEvent::Exit {
1034 terminal_id,
1035 status,
1036 } => {
1037 if let Some(entity) = self.terminals.get(&terminal_id) {
1038 entity.update(cx, |_term, cx| {
1039 cx.notify();
1040 });
1041 } else {
1042 self.pending_terminal_exit.insert(terminal_id, status);
1043 }
1044 }
1045 }
1046 }
1047}
1048
1049#[derive(PartialEq, Eq, Debug)]
1050pub enum ThreadStatus {
1051 Idle,
1052 Generating,
1053}
1054
1055#[derive(Debug, Clone)]
1056pub enum LoadError {
1057 Unsupported {
1058 command: SharedString,
1059 current_version: SharedString,
1060 minimum_version: SharedString,
1061 },
1062 FailedToInstall(SharedString),
1063 Exited {
1064 status: ExitStatus,
1065 },
1066 Other(SharedString),
1067}
1068
1069impl Display for LoadError {
1070 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1071 match self {
1072 LoadError::Unsupported {
1073 command: path,
1074 current_version,
1075 minimum_version,
1076 } => {
1077 write!(
1078 f,
1079 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1080 )
1081 }
1082 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1083 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1084 LoadError::Other(msg) => write!(f, "{msg}"),
1085 }
1086 }
1087}
1088
1089impl Error for LoadError {}
1090
1091impl AcpThread {
1092 pub fn new(
1093 title: impl Into<SharedString>,
1094 connection: Rc<dyn AgentConnection>,
1095 project: Entity<Project>,
1096 action_log: Entity<ActionLog>,
1097 session_id: acp::SessionId,
1098 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1099 cx: &mut Context<Self>,
1100 ) -> Self {
1101 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1102 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1103 loop {
1104 let caps = prompt_capabilities_rx.recv().await?;
1105 this.update(cx, |this, cx| {
1106 this.prompt_capabilities = caps;
1107 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1108 })?;
1109 }
1110 });
1111
1112 Self {
1113 action_log,
1114 shared_buffers: Default::default(),
1115 entries: Default::default(),
1116 plan: Default::default(),
1117 title: title.into(),
1118 project,
1119 send_task: None,
1120 connection,
1121 session_id,
1122 token_usage: None,
1123 prompt_capabilities,
1124 _observe_prompt_capabilities: task,
1125 terminals: HashMap::default(),
1126 pending_terminal_output: HashMap::default(),
1127 pending_terminal_exit: HashMap::default(),
1128 }
1129 }
1130
1131 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1132 self.prompt_capabilities.clone()
1133 }
1134
1135 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1136 &self.connection
1137 }
1138
1139 pub fn action_log(&self) -> &Entity<ActionLog> {
1140 &self.action_log
1141 }
1142
1143 pub fn project(&self) -> &Entity<Project> {
1144 &self.project
1145 }
1146
1147 pub fn title(&self) -> SharedString {
1148 self.title.clone()
1149 }
1150
1151 pub fn entries(&self) -> &[AgentThreadEntry] {
1152 &self.entries
1153 }
1154
1155 pub fn session_id(&self) -> &acp::SessionId {
1156 &self.session_id
1157 }
1158
1159 pub fn status(&self) -> ThreadStatus {
1160 if self.send_task.is_some() {
1161 ThreadStatus::Generating
1162 } else {
1163 ThreadStatus::Idle
1164 }
1165 }
1166
1167 pub fn token_usage(&self) -> Option<&TokenUsage> {
1168 self.token_usage.as_ref()
1169 }
1170
1171 pub fn has_pending_edit_tool_calls(&self) -> bool {
1172 for entry in self.entries.iter().rev() {
1173 match entry {
1174 AgentThreadEntry::UserMessage(_) => return false,
1175 AgentThreadEntry::ToolCall(
1176 call @ ToolCall {
1177 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1178 ..
1179 },
1180 ) if call.diffs().next().is_some() => {
1181 return true;
1182 }
1183 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1184 }
1185 }
1186
1187 false
1188 }
1189
1190 pub fn has_in_progress_tool_calls(&self) -> bool {
1191 for entry in self.entries.iter().rev() {
1192 match entry {
1193 AgentThreadEntry::UserMessage(_) => return false,
1194 AgentThreadEntry::ToolCall(ToolCall {
1195 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1196 ..
1197 }) => {
1198 return true;
1199 }
1200 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1201 }
1202 }
1203
1204 false
1205 }
1206
1207 pub fn used_tools_since_last_user_message(&self) -> bool {
1208 for entry in self.entries.iter().rev() {
1209 match entry {
1210 AgentThreadEntry::UserMessage(..) => return false,
1211 AgentThreadEntry::AssistantMessage(..) => continue,
1212 AgentThreadEntry::ToolCall(..) => return true,
1213 }
1214 }
1215
1216 false
1217 }
1218
1219 pub fn handle_session_update(
1220 &mut self,
1221 update: acp::SessionUpdate,
1222 cx: &mut Context<Self>,
1223 ) -> Result<(), acp::Error> {
1224 match update {
1225 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1226 self.push_user_content_block(None, content, cx);
1227 }
1228 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1229 self.push_assistant_content_block(content, false, cx);
1230 }
1231 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1232 self.push_assistant_content_block(content, true, cx);
1233 }
1234 acp::SessionUpdate::ToolCall(tool_call) => {
1235 self.upsert_tool_call(tool_call, cx)?;
1236 }
1237 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1238 self.update_tool_call(tool_call_update, cx)?;
1239 }
1240 acp::SessionUpdate::Plan(plan) => {
1241 self.update_plan(plan, cx);
1242 }
1243 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1244 available_commands,
1245 ..
1246 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1247 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1248 current_mode_id,
1249 ..
1250 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1251 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1252 config_options,
1253 ..
1254 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1255 _ => {}
1256 }
1257 Ok(())
1258 }
1259
1260 pub fn push_user_content_block(
1261 &mut self,
1262 message_id: Option<UserMessageId>,
1263 chunk: acp::ContentBlock,
1264 cx: &mut Context<Self>,
1265 ) {
1266 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1267 }
1268
1269 pub fn push_user_content_block_with_indent(
1270 &mut self,
1271 message_id: Option<UserMessageId>,
1272 chunk: acp::ContentBlock,
1273 indented: bool,
1274 cx: &mut Context<Self>,
1275 ) {
1276 let language_registry = self.project.read(cx).languages().clone();
1277 let path_style = self.project.read(cx).path_style(cx);
1278 let entries_len = self.entries.len();
1279
1280 if let Some(last_entry) = self.entries.last_mut()
1281 && let AgentThreadEntry::UserMessage(UserMessage {
1282 id,
1283 content,
1284 chunks,
1285 indented: existing_indented,
1286 ..
1287 }) = last_entry
1288 && *existing_indented == indented
1289 {
1290 *id = message_id.or(id.take());
1291 content.append(chunk.clone(), &language_registry, path_style, cx);
1292 chunks.push(chunk);
1293 let idx = entries_len - 1;
1294 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1295 } else {
1296 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1297 self.push_entry(
1298 AgentThreadEntry::UserMessage(UserMessage {
1299 id: message_id,
1300 content,
1301 chunks: vec![chunk],
1302 checkpoint: None,
1303 indented,
1304 }),
1305 cx,
1306 );
1307 }
1308 }
1309
1310 pub fn push_assistant_content_block(
1311 &mut self,
1312 chunk: acp::ContentBlock,
1313 is_thought: bool,
1314 cx: &mut Context<Self>,
1315 ) {
1316 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1317 }
1318
1319 pub fn push_assistant_content_block_with_indent(
1320 &mut self,
1321 chunk: acp::ContentBlock,
1322 is_thought: bool,
1323 indented: bool,
1324 cx: &mut Context<Self>,
1325 ) {
1326 let language_registry = self.project.read(cx).languages().clone();
1327 let path_style = self.project.read(cx).path_style(cx);
1328 let entries_len = self.entries.len();
1329 if let Some(last_entry) = self.entries.last_mut()
1330 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1331 chunks,
1332 indented: existing_indented,
1333 }) = last_entry
1334 && *existing_indented == indented
1335 {
1336 let idx = entries_len - 1;
1337 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1338 match (chunks.last_mut(), is_thought) {
1339 (Some(AssistantMessageChunk::Message { block }), false)
1340 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1341 block.append(chunk, &language_registry, path_style, cx)
1342 }
1343 _ => {
1344 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1345 if is_thought {
1346 chunks.push(AssistantMessageChunk::Thought { block })
1347 } else {
1348 chunks.push(AssistantMessageChunk::Message { block })
1349 }
1350 }
1351 }
1352 } else {
1353 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1354 let chunk = if is_thought {
1355 AssistantMessageChunk::Thought { block }
1356 } else {
1357 AssistantMessageChunk::Message { block }
1358 };
1359
1360 self.push_entry(
1361 AgentThreadEntry::AssistantMessage(AssistantMessage {
1362 chunks: vec![chunk],
1363 indented,
1364 }),
1365 cx,
1366 );
1367 }
1368 }
1369
1370 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1371 self.entries.push(entry);
1372 cx.emit(AcpThreadEvent::NewEntry);
1373 }
1374
1375 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1376 self.connection.set_title(&self.session_id, cx).is_some()
1377 }
1378
1379 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1380 if title != self.title {
1381 self.title = title.clone();
1382 cx.emit(AcpThreadEvent::TitleUpdated);
1383 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1384 return set_title.run(title, cx);
1385 }
1386 }
1387 Task::ready(Ok(()))
1388 }
1389
1390 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1391 self.token_usage = usage;
1392 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1393 }
1394
1395 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1396 cx.emit(AcpThreadEvent::Retry(status));
1397 }
1398
1399 pub fn update_tool_call(
1400 &mut self,
1401 update: impl Into<ToolCallUpdate>,
1402 cx: &mut Context<Self>,
1403 ) -> Result<()> {
1404 let update = update.into();
1405 let languages = self.project.read(cx).languages().clone();
1406 let path_style = self.project.read(cx).path_style(cx);
1407
1408 let ix = match self.index_for_tool_call(update.id()) {
1409 Some(ix) => ix,
1410 None => {
1411 // Tool call not found - create a failed tool call entry
1412 let failed_tool_call = ToolCall {
1413 id: update.id().clone(),
1414 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1415 kind: acp::ToolKind::Fetch,
1416 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1417 "Tool call not found".into(),
1418 &languages,
1419 path_style,
1420 cx,
1421 ))],
1422 status: ToolCallStatus::Failed,
1423 locations: Vec::new(),
1424 resolved_locations: Vec::new(),
1425 raw_input: None,
1426 raw_input_markdown: None,
1427 raw_output: None,
1428 };
1429 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1430 return Ok(());
1431 }
1432 };
1433 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1434 unreachable!()
1435 };
1436
1437 match update {
1438 ToolCallUpdate::UpdateFields(update) => {
1439 let location_updated = update.fields.locations.is_some();
1440 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1441 if location_updated {
1442 self.resolve_locations(update.tool_call_id, cx);
1443 }
1444 }
1445 ToolCallUpdate::UpdateDiff(update) => {
1446 call.content.clear();
1447 call.content.push(ToolCallContent::Diff(update.diff));
1448 }
1449 ToolCallUpdate::UpdateTerminal(update) => {
1450 call.content.clear();
1451 call.content
1452 .push(ToolCallContent::Terminal(update.terminal));
1453 }
1454 }
1455
1456 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1457
1458 Ok(())
1459 }
1460
1461 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1462 pub fn upsert_tool_call(
1463 &mut self,
1464 tool_call: acp::ToolCall,
1465 cx: &mut Context<Self>,
1466 ) -> Result<(), acp::Error> {
1467 let status = tool_call.status.into();
1468 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1469 }
1470
1471 /// Fails if id does not match an existing entry.
1472 pub fn upsert_tool_call_inner(
1473 &mut self,
1474 update: acp::ToolCallUpdate,
1475 status: ToolCallStatus,
1476 cx: &mut Context<Self>,
1477 ) -> Result<(), acp::Error> {
1478 let language_registry = self.project.read(cx).languages().clone();
1479 let path_style = self.project.read(cx).path_style(cx);
1480 let id = update.tool_call_id.clone();
1481
1482 let agent_telemetry_id = self.connection().telemetry_id();
1483 let session = self.session_id();
1484 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1485 let status = if matches!(status, ToolCallStatus::Completed) {
1486 "completed"
1487 } else {
1488 "failed"
1489 };
1490 telemetry::event!(
1491 "Agent Tool Call Completed",
1492 agent_telemetry_id,
1493 session,
1494 status
1495 );
1496 }
1497
1498 if let Some(ix) = self.index_for_tool_call(&id) {
1499 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1500 unreachable!()
1501 };
1502
1503 call.update_fields(
1504 update.fields,
1505 language_registry,
1506 path_style,
1507 &self.terminals,
1508 cx,
1509 )?;
1510 call.status = status;
1511
1512 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1513 } else {
1514 let call = ToolCall::from_acp(
1515 update.try_into()?,
1516 status,
1517 language_registry,
1518 self.project.read(cx).path_style(cx),
1519 &self.terminals,
1520 cx,
1521 )?;
1522 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1523 };
1524
1525 self.resolve_locations(id, cx);
1526 Ok(())
1527 }
1528
1529 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1530 self.entries
1531 .iter()
1532 .enumerate()
1533 .rev()
1534 .find_map(|(index, entry)| {
1535 if let AgentThreadEntry::ToolCall(tool_call) = entry
1536 && &tool_call.id == id
1537 {
1538 Some(index)
1539 } else {
1540 None
1541 }
1542 })
1543 }
1544
1545 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1546 // The tool call we are looking for is typically the last one, or very close to the end.
1547 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1548 self.entries
1549 .iter_mut()
1550 .enumerate()
1551 .rev()
1552 .find_map(|(index, tool_call)| {
1553 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1554 && &tool_call.id == id
1555 {
1556 Some((index, tool_call))
1557 } else {
1558 None
1559 }
1560 })
1561 }
1562
1563 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1564 self.entries
1565 .iter()
1566 .enumerate()
1567 .rev()
1568 .find_map(|(index, tool_call)| {
1569 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1570 && &tool_call.id == id
1571 {
1572 Some((index, tool_call))
1573 } else {
1574 None
1575 }
1576 })
1577 }
1578
1579 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1580 let project = self.project.clone();
1581 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1582 return;
1583 };
1584 let task = tool_call.resolve_locations(project, cx);
1585 cx.spawn(async move |this, cx| {
1586 let resolved_locations = task.await;
1587
1588 this.update(cx, |this, cx| {
1589 let project = this.project.clone();
1590
1591 for location in resolved_locations.iter().flatten() {
1592 this.shared_buffers
1593 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1594 }
1595 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1596 return;
1597 };
1598
1599 if let Some(Some(location)) = resolved_locations.last() {
1600 project.update(cx, |project, cx| {
1601 let should_ignore = if let Some(agent_location) = project
1602 .agent_location()
1603 .filter(|agent_location| agent_location.buffer == location.buffer)
1604 {
1605 let snapshot = location.buffer.read(cx).snapshot();
1606 let old_position = agent_location.position.to_point(&snapshot);
1607 let new_position = location.position.to_point(&snapshot);
1608
1609 // ignore this so that when we get updates from the edit tool
1610 // the position doesn't reset to the startof line
1611 old_position.row == new_position.row
1612 && old_position.column > new_position.column
1613 } else {
1614 false
1615 };
1616 if !should_ignore {
1617 project.set_agent_location(Some(location.into()), cx);
1618 }
1619 });
1620 }
1621
1622 let resolved_locations = resolved_locations
1623 .iter()
1624 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1625 .collect::<Vec<_>>();
1626
1627 if tool_call.resolved_locations != resolved_locations {
1628 tool_call.resolved_locations = resolved_locations;
1629 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1630 }
1631 })
1632 })
1633 .detach();
1634 }
1635
1636 pub fn request_tool_call_authorization(
1637 &mut self,
1638 tool_call: acp::ToolCallUpdate,
1639 options: Vec<acp::PermissionOption>,
1640 respect_always_allow_setting: bool,
1641 cx: &mut Context<Self>,
1642 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1643 let (tx, rx) = oneshot::channel();
1644
1645 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1646 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1647 // some tools would (incorrectly) continue to auto-accept.
1648 if let Some(allow_once_option) = options.iter().find_map(|option| {
1649 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1650 Some(option.option_id.clone())
1651 } else {
1652 None
1653 }
1654 }) {
1655 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1656 return Ok(async {
1657 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1658 allow_once_option,
1659 ))
1660 }
1661 .boxed());
1662 }
1663 }
1664
1665 let status = ToolCallStatus::WaitingForConfirmation {
1666 options,
1667 respond_tx: tx,
1668 };
1669
1670 self.upsert_tool_call_inner(tool_call, status, cx)?;
1671 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1672
1673 let fut = async {
1674 match rx.await {
1675 Ok(option) => acp::RequestPermissionOutcome::Selected(
1676 acp::SelectedPermissionOutcome::new(option),
1677 ),
1678 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1679 }
1680 }
1681 .boxed();
1682
1683 Ok(fut)
1684 }
1685
1686 pub fn authorize_tool_call(
1687 &mut self,
1688 id: acp::ToolCallId,
1689 option_id: acp::PermissionOptionId,
1690 option_kind: acp::PermissionOptionKind,
1691 cx: &mut Context<Self>,
1692 ) {
1693 let Some((ix, call)) = self.tool_call_mut(&id) else {
1694 return;
1695 };
1696
1697 let new_status = match option_kind {
1698 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1699 ToolCallStatus::Rejected
1700 }
1701 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1702 ToolCallStatus::InProgress
1703 }
1704 _ => ToolCallStatus::InProgress,
1705 };
1706
1707 let curr_status = mem::replace(&mut call.status, new_status);
1708
1709 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1710 respond_tx.send(option_id).log_err();
1711 } else if cfg!(debug_assertions) {
1712 panic!("tried to authorize an already authorized tool call");
1713 }
1714
1715 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1716 }
1717
1718 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1719 let mut first_tool_call = None;
1720
1721 for entry in self.entries.iter().rev() {
1722 match &entry {
1723 AgentThreadEntry::ToolCall(call) => {
1724 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1725 first_tool_call = Some(call);
1726 } else {
1727 continue;
1728 }
1729 }
1730 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1731 // Reached the beginning of the turn.
1732 // If we had pending permission requests in the previous turn, they have been cancelled.
1733 break;
1734 }
1735 }
1736 }
1737
1738 first_tool_call
1739 }
1740
1741 pub fn plan(&self) -> &Plan {
1742 &self.plan
1743 }
1744
1745 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1746 let new_entries_len = request.entries.len();
1747 let mut new_entries = request.entries.into_iter();
1748
1749 // Reuse existing markdown to prevent flickering
1750 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1751 let PlanEntry {
1752 content,
1753 priority,
1754 status,
1755 } = old;
1756 content.update(cx, |old, cx| {
1757 old.replace(new.content, cx);
1758 });
1759 *priority = new.priority;
1760 *status = new.status;
1761 }
1762 for new in new_entries {
1763 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1764 }
1765 self.plan.entries.truncate(new_entries_len);
1766
1767 cx.notify();
1768 }
1769
1770 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1771 self.plan
1772 .entries
1773 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1774 cx.notify();
1775 }
1776
1777 #[cfg(any(test, feature = "test-support"))]
1778 pub fn send_raw(
1779 &mut self,
1780 message: &str,
1781 cx: &mut Context<Self>,
1782 ) -> BoxFuture<'static, Result<()>> {
1783 self.send(vec![message.into()], cx)
1784 }
1785
1786 pub fn send(
1787 &mut self,
1788 message: Vec<acp::ContentBlock>,
1789 cx: &mut Context<Self>,
1790 ) -> BoxFuture<'static, Result<()>> {
1791 let block = ContentBlock::new_combined(
1792 message.clone(),
1793 self.project.read(cx).languages().clone(),
1794 self.project.read(cx).path_style(cx),
1795 cx,
1796 );
1797 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1798 let git_store = self.project.read(cx).git_store().clone();
1799
1800 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1801 Some(UserMessageId::new())
1802 } else {
1803 None
1804 };
1805
1806 self.run_turn(cx, async move |this, cx| {
1807 this.update(cx, |this, cx| {
1808 this.push_entry(
1809 AgentThreadEntry::UserMessage(UserMessage {
1810 id: message_id.clone(),
1811 content: block,
1812 chunks: message,
1813 checkpoint: None,
1814 indented: false,
1815 }),
1816 cx,
1817 );
1818 })
1819 .ok();
1820
1821 let old_checkpoint = git_store
1822 .update(cx, |git, cx| git.checkpoint(cx))
1823 .await
1824 .context("failed to get old checkpoint")
1825 .log_err();
1826 this.update(cx, |this, cx| {
1827 if let Some((_ix, message)) = this.last_user_message() {
1828 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1829 git_checkpoint,
1830 show: false,
1831 });
1832 }
1833 this.connection.prompt(message_id, request, cx)
1834 })?
1835 .await
1836 })
1837 }
1838
1839 pub fn can_resume(&self, cx: &App) -> bool {
1840 self.connection.resume(&self.session_id, cx).is_some()
1841 }
1842
1843 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1844 self.run_turn(cx, async move |this, cx| {
1845 this.update(cx, |this, cx| {
1846 this.connection
1847 .resume(&this.session_id, cx)
1848 .map(|resume| resume.run(cx))
1849 })?
1850 .context("resuming a session is not supported")?
1851 .await
1852 })
1853 }
1854
1855 fn run_turn(
1856 &mut self,
1857 cx: &mut Context<Self>,
1858 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1859 ) -> BoxFuture<'static, Result<()>> {
1860 self.clear_completed_plan_entries(cx);
1861
1862 let (tx, rx) = oneshot::channel();
1863 let cancel_task = self.cancel(cx);
1864
1865 self.send_task = Some(cx.spawn(async move |this, cx| {
1866 cancel_task.await;
1867 tx.send(f(this, cx).await).ok();
1868 }));
1869
1870 cx.spawn(async move |this, cx| {
1871 let response = rx.await;
1872
1873 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1874 .await?;
1875
1876 this.update(cx, |this, cx| {
1877 this.project
1878 .update(cx, |project, cx| project.set_agent_location(None, cx));
1879 match response {
1880 Ok(Err(e)) => {
1881 this.send_task.take();
1882 cx.emit(AcpThreadEvent::Error);
1883 Err(e)
1884 }
1885 result => {
1886 let canceled = matches!(
1887 result,
1888 Ok(Ok(acp::PromptResponse {
1889 stop_reason: acp::StopReason::Cancelled,
1890 ..
1891 }))
1892 );
1893
1894 // We only take the task if the current prompt wasn't canceled.
1895 //
1896 // This prompt may have been canceled because another one was sent
1897 // while it was still generating. In these cases, dropping `send_task`
1898 // would cause the next generation to be canceled.
1899 if !canceled {
1900 this.send_task.take();
1901 }
1902
1903 // Handle refusal - distinguish between user prompt and tool call refusals
1904 if let Ok(Ok(acp::PromptResponse {
1905 stop_reason: acp::StopReason::Refusal,
1906 ..
1907 })) = result
1908 {
1909 if let Some((user_msg_ix, _)) = this.last_user_message() {
1910 // Check if there's a completed tool call with results after the last user message
1911 // This indicates the refusal is in response to tool output, not the user's prompt
1912 let has_completed_tool_call_after_user_msg =
1913 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1914 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1915 // Check if the tool call has completed and has output
1916 matches!(tool_call.status, ToolCallStatus::Completed)
1917 && tool_call.raw_output.is_some()
1918 } else {
1919 false
1920 }
1921 });
1922
1923 if has_completed_tool_call_after_user_msg {
1924 // Refusal is due to tool output - don't truncate, just notify
1925 // The model refused based on what the tool returned
1926 cx.emit(AcpThreadEvent::Refusal);
1927 } else {
1928 // User prompt was refused - truncate back to before the user message
1929 let range = user_msg_ix..this.entries.len();
1930 if range.start < range.end {
1931 this.entries.truncate(user_msg_ix);
1932 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1933 }
1934 cx.emit(AcpThreadEvent::Refusal);
1935 }
1936 } else {
1937 // No user message found, treat as general refusal
1938 cx.emit(AcpThreadEvent::Refusal);
1939 }
1940 }
1941
1942 cx.emit(AcpThreadEvent::Stopped);
1943 Ok(())
1944 }
1945 }
1946 })?
1947 })
1948 .boxed()
1949 }
1950
1951 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1952 let Some(send_task) = self.send_task.take() else {
1953 return Task::ready(());
1954 };
1955
1956 for entry in self.entries.iter_mut() {
1957 if let AgentThreadEntry::ToolCall(call) = entry {
1958 let cancel = matches!(
1959 call.status,
1960 ToolCallStatus::Pending
1961 | ToolCallStatus::WaitingForConfirmation { .. }
1962 | ToolCallStatus::InProgress
1963 );
1964
1965 if cancel {
1966 call.status = ToolCallStatus::Canceled;
1967 }
1968 }
1969 }
1970
1971 self.connection.cancel(&self.session_id, cx);
1972
1973 // Wait for the send task to complete
1974 cx.foreground_executor().spawn(send_task)
1975 }
1976
1977 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1978 pub fn restore_checkpoint(
1979 &mut self,
1980 id: UserMessageId,
1981 cx: &mut Context<Self>,
1982 ) -> Task<Result<()>> {
1983 let Some((_, message)) = self.user_message_mut(&id) else {
1984 return Task::ready(Err(anyhow!("message not found")));
1985 };
1986
1987 let checkpoint = message
1988 .checkpoint
1989 .as_ref()
1990 .map(|c| c.git_checkpoint.clone());
1991
1992 // Cancel any in-progress generation before restoring
1993 let cancel_task = self.cancel(cx);
1994 let rewind = self.rewind(id.clone(), cx);
1995 let git_store = self.project.read(cx).git_store().clone();
1996
1997 cx.spawn(async move |_, cx| {
1998 cancel_task.await;
1999 rewind.await?;
2000 if let Some(checkpoint) = checkpoint {
2001 git_store
2002 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
2003 .await?;
2004 }
2005
2006 Ok(())
2007 })
2008 }
2009
2010 /// Rewinds this thread to before the entry at `index`, removing it and all
2011 /// subsequent entries while rejecting any action_log changes made from that point.
2012 /// Unlike `restore_checkpoint`, this method does not restore from git.
2013 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
2014 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
2015 return Task::ready(Err(anyhow!("not supported")));
2016 };
2017
2018 let telemetry = ActionLogTelemetry::from(&*self);
2019 cx.spawn(async move |this, cx| {
2020 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2021 this.update(cx, |this, cx| {
2022 if let Some((ix, _)) = this.user_message_mut(&id) {
2023 // Collect all terminals from entries that will be removed
2024 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2025 .iter()
2026 .flat_map(|entry| entry.terminals())
2027 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2028 .collect();
2029
2030 let range = ix..this.entries.len();
2031 this.entries.truncate(ix);
2032 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2033
2034 // Kill and remove the terminals
2035 for terminal_id in terminals_to_remove {
2036 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2037 terminal.update(cx, |terminal, cx| {
2038 terminal.kill(cx);
2039 });
2040 }
2041 }
2042 }
2043 this.action_log().update(cx, |action_log, cx| {
2044 action_log.reject_all_edits(Some(telemetry), cx)
2045 })
2046 })?
2047 .await;
2048 Ok(())
2049 })
2050 }
2051
2052 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2053 let git_store = self.project.read(cx).git_store().clone();
2054
2055 let Some((_, message)) = self.last_user_message() else {
2056 return Task::ready(Ok(()));
2057 };
2058 let Some(user_message_id) = message.id.clone() else {
2059 return Task::ready(Ok(()));
2060 };
2061 let Some(checkpoint) = message.checkpoint.as_ref() else {
2062 return Task::ready(Ok(()));
2063 };
2064 let old_checkpoint = checkpoint.git_checkpoint.clone();
2065
2066 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2067 cx.spawn(async move |this, cx| {
2068 let Some(new_checkpoint) = new_checkpoint
2069 .await
2070 .context("failed to get new checkpoint")
2071 .log_err()
2072 else {
2073 return Ok(());
2074 };
2075
2076 let equal = git_store
2077 .update(cx, |git, cx| {
2078 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2079 })
2080 .await
2081 .unwrap_or(true);
2082
2083 this.update(cx, |this, cx| {
2084 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2085 if let Some(checkpoint) = message.checkpoint.as_mut() {
2086 checkpoint.show = !equal;
2087 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2088 }
2089 }
2090 })?;
2091
2092 Ok(())
2093 })
2094 }
2095
2096 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2097 self.entries
2098 .iter_mut()
2099 .enumerate()
2100 .rev()
2101 .find_map(|(ix, entry)| {
2102 if let AgentThreadEntry::UserMessage(message) = entry {
2103 Some((ix, message))
2104 } else {
2105 None
2106 }
2107 })
2108 }
2109
2110 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2111 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2112 if let AgentThreadEntry::UserMessage(message) = entry {
2113 if message.id.as_ref() == Some(id) {
2114 Some((ix, message))
2115 } else {
2116 None
2117 }
2118 } else {
2119 None
2120 }
2121 })
2122 }
2123
2124 pub fn read_text_file(
2125 &self,
2126 path: PathBuf,
2127 line: Option<u32>,
2128 limit: Option<u32>,
2129 reuse_shared_snapshot: bool,
2130 cx: &mut Context<Self>,
2131 ) -> Task<Result<String, acp::Error>> {
2132 // Args are 1-based, move to 0-based
2133 let line = line.unwrap_or_default().saturating_sub(1);
2134 let limit = limit.unwrap_or(u32::MAX);
2135 let project = self.project.clone();
2136 let action_log = self.action_log.clone();
2137 cx.spawn(async move |this, cx| {
2138 let load = project.update(cx, |project, cx| {
2139 let path = project
2140 .project_path_for_absolute_path(&path, cx)
2141 .ok_or_else(|| {
2142 acp::Error::resource_not_found(Some(path.display().to_string()))
2143 })?;
2144 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2145 })?;
2146
2147 let buffer = load.await?;
2148
2149 let snapshot = if reuse_shared_snapshot {
2150 this.read_with(cx, |this, _| {
2151 this.shared_buffers.get(&buffer.clone()).cloned()
2152 })
2153 .log_err()
2154 .flatten()
2155 } else {
2156 None
2157 };
2158
2159 let snapshot = if let Some(snapshot) = snapshot {
2160 snapshot
2161 } else {
2162 action_log.update(cx, |action_log, cx| {
2163 action_log.buffer_read(buffer.clone(), cx);
2164 });
2165
2166 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2167 this.update(cx, |this, _| {
2168 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2169 })?;
2170 snapshot
2171 };
2172
2173 let max_point = snapshot.max_point();
2174 let start_position = Point::new(line, 0);
2175
2176 if start_position > max_point {
2177 return Err(acp::Error::invalid_params().data(format!(
2178 "Attempting to read beyond the end of the file, line {}:{}",
2179 max_point.row + 1,
2180 max_point.column
2181 )));
2182 }
2183
2184 let start = snapshot.anchor_before(start_position);
2185 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2186
2187 project.update(cx, |project, cx| {
2188 project.set_agent_location(
2189 Some(AgentLocation {
2190 buffer: buffer.downgrade(),
2191 position: start,
2192 }),
2193 cx,
2194 );
2195 });
2196
2197 Ok(snapshot.text_for_range(start..end).collect::<String>())
2198 })
2199 }
2200
2201 pub fn write_text_file(
2202 &self,
2203 path: PathBuf,
2204 content: String,
2205 cx: &mut Context<Self>,
2206 ) -> Task<Result<()>> {
2207 let project = self.project.clone();
2208 let action_log = self.action_log.clone();
2209 cx.spawn(async move |this, cx| {
2210 let load = project.update(cx, |project, cx| {
2211 let path = project
2212 .project_path_for_absolute_path(&path, cx)
2213 .context("invalid path")?;
2214 anyhow::Ok(project.open_buffer(path, cx))
2215 });
2216 let buffer = load?.await?;
2217 let snapshot = this.update(cx, |this, cx| {
2218 this.shared_buffers
2219 .get(&buffer)
2220 .cloned()
2221 .unwrap_or_else(|| buffer.read(cx).snapshot())
2222 })?;
2223 let edits = cx
2224 .background_executor()
2225 .spawn(async move {
2226 let old_text = snapshot.text();
2227 text_diff(old_text.as_str(), &content)
2228 .into_iter()
2229 .map(|(range, replacement)| {
2230 (
2231 snapshot.anchor_after(range.start)
2232 ..snapshot.anchor_before(range.end),
2233 replacement,
2234 )
2235 })
2236 .collect::<Vec<_>>()
2237 })
2238 .await;
2239
2240 project.update(cx, |project, cx| {
2241 project.set_agent_location(
2242 Some(AgentLocation {
2243 buffer: buffer.downgrade(),
2244 position: edits
2245 .last()
2246 .map(|(range, _)| range.end)
2247 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2248 }),
2249 cx,
2250 );
2251 });
2252
2253 let format_on_save = cx.update(|cx| {
2254 action_log.update(cx, |action_log, cx| {
2255 action_log.buffer_read(buffer.clone(), cx);
2256 });
2257
2258 let format_on_save = buffer.update(cx, |buffer, cx| {
2259 buffer.edit(edits, None, cx);
2260
2261 let settings = language::language_settings::language_settings(
2262 buffer.language().map(|l| l.name()),
2263 buffer.file(),
2264 cx,
2265 );
2266
2267 settings.format_on_save != FormatOnSave::Off
2268 });
2269 action_log.update(cx, |action_log, cx| {
2270 action_log.buffer_edited(buffer.clone(), cx);
2271 });
2272 format_on_save
2273 });
2274
2275 if format_on_save {
2276 let format_task = project.update(cx, |project, cx| {
2277 project.format(
2278 HashSet::from_iter([buffer.clone()]),
2279 LspFormatTarget::Buffers,
2280 false,
2281 FormatTrigger::Save,
2282 cx,
2283 )
2284 });
2285 format_task.await.log_err();
2286
2287 action_log.update(cx, |action_log, cx| {
2288 action_log.buffer_edited(buffer.clone(), cx);
2289 });
2290 }
2291
2292 project
2293 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2294 .await
2295 })
2296 }
2297
2298 pub fn create_terminal(
2299 &self,
2300 command: String,
2301 args: Vec<String>,
2302 extra_env: Vec<acp::EnvVariable>,
2303 cwd: Option<PathBuf>,
2304 output_byte_limit: Option<u64>,
2305 cx: &mut Context<Self>,
2306 ) -> Task<Result<Entity<Terminal>>> {
2307 let env = match &cwd {
2308 Some(dir) => self.project.update(cx, |project, cx| {
2309 project.environment().update(cx, |env, cx| {
2310 env.directory_environment(dir.as_path().into(), cx)
2311 })
2312 }),
2313 None => Task::ready(None).shared(),
2314 };
2315 let env = cx.spawn(async move |_, _| {
2316 let mut env = env.await.unwrap_or_default();
2317 // Disables paging for `git` and hopefully other commands
2318 env.insert("PAGER".into(), "".into());
2319 for var in extra_env {
2320 env.insert(var.name, var.value);
2321 }
2322 env
2323 });
2324
2325 let project = self.project.clone();
2326 let language_registry = project.read(cx).languages().clone();
2327 let is_windows = project.read(cx).path_style(cx).is_windows();
2328
2329 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2330 let terminal_task = cx.spawn({
2331 let terminal_id = terminal_id.clone();
2332 async move |_this, cx| {
2333 let env = env.await;
2334 let shell = project
2335 .update(cx, |project, cx| {
2336 project
2337 .remote_client()
2338 .and_then(|r| r.read(cx).default_system_shell())
2339 })
2340 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2341 let (task_command, task_args) =
2342 ShellBuilder::new(&Shell::Program(shell), is_windows)
2343 .redirect_stdin_to_dev_null()
2344 .build(Some(command.clone()), &args);
2345 let terminal = project
2346 .update(cx, |project, cx| {
2347 project.create_terminal_task(
2348 task::SpawnInTerminal {
2349 command: Some(task_command),
2350 args: task_args,
2351 cwd: cwd.clone(),
2352 env,
2353 ..Default::default()
2354 },
2355 cx,
2356 )
2357 })
2358 .await?;
2359
2360 anyhow::Ok(cx.new(|cx| {
2361 Terminal::new(
2362 terminal_id,
2363 &format!("{} {}", command, args.join(" ")),
2364 cwd,
2365 output_byte_limit.map(|l| l as usize),
2366 terminal,
2367 language_registry,
2368 cx,
2369 )
2370 }))
2371 }
2372 });
2373
2374 cx.spawn(async move |this, cx| {
2375 let terminal = terminal_task.await?;
2376 this.update(cx, |this, _cx| {
2377 this.terminals.insert(terminal_id, terminal.clone());
2378 terminal
2379 })
2380 })
2381 }
2382
2383 pub fn kill_terminal(
2384 &mut self,
2385 terminal_id: acp::TerminalId,
2386 cx: &mut Context<Self>,
2387 ) -> Result<()> {
2388 self.terminals
2389 .get(&terminal_id)
2390 .context("Terminal not found")?
2391 .update(cx, |terminal, cx| {
2392 terminal.kill(cx);
2393 });
2394
2395 Ok(())
2396 }
2397
2398 pub fn release_terminal(
2399 &mut self,
2400 terminal_id: acp::TerminalId,
2401 cx: &mut Context<Self>,
2402 ) -> Result<()> {
2403 self.terminals
2404 .remove(&terminal_id)
2405 .context("Terminal not found")?
2406 .update(cx, |terminal, cx| {
2407 terminal.kill(cx);
2408 });
2409
2410 Ok(())
2411 }
2412
2413 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2414 self.terminals
2415 .get(&terminal_id)
2416 .context("Terminal not found")
2417 .cloned()
2418 }
2419
2420 pub fn to_markdown(&self, cx: &App) -> String {
2421 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2422 }
2423
2424 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2425 cx.emit(AcpThreadEvent::LoadError(error));
2426 }
2427
2428 pub fn register_terminal_created(
2429 &mut self,
2430 terminal_id: acp::TerminalId,
2431 command_label: String,
2432 working_dir: Option<PathBuf>,
2433 output_byte_limit: Option<u64>,
2434 terminal: Entity<::terminal::Terminal>,
2435 cx: &mut Context<Self>,
2436 ) -> Entity<Terminal> {
2437 let language_registry = self.project.read(cx).languages().clone();
2438
2439 let entity = cx.new(|cx| {
2440 Terminal::new(
2441 terminal_id.clone(),
2442 &command_label,
2443 working_dir.clone(),
2444 output_byte_limit.map(|l| l as usize),
2445 terminal,
2446 language_registry,
2447 cx,
2448 )
2449 });
2450 self.terminals.insert(terminal_id.clone(), entity.clone());
2451 entity
2452 }
2453}
2454
2455fn markdown_for_raw_output(
2456 raw_output: &serde_json::Value,
2457 language_registry: &Arc<LanguageRegistry>,
2458 cx: &mut App,
2459) -> Option<Entity<Markdown>> {
2460 match raw_output {
2461 serde_json::Value::Null => None,
2462 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2463 Markdown::new(
2464 value.to_string().into(),
2465 Some(language_registry.clone()),
2466 None,
2467 cx,
2468 )
2469 })),
2470 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2471 Markdown::new(
2472 value.to_string().into(),
2473 Some(language_registry.clone()),
2474 None,
2475 cx,
2476 )
2477 })),
2478 serde_json::Value::String(value) => Some(cx.new(|cx| {
2479 Markdown::new(
2480 value.clone().into(),
2481 Some(language_registry.clone()),
2482 None,
2483 cx,
2484 )
2485 })),
2486 value => Some(cx.new(|cx| {
2487 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2488
2489 Markdown::new(
2490 format!("```json\n{}\n```", pretty_json).into(),
2491 Some(language_registry.clone()),
2492 None,
2493 cx,
2494 )
2495 })),
2496 }
2497}
2498
2499#[cfg(test)]
2500mod tests {
2501 use super::*;
2502 use anyhow::anyhow;
2503 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2504 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2505 use indoc::indoc;
2506 use project::{FakeFs, Fs};
2507 use rand::{distr, prelude::*};
2508 use serde_json::json;
2509 use settings::SettingsStore;
2510 use smol::stream::StreamExt as _;
2511 use std::{
2512 any::Any,
2513 cell::RefCell,
2514 path::Path,
2515 rc::Rc,
2516 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2517 time::Duration,
2518 };
2519 use util::path;
2520
2521 fn init_test(cx: &mut TestAppContext) {
2522 env_logger::try_init().ok();
2523 cx.update(|cx| {
2524 let settings_store = SettingsStore::test(cx);
2525 cx.set_global(settings_store);
2526 });
2527 }
2528
2529 #[gpui::test]
2530 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2531 init_test(cx);
2532
2533 let fs = FakeFs::new(cx.executor());
2534 let project = Project::test(fs, [], cx).await;
2535 let connection = Rc::new(FakeAgentConnection::new());
2536 let thread = cx
2537 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2538 .await
2539 .unwrap();
2540
2541 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2542
2543 // Send Output BEFORE Created - should be buffered by acp_thread
2544 thread.update(cx, |thread, cx| {
2545 thread.on_terminal_provider_event(
2546 TerminalProviderEvent::Output {
2547 terminal_id: terminal_id.clone(),
2548 data: b"hello buffered".to_vec(),
2549 },
2550 cx,
2551 );
2552 });
2553
2554 // Create a display-only terminal and then send Created
2555 let lower = cx.new(|cx| {
2556 let builder = ::terminal::TerminalBuilder::new_display_only(
2557 ::terminal::terminal_settings::CursorShape::default(),
2558 ::terminal::terminal_settings::AlternateScroll::On,
2559 None,
2560 0,
2561 )
2562 .unwrap();
2563 builder.subscribe(cx)
2564 });
2565
2566 thread.update(cx, |thread, cx| {
2567 thread.on_terminal_provider_event(
2568 TerminalProviderEvent::Created {
2569 terminal_id: terminal_id.clone(),
2570 label: "Buffered Test".to_string(),
2571 cwd: None,
2572 output_byte_limit: None,
2573 terminal: lower.clone(),
2574 },
2575 cx,
2576 );
2577 });
2578
2579 // After Created, buffered Output should have been flushed into the renderer
2580 let content = thread.read_with(cx, |thread, cx| {
2581 let term = thread.terminal(terminal_id.clone()).unwrap();
2582 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2583 });
2584
2585 assert!(
2586 content.contains("hello buffered"),
2587 "expected buffered output to render, got: {content}"
2588 );
2589 }
2590
2591 #[gpui::test]
2592 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2593 init_test(cx);
2594
2595 let fs = FakeFs::new(cx.executor());
2596 let project = Project::test(fs, [], cx).await;
2597 let connection = Rc::new(FakeAgentConnection::new());
2598 let thread = cx
2599 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2600 .await
2601 .unwrap();
2602
2603 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2604
2605 // Send Output BEFORE Created
2606 thread.update(cx, |thread, cx| {
2607 thread.on_terminal_provider_event(
2608 TerminalProviderEvent::Output {
2609 terminal_id: terminal_id.clone(),
2610 data: b"pre-exit data".to_vec(),
2611 },
2612 cx,
2613 );
2614 });
2615
2616 // Send Exit BEFORE Created
2617 thread.update(cx, |thread, cx| {
2618 thread.on_terminal_provider_event(
2619 TerminalProviderEvent::Exit {
2620 terminal_id: terminal_id.clone(),
2621 status: acp::TerminalExitStatus::new().exit_code(0),
2622 },
2623 cx,
2624 );
2625 });
2626
2627 // Now create a display-only lower-level terminal and send Created
2628 let lower = cx.new(|cx| {
2629 let builder = ::terminal::TerminalBuilder::new_display_only(
2630 ::terminal::terminal_settings::CursorShape::default(),
2631 ::terminal::terminal_settings::AlternateScroll::On,
2632 None,
2633 0,
2634 )
2635 .unwrap();
2636 builder.subscribe(cx)
2637 });
2638
2639 thread.update(cx, |thread, cx| {
2640 thread.on_terminal_provider_event(
2641 TerminalProviderEvent::Created {
2642 terminal_id: terminal_id.clone(),
2643 label: "Buffered Exit Test".to_string(),
2644 cwd: None,
2645 output_byte_limit: None,
2646 terminal: lower.clone(),
2647 },
2648 cx,
2649 );
2650 });
2651
2652 // Output should be present after Created (flushed from buffer)
2653 let content = thread.read_with(cx, |thread, cx| {
2654 let term = thread.terminal(terminal_id.clone()).unwrap();
2655 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2656 });
2657
2658 assert!(
2659 content.contains("pre-exit data"),
2660 "expected pre-exit data to render, got: {content}"
2661 );
2662 }
2663
2664 #[gpui::test]
2665 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2666 init_test(cx);
2667
2668 let fs = FakeFs::new(cx.executor());
2669 let project = Project::test(fs, [], cx).await;
2670 let connection = Rc::new(FakeAgentConnection::new());
2671 let thread = cx
2672 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2673 .await
2674 .unwrap();
2675
2676 // Test creating a new user message
2677 thread.update(cx, |thread, cx| {
2678 thread.push_user_content_block(None, "Hello, ".into(), cx);
2679 });
2680
2681 thread.update(cx, |thread, cx| {
2682 assert_eq!(thread.entries.len(), 1);
2683 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2684 assert_eq!(user_msg.id, None);
2685 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2686 } else {
2687 panic!("Expected UserMessage");
2688 }
2689 });
2690
2691 // Test appending to existing user message
2692 let message_1_id = UserMessageId::new();
2693 thread.update(cx, |thread, cx| {
2694 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2695 });
2696
2697 thread.update(cx, |thread, cx| {
2698 assert_eq!(thread.entries.len(), 1);
2699 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2700 assert_eq!(user_msg.id, Some(message_1_id));
2701 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2702 } else {
2703 panic!("Expected UserMessage");
2704 }
2705 });
2706
2707 // Test creating new user message after assistant message
2708 thread.update(cx, |thread, cx| {
2709 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2710 });
2711
2712 let message_2_id = UserMessageId::new();
2713 thread.update(cx, |thread, cx| {
2714 thread.push_user_content_block(
2715 Some(message_2_id.clone()),
2716 "New user message".into(),
2717 cx,
2718 );
2719 });
2720
2721 thread.update(cx, |thread, cx| {
2722 assert_eq!(thread.entries.len(), 3);
2723 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2724 assert_eq!(user_msg.id, Some(message_2_id));
2725 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2726 } else {
2727 panic!("Expected UserMessage at index 2");
2728 }
2729 });
2730 }
2731
2732 #[gpui::test]
2733 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2734 init_test(cx);
2735
2736 let fs = FakeFs::new(cx.executor());
2737 let project = Project::test(fs, [], cx).await;
2738 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2739 |_, thread, mut cx| {
2740 async move {
2741 thread.update(&mut cx, |thread, cx| {
2742 thread
2743 .handle_session_update(
2744 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2745 "Thinking ".into(),
2746 )),
2747 cx,
2748 )
2749 .unwrap();
2750 thread
2751 .handle_session_update(
2752 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2753 "hard!".into(),
2754 )),
2755 cx,
2756 )
2757 .unwrap();
2758 })?;
2759 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2760 }
2761 .boxed_local()
2762 },
2763 ));
2764
2765 let thread = cx
2766 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2767 .await
2768 .unwrap();
2769
2770 thread
2771 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2772 .await
2773 .unwrap();
2774
2775 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2776 assert_eq!(
2777 output,
2778 indoc! {r#"
2779 ## User
2780
2781 Hello from Zed!
2782
2783 ## Assistant
2784
2785 <thinking>
2786 Thinking hard!
2787 </thinking>
2788
2789 "#}
2790 );
2791 }
2792
2793 #[gpui::test]
2794 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2795 init_test(cx);
2796
2797 let fs = FakeFs::new(cx.executor());
2798 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2799 .await;
2800 let project = Project::test(fs.clone(), [], cx).await;
2801 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2802 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2803 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2804 move |_, thread, mut cx| {
2805 let read_file_tx = read_file_tx.clone();
2806 async move {
2807 let content = thread
2808 .update(&mut cx, |thread, cx| {
2809 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2810 })
2811 .unwrap()
2812 .await
2813 .unwrap();
2814 assert_eq!(content, "one\ntwo\nthree\n");
2815 read_file_tx.take().unwrap().send(()).unwrap();
2816 thread
2817 .update(&mut cx, |thread, cx| {
2818 thread.write_text_file(
2819 path!("/tmp/foo").into(),
2820 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2821 cx,
2822 )
2823 })
2824 .unwrap()
2825 .await
2826 .unwrap();
2827 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2828 }
2829 .boxed_local()
2830 },
2831 ));
2832
2833 let (worktree, pathbuf) = project
2834 .update(cx, |project, cx| {
2835 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2836 })
2837 .await
2838 .unwrap();
2839 let buffer = project
2840 .update(cx, |project, cx| {
2841 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2842 })
2843 .await
2844 .unwrap();
2845
2846 let thread = cx
2847 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2848 .await
2849 .unwrap();
2850
2851 let request = thread.update(cx, |thread, cx| {
2852 thread.send_raw("Extend the count in /tmp/foo", cx)
2853 });
2854 read_file_rx.await.ok();
2855 buffer.update(cx, |buffer, cx| {
2856 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2857 });
2858 cx.run_until_parked();
2859 assert_eq!(
2860 buffer.read_with(cx, |buffer, _| buffer.text()),
2861 "zero\none\ntwo\nthree\nfour\nfive\n"
2862 );
2863 assert_eq!(
2864 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2865 "zero\none\ntwo\nthree\nfour\nfive\n"
2866 );
2867 request.await.unwrap();
2868 }
2869
2870 #[gpui::test]
2871 async fn test_reading_from_line(cx: &mut TestAppContext) {
2872 init_test(cx);
2873
2874 let fs = FakeFs::new(cx.executor());
2875 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2876 .await;
2877 let project = Project::test(fs.clone(), [], cx).await;
2878 project
2879 .update(cx, |project, cx| {
2880 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2881 })
2882 .await
2883 .unwrap();
2884
2885 let connection = Rc::new(FakeAgentConnection::new());
2886
2887 let thread = cx
2888 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2889 .await
2890 .unwrap();
2891
2892 // Whole file
2893 let content = thread
2894 .update(cx, |thread, cx| {
2895 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2896 })
2897 .await
2898 .unwrap();
2899
2900 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2901
2902 // Only start line
2903 let content = thread
2904 .update(cx, |thread, cx| {
2905 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2906 })
2907 .await
2908 .unwrap();
2909
2910 assert_eq!(content, "three\nfour\n");
2911
2912 // Only limit
2913 let content = thread
2914 .update(cx, |thread, cx| {
2915 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2916 })
2917 .await
2918 .unwrap();
2919
2920 assert_eq!(content, "one\ntwo\n");
2921
2922 // Range
2923 let content = thread
2924 .update(cx, |thread, cx| {
2925 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2926 })
2927 .await
2928 .unwrap();
2929
2930 assert_eq!(content, "two\nthree\n");
2931
2932 // Invalid
2933 let err = thread
2934 .update(cx, |thread, cx| {
2935 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2936 })
2937 .await
2938 .unwrap_err();
2939
2940 assert_eq!(
2941 err.to_string(),
2942 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2943 );
2944 }
2945
2946 #[gpui::test]
2947 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2948 init_test(cx);
2949
2950 let fs = FakeFs::new(cx.executor());
2951 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2952 let project = Project::test(fs.clone(), [], cx).await;
2953 project
2954 .update(cx, |project, cx| {
2955 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2956 })
2957 .await
2958 .unwrap();
2959
2960 let connection = Rc::new(FakeAgentConnection::new());
2961
2962 let thread = cx
2963 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2964 .await
2965 .unwrap();
2966
2967 // Whole file
2968 let content = thread
2969 .update(cx, |thread, cx| {
2970 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2971 })
2972 .await
2973 .unwrap();
2974
2975 assert_eq!(content, "");
2976
2977 // Only start line
2978 let content = thread
2979 .update(cx, |thread, cx| {
2980 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2981 })
2982 .await
2983 .unwrap();
2984
2985 assert_eq!(content, "");
2986
2987 // Only limit
2988 let content = thread
2989 .update(cx, |thread, cx| {
2990 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2991 })
2992 .await
2993 .unwrap();
2994
2995 assert_eq!(content, "");
2996
2997 // Range
2998 let content = thread
2999 .update(cx, |thread, cx| {
3000 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
3001 })
3002 .await
3003 .unwrap();
3004
3005 assert_eq!(content, "");
3006
3007 // Invalid
3008 let err = thread
3009 .update(cx, |thread, cx| {
3010 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
3011 })
3012 .await
3013 .unwrap_err();
3014
3015 assert_eq!(
3016 err.to_string(),
3017 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3018 );
3019 }
3020 #[gpui::test]
3021 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3022 init_test(cx);
3023
3024 let fs = FakeFs::new(cx.executor());
3025 fs.insert_tree(path!("/tmp"), json!({})).await;
3026 let project = Project::test(fs.clone(), [], cx).await;
3027 project
3028 .update(cx, |project, cx| {
3029 project.find_or_create_worktree(path!("/tmp"), true, cx)
3030 })
3031 .await
3032 .unwrap();
3033
3034 let connection = Rc::new(FakeAgentConnection::new());
3035
3036 let thread = cx
3037 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3038 .await
3039 .unwrap();
3040
3041 // Out of project file
3042 let err = thread
3043 .update(cx, |thread, cx| {
3044 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3045 })
3046 .await
3047 .unwrap_err();
3048
3049 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3050 }
3051
3052 #[gpui::test]
3053 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3054 init_test(cx);
3055
3056 let fs = FakeFs::new(cx.executor());
3057 let project = Project::test(fs, [], cx).await;
3058 let id = acp::ToolCallId::new("test");
3059
3060 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3061 let id = id.clone();
3062 move |_, thread, mut cx| {
3063 let id = id.clone();
3064 async move {
3065 thread
3066 .update(&mut cx, |thread, cx| {
3067 thread.handle_session_update(
3068 acp::SessionUpdate::ToolCall(
3069 acp::ToolCall::new(id.clone(), "Label")
3070 .kind(acp::ToolKind::Fetch)
3071 .status(acp::ToolCallStatus::InProgress),
3072 ),
3073 cx,
3074 )
3075 })
3076 .unwrap()
3077 .unwrap();
3078 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3079 }
3080 .boxed_local()
3081 }
3082 }));
3083
3084 let thread = cx
3085 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3086 .await
3087 .unwrap();
3088
3089 let request = thread.update(cx, |thread, cx| {
3090 thread.send_raw("Fetch https://example.com", cx)
3091 });
3092
3093 run_until_first_tool_call(&thread, cx).await;
3094
3095 thread.read_with(cx, |thread, _| {
3096 assert!(matches!(
3097 thread.entries[1],
3098 AgentThreadEntry::ToolCall(ToolCall {
3099 status: ToolCallStatus::InProgress,
3100 ..
3101 })
3102 ));
3103 });
3104
3105 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3106
3107 thread.read_with(cx, |thread, _| {
3108 assert!(matches!(
3109 &thread.entries[1],
3110 AgentThreadEntry::ToolCall(ToolCall {
3111 status: ToolCallStatus::Canceled,
3112 ..
3113 })
3114 ));
3115 });
3116
3117 thread
3118 .update(cx, |thread, cx| {
3119 thread.handle_session_update(
3120 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3121 id,
3122 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3123 )),
3124 cx,
3125 )
3126 })
3127 .unwrap();
3128
3129 request.await.unwrap();
3130
3131 thread.read_with(cx, |thread, _| {
3132 assert!(matches!(
3133 thread.entries[1],
3134 AgentThreadEntry::ToolCall(ToolCall {
3135 status: ToolCallStatus::Completed,
3136 ..
3137 })
3138 ));
3139 });
3140 }
3141
3142 #[gpui::test]
3143 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3144 init_test(cx);
3145 let fs = FakeFs::new(cx.background_executor.clone());
3146 fs.insert_tree(path!("/test"), json!({})).await;
3147 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3148
3149 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3150 move |_, thread, mut cx| {
3151 async move {
3152 thread
3153 .update(&mut cx, |thread, cx| {
3154 thread.handle_session_update(
3155 acp::SessionUpdate::ToolCall(
3156 acp::ToolCall::new("test", "Label")
3157 .kind(acp::ToolKind::Edit)
3158 .status(acp::ToolCallStatus::Completed)
3159 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3160 "/test/test.txt",
3161 "foo",
3162 ))]),
3163 ),
3164 cx,
3165 )
3166 })
3167 .unwrap()
3168 .unwrap();
3169 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3170 }
3171 .boxed_local()
3172 }
3173 }));
3174
3175 let thread = cx
3176 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3177 .await
3178 .unwrap();
3179
3180 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3181 .await
3182 .unwrap();
3183
3184 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3185 }
3186
3187 #[gpui::test(iterations = 10)]
3188 async fn test_checkpoints(cx: &mut TestAppContext) {
3189 init_test(cx);
3190 let fs = FakeFs::new(cx.background_executor.clone());
3191 fs.insert_tree(
3192 path!("/test"),
3193 json!({
3194 ".git": {}
3195 }),
3196 )
3197 .await;
3198 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3199
3200 let simulate_changes = Arc::new(AtomicBool::new(true));
3201 let next_filename = Arc::new(AtomicUsize::new(0));
3202 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3203 let simulate_changes = simulate_changes.clone();
3204 let next_filename = next_filename.clone();
3205 let fs = fs.clone();
3206 move |request, thread, mut cx| {
3207 let fs = fs.clone();
3208 let simulate_changes = simulate_changes.clone();
3209 let next_filename = next_filename.clone();
3210 async move {
3211 if simulate_changes.load(SeqCst) {
3212 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3213 fs.write(Path::new(&filename), b"").await?;
3214 }
3215
3216 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3217 panic!("expected text content block");
3218 };
3219 thread.update(&mut cx, |thread, cx| {
3220 thread
3221 .handle_session_update(
3222 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3223 content.text.to_uppercase().into(),
3224 )),
3225 cx,
3226 )
3227 .unwrap();
3228 })?;
3229 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3230 }
3231 .boxed_local()
3232 }
3233 }));
3234 let thread = cx
3235 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3236 .await
3237 .unwrap();
3238
3239 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3240 .await
3241 .unwrap();
3242 thread.read_with(cx, |thread, cx| {
3243 assert_eq!(
3244 thread.to_markdown(cx),
3245 indoc! {"
3246 ## User (checkpoint)
3247
3248 Lorem
3249
3250 ## Assistant
3251
3252 LOREM
3253
3254 "}
3255 );
3256 });
3257 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3258
3259 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3260 .await
3261 .unwrap();
3262 thread.read_with(cx, |thread, cx| {
3263 assert_eq!(
3264 thread.to_markdown(cx),
3265 indoc! {"
3266 ## User (checkpoint)
3267
3268 Lorem
3269
3270 ## Assistant
3271
3272 LOREM
3273
3274 ## User (checkpoint)
3275
3276 ipsum
3277
3278 ## Assistant
3279
3280 IPSUM
3281
3282 "}
3283 );
3284 });
3285 assert_eq!(
3286 fs.files(),
3287 vec![
3288 Path::new(path!("/test/file-0")),
3289 Path::new(path!("/test/file-1"))
3290 ]
3291 );
3292
3293 // Checkpoint isn't stored when there are no changes.
3294 simulate_changes.store(false, SeqCst);
3295 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3296 .await
3297 .unwrap();
3298 thread.read_with(cx, |thread, cx| {
3299 assert_eq!(
3300 thread.to_markdown(cx),
3301 indoc! {"
3302 ## User (checkpoint)
3303
3304 Lorem
3305
3306 ## Assistant
3307
3308 LOREM
3309
3310 ## User (checkpoint)
3311
3312 ipsum
3313
3314 ## Assistant
3315
3316 IPSUM
3317
3318 ## User
3319
3320 dolor
3321
3322 ## Assistant
3323
3324 DOLOR
3325
3326 "}
3327 );
3328 });
3329 assert_eq!(
3330 fs.files(),
3331 vec![
3332 Path::new(path!("/test/file-0")),
3333 Path::new(path!("/test/file-1"))
3334 ]
3335 );
3336
3337 // Rewinding the conversation truncates the history and restores the checkpoint.
3338 thread
3339 .update(cx, |thread, cx| {
3340 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3341 panic!("unexpected entries {:?}", thread.entries)
3342 };
3343 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3344 })
3345 .await
3346 .unwrap();
3347 thread.read_with(cx, |thread, cx| {
3348 assert_eq!(
3349 thread.to_markdown(cx),
3350 indoc! {"
3351 ## User (checkpoint)
3352
3353 Lorem
3354
3355 ## Assistant
3356
3357 LOREM
3358
3359 "}
3360 );
3361 });
3362 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3363 }
3364
3365 #[gpui::test]
3366 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3367 use std::sync::atomic::AtomicUsize;
3368 init_test(cx);
3369
3370 let fs = FakeFs::new(cx.executor());
3371 let project = Project::test(fs, None, cx).await;
3372
3373 // Create a connection that simulates refusal after tool result
3374 let prompt_count = Arc::new(AtomicUsize::new(0));
3375 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3376 let prompt_count = prompt_count.clone();
3377 move |_request, thread, mut cx| {
3378 let count = prompt_count.fetch_add(1, SeqCst);
3379 async move {
3380 if count == 0 {
3381 // First prompt: Generate a tool call with result
3382 thread.update(&mut cx, |thread, cx| {
3383 thread
3384 .handle_session_update(
3385 acp::SessionUpdate::ToolCall(
3386 acp::ToolCall::new("tool1", "Test Tool")
3387 .kind(acp::ToolKind::Fetch)
3388 .status(acp::ToolCallStatus::Completed)
3389 .raw_input(serde_json::json!({"query": "test"}))
3390 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3391 ),
3392 cx,
3393 )
3394 .unwrap();
3395 })?;
3396
3397 // Now return refusal because of the tool result
3398 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3399 } else {
3400 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3401 }
3402 }
3403 .boxed_local()
3404 }
3405 }));
3406
3407 let thread = cx
3408 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3409 .await
3410 .unwrap();
3411
3412 // Track if we see a Refusal event
3413 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3414 let saw_refusal_event_captured = saw_refusal_event.clone();
3415 thread.update(cx, |_thread, cx| {
3416 cx.subscribe(
3417 &thread,
3418 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3419 if matches!(event, AcpThreadEvent::Refusal) {
3420 *saw_refusal_event_captured.lock().unwrap() = true;
3421 }
3422 },
3423 )
3424 .detach();
3425 });
3426
3427 // Send a user message - this will trigger tool call and then refusal
3428 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3429 cx.background_executor.spawn(send_task).detach();
3430 cx.run_until_parked();
3431
3432 // Verify that:
3433 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3434 // 2. The user message was NOT truncated
3435 assert!(
3436 *saw_refusal_event.lock().unwrap(),
3437 "Refusal event should be emitted for tool result refusals"
3438 );
3439
3440 thread.read_with(cx, |thread, _| {
3441 let entries = thread.entries();
3442 assert!(entries.len() >= 2, "Should have user message and tool call");
3443
3444 // Verify user message is still there
3445 assert!(
3446 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3447 "User message should not be truncated"
3448 );
3449
3450 // Verify tool call is there with result
3451 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3452 assert!(
3453 tool_call.raw_output.is_some(),
3454 "Tool call should have output"
3455 );
3456 } else {
3457 panic!("Expected tool call at index 1");
3458 }
3459 });
3460 }
3461
3462 #[gpui::test]
3463 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3464 init_test(cx);
3465
3466 let fs = FakeFs::new(cx.executor());
3467 let project = Project::test(fs, None, cx).await;
3468
3469 let refuse_next = Arc::new(AtomicBool::new(false));
3470 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3471 let refuse_next = refuse_next.clone();
3472 move |_request, _thread, _cx| {
3473 if refuse_next.load(SeqCst) {
3474 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3475 .boxed_local()
3476 } else {
3477 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3478 .boxed_local()
3479 }
3480 }
3481 }));
3482
3483 let thread = cx
3484 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3485 .await
3486 .unwrap();
3487
3488 // Track if we see a Refusal event
3489 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3490 let saw_refusal_event_captured = saw_refusal_event.clone();
3491 thread.update(cx, |_thread, cx| {
3492 cx.subscribe(
3493 &thread,
3494 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3495 if matches!(event, AcpThreadEvent::Refusal) {
3496 *saw_refusal_event_captured.lock().unwrap() = true;
3497 }
3498 },
3499 )
3500 .detach();
3501 });
3502
3503 // Send a message that will be refused
3504 refuse_next.store(true, SeqCst);
3505 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3506 .await
3507 .unwrap();
3508
3509 // Verify that a Refusal event WAS emitted for user prompt refusal
3510 assert!(
3511 *saw_refusal_event.lock().unwrap(),
3512 "Refusal event should be emitted for user prompt refusals"
3513 );
3514
3515 // Verify the message was truncated (user prompt refusal)
3516 thread.read_with(cx, |thread, cx| {
3517 assert_eq!(thread.to_markdown(cx), "");
3518 });
3519 }
3520
3521 #[gpui::test]
3522 async fn test_refusal(cx: &mut TestAppContext) {
3523 init_test(cx);
3524 let fs = FakeFs::new(cx.background_executor.clone());
3525 fs.insert_tree(path!("/"), json!({})).await;
3526 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3527
3528 let refuse_next = Arc::new(AtomicBool::new(false));
3529 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3530 let refuse_next = refuse_next.clone();
3531 move |request, thread, mut cx| {
3532 let refuse_next = refuse_next.clone();
3533 async move {
3534 if refuse_next.load(SeqCst) {
3535 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3536 }
3537
3538 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3539 panic!("expected text content block");
3540 };
3541 thread.update(&mut cx, |thread, cx| {
3542 thread
3543 .handle_session_update(
3544 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3545 content.text.to_uppercase().into(),
3546 )),
3547 cx,
3548 )
3549 .unwrap();
3550 })?;
3551 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3552 }
3553 .boxed_local()
3554 }
3555 }));
3556 let thread = cx
3557 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3558 .await
3559 .unwrap();
3560
3561 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3562 .await
3563 .unwrap();
3564 thread.read_with(cx, |thread, cx| {
3565 assert_eq!(
3566 thread.to_markdown(cx),
3567 indoc! {"
3568 ## User
3569
3570 hello
3571
3572 ## Assistant
3573
3574 HELLO
3575
3576 "}
3577 );
3578 });
3579
3580 // Simulate refusing the second message. The message should be truncated
3581 // when a user prompt is refused.
3582 refuse_next.store(true, SeqCst);
3583 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3584 .await
3585 .unwrap();
3586 thread.read_with(cx, |thread, cx| {
3587 assert_eq!(
3588 thread.to_markdown(cx),
3589 indoc! {"
3590 ## User
3591
3592 hello
3593
3594 ## Assistant
3595
3596 HELLO
3597
3598 "}
3599 );
3600 });
3601 }
3602
3603 async fn run_until_first_tool_call(
3604 thread: &Entity<AcpThread>,
3605 cx: &mut TestAppContext,
3606 ) -> usize {
3607 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3608
3609 let subscription = cx.update(|cx| {
3610 cx.subscribe(thread, move |thread, _, cx| {
3611 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3612 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3613 return tx.try_send(ix).unwrap();
3614 }
3615 }
3616 })
3617 });
3618
3619 select! {
3620 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3621 panic!("Timeout waiting for tool call")
3622 }
3623 ix = rx.next().fuse() => {
3624 drop(subscription);
3625 ix.unwrap()
3626 }
3627 }
3628 }
3629
3630 #[derive(Clone, Default)]
3631 struct FakeAgentConnection {
3632 auth_methods: Vec<acp::AuthMethod>,
3633 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3634 on_user_message: Option<
3635 Rc<
3636 dyn Fn(
3637 acp::PromptRequest,
3638 WeakEntity<AcpThread>,
3639 AsyncApp,
3640 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3641 + 'static,
3642 >,
3643 >,
3644 }
3645
3646 impl FakeAgentConnection {
3647 fn new() -> Self {
3648 Self {
3649 auth_methods: Vec::new(),
3650 on_user_message: None,
3651 sessions: Arc::default(),
3652 }
3653 }
3654
3655 #[expect(unused)]
3656 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3657 self.auth_methods = auth_methods;
3658 self
3659 }
3660
3661 fn on_user_message(
3662 mut self,
3663 handler: impl Fn(
3664 acp::PromptRequest,
3665 WeakEntity<AcpThread>,
3666 AsyncApp,
3667 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3668 + 'static,
3669 ) -> Self {
3670 self.on_user_message.replace(Rc::new(handler));
3671 self
3672 }
3673 }
3674
3675 impl AgentConnection for FakeAgentConnection {
3676 fn telemetry_id(&self) -> SharedString {
3677 "fake".into()
3678 }
3679
3680 fn auth_methods(&self) -> &[acp::AuthMethod] {
3681 &self.auth_methods
3682 }
3683
3684 fn new_thread(
3685 self: Rc<Self>,
3686 project: Entity<Project>,
3687 _cwd: &Path,
3688 cx: &mut App,
3689 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3690 let session_id = acp::SessionId::new(
3691 rand::rng()
3692 .sample_iter(&distr::Alphanumeric)
3693 .take(7)
3694 .map(char::from)
3695 .collect::<String>(),
3696 );
3697 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3698 let thread = cx.new(|cx| {
3699 AcpThread::new(
3700 "Test",
3701 self.clone(),
3702 project,
3703 action_log,
3704 session_id.clone(),
3705 watch::Receiver::constant(
3706 acp::PromptCapabilities::new()
3707 .image(true)
3708 .audio(true)
3709 .embedded_context(true),
3710 ),
3711 cx,
3712 )
3713 });
3714 self.sessions.lock().insert(session_id, thread.downgrade());
3715 Task::ready(Ok(thread))
3716 }
3717
3718 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3719 if self.auth_methods().iter().any(|m| m.id == method) {
3720 Task::ready(Ok(()))
3721 } else {
3722 Task::ready(Err(anyhow!("Invalid Auth Method")))
3723 }
3724 }
3725
3726 fn prompt(
3727 &self,
3728 _id: Option<UserMessageId>,
3729 params: acp::PromptRequest,
3730 cx: &mut App,
3731 ) -> Task<gpui::Result<acp::PromptResponse>> {
3732 let sessions = self.sessions.lock();
3733 let thread = sessions.get(¶ms.session_id).unwrap();
3734 if let Some(handler) = &self.on_user_message {
3735 let handler = handler.clone();
3736 let thread = thread.clone();
3737 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3738 } else {
3739 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3740 }
3741 }
3742
3743 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3744 let sessions = self.sessions.lock();
3745 let thread = sessions.get(session_id).unwrap().clone();
3746
3747 cx.spawn(async move |cx| {
3748 thread
3749 .update(cx, |thread, cx| thread.cancel(cx))
3750 .unwrap()
3751 .await
3752 })
3753 .detach();
3754 }
3755
3756 fn truncate(
3757 &self,
3758 session_id: &acp::SessionId,
3759 _cx: &App,
3760 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3761 Some(Rc::new(FakeAgentSessionEditor {
3762 _session_id: session_id.clone(),
3763 }))
3764 }
3765
3766 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3767 self
3768 }
3769 }
3770
3771 struct FakeAgentSessionEditor {
3772 _session_id: acp::SessionId,
3773 }
3774
3775 impl AgentSessionTruncate for FakeAgentSessionEditor {
3776 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3777 Task::ready(Ok(()))
3778 }
3779 }
3780
3781 #[gpui::test]
3782 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3783 init_test(cx);
3784
3785 let fs = FakeFs::new(cx.executor());
3786 let project = Project::test(fs, [], cx).await;
3787 let connection = Rc::new(FakeAgentConnection::new());
3788 let thread = cx
3789 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3790 .await
3791 .unwrap();
3792
3793 // Try to update a tool call that doesn't exist
3794 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3795 thread.update(cx, |thread, cx| {
3796 let result = thread.handle_session_update(
3797 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3798 nonexistent_id.clone(),
3799 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3800 )),
3801 cx,
3802 );
3803
3804 // The update should succeed (not return an error)
3805 assert!(result.is_ok());
3806
3807 // There should now be exactly one entry in the thread
3808 assert_eq!(thread.entries.len(), 1);
3809
3810 // The entry should be a failed tool call
3811 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3812 assert_eq!(tool_call.id, nonexistent_id);
3813 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3814 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3815
3816 // Check that the content contains the error message
3817 assert_eq!(tool_call.content.len(), 1);
3818 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3819 match content_block {
3820 ContentBlock::Markdown { markdown } => {
3821 let markdown_text = markdown.read(cx).source();
3822 assert!(markdown_text.contains("Tool call not found"));
3823 }
3824 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3825 ContentBlock::ResourceLink { .. } => {
3826 panic!("Expected markdown content, got resource link")
3827 }
3828 ContentBlock::Image { .. } => {
3829 panic!("Expected markdown content, got image")
3830 }
3831 }
3832 } else {
3833 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3834 }
3835 } else {
3836 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3837 }
3838 });
3839 }
3840
3841 /// Tests that restoring a checkpoint properly cleans up terminals that were
3842 /// created after that checkpoint, and cancels any in-progress generation.
3843 ///
3844 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3845 /// that were started after that checkpoint should be terminated, and any in-progress
3846 /// AI generation should be canceled.
3847 #[gpui::test]
3848 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3849 init_test(cx);
3850
3851 let fs = FakeFs::new(cx.executor());
3852 let project = Project::test(fs, [], cx).await;
3853 let connection = Rc::new(FakeAgentConnection::new());
3854 let thread = cx
3855 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3856 .await
3857 .unwrap();
3858
3859 // Send first user message to create a checkpoint
3860 cx.update(|cx| {
3861 thread.update(cx, |thread, cx| {
3862 thread.send(vec!["first message".into()], cx)
3863 })
3864 })
3865 .await
3866 .unwrap();
3867
3868 // Send second message (creates another checkpoint) - we'll restore to this one
3869 cx.update(|cx| {
3870 thread.update(cx, |thread, cx| {
3871 thread.send(vec!["second message".into()], cx)
3872 })
3873 })
3874 .await
3875 .unwrap();
3876
3877 // Create 2 terminals BEFORE the checkpoint that have completed running
3878 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3879 let mock_terminal_1 = cx.new(|cx| {
3880 let builder = ::terminal::TerminalBuilder::new_display_only(
3881 ::terminal::terminal_settings::CursorShape::default(),
3882 ::terminal::terminal_settings::AlternateScroll::On,
3883 None,
3884 0,
3885 )
3886 .unwrap();
3887 builder.subscribe(cx)
3888 });
3889
3890 thread.update(cx, |thread, cx| {
3891 thread.on_terminal_provider_event(
3892 TerminalProviderEvent::Created {
3893 terminal_id: terminal_id_1.clone(),
3894 label: "echo 'first'".to_string(),
3895 cwd: Some(PathBuf::from("/test")),
3896 output_byte_limit: None,
3897 terminal: mock_terminal_1.clone(),
3898 },
3899 cx,
3900 );
3901 });
3902
3903 thread.update(cx, |thread, cx| {
3904 thread.on_terminal_provider_event(
3905 TerminalProviderEvent::Output {
3906 terminal_id: terminal_id_1.clone(),
3907 data: b"first\n".to_vec(),
3908 },
3909 cx,
3910 );
3911 });
3912
3913 thread.update(cx, |thread, cx| {
3914 thread.on_terminal_provider_event(
3915 TerminalProviderEvent::Exit {
3916 terminal_id: terminal_id_1.clone(),
3917 status: acp::TerminalExitStatus::new().exit_code(0),
3918 },
3919 cx,
3920 );
3921 });
3922
3923 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3924 let mock_terminal_2 = cx.new(|cx| {
3925 let builder = ::terminal::TerminalBuilder::new_display_only(
3926 ::terminal::terminal_settings::CursorShape::default(),
3927 ::terminal::terminal_settings::AlternateScroll::On,
3928 None,
3929 0,
3930 )
3931 .unwrap();
3932 builder.subscribe(cx)
3933 });
3934
3935 thread.update(cx, |thread, cx| {
3936 thread.on_terminal_provider_event(
3937 TerminalProviderEvent::Created {
3938 terminal_id: terminal_id_2.clone(),
3939 label: "echo 'second'".to_string(),
3940 cwd: Some(PathBuf::from("/test")),
3941 output_byte_limit: None,
3942 terminal: mock_terminal_2.clone(),
3943 },
3944 cx,
3945 );
3946 });
3947
3948 thread.update(cx, |thread, cx| {
3949 thread.on_terminal_provider_event(
3950 TerminalProviderEvent::Output {
3951 terminal_id: terminal_id_2.clone(),
3952 data: b"second\n".to_vec(),
3953 },
3954 cx,
3955 );
3956 });
3957
3958 thread.update(cx, |thread, cx| {
3959 thread.on_terminal_provider_event(
3960 TerminalProviderEvent::Exit {
3961 terminal_id: terminal_id_2.clone(),
3962 status: acp::TerminalExitStatus::new().exit_code(0),
3963 },
3964 cx,
3965 );
3966 });
3967
3968 // Get the second message ID to restore to
3969 let second_message_id = thread.read_with(cx, |thread, _| {
3970 // At this point we have:
3971 // - Index 0: First user message (with checkpoint)
3972 // - Index 1: Second user message (with checkpoint)
3973 // No assistant responses because FakeAgentConnection just returns EndTurn
3974 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3975 panic!("expected user message at index 1");
3976 };
3977 message.id.clone().unwrap()
3978 });
3979
3980 // Create a terminal AFTER the checkpoint we'll restore to.
3981 // This simulates the AI agent starting a long-running terminal command.
3982 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3983 let mock_terminal = cx.new(|cx| {
3984 let builder = ::terminal::TerminalBuilder::new_display_only(
3985 ::terminal::terminal_settings::CursorShape::default(),
3986 ::terminal::terminal_settings::AlternateScroll::On,
3987 None,
3988 0,
3989 )
3990 .unwrap();
3991 builder.subscribe(cx)
3992 });
3993
3994 // Register the terminal as created
3995 thread.update(cx, |thread, cx| {
3996 thread.on_terminal_provider_event(
3997 TerminalProviderEvent::Created {
3998 terminal_id: terminal_id.clone(),
3999 label: "sleep 1000".to_string(),
4000 cwd: Some(PathBuf::from("/test")),
4001 output_byte_limit: None,
4002 terminal: mock_terminal.clone(),
4003 },
4004 cx,
4005 );
4006 });
4007
4008 // Simulate the terminal producing output (still running)
4009 thread.update(cx, |thread, cx| {
4010 thread.on_terminal_provider_event(
4011 TerminalProviderEvent::Output {
4012 terminal_id: terminal_id.clone(),
4013 data: b"terminal is running...\n".to_vec(),
4014 },
4015 cx,
4016 );
4017 });
4018
4019 // Create a tool call entry that references this terminal
4020 // This represents the agent requesting a terminal command
4021 thread.update(cx, |thread, cx| {
4022 thread
4023 .handle_session_update(
4024 acp::SessionUpdate::ToolCall(
4025 acp::ToolCall::new("terminal-tool-1", "Running command")
4026 .kind(acp::ToolKind::Execute)
4027 .status(acp::ToolCallStatus::InProgress)
4028 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4029 terminal_id.clone(),
4030 ))])
4031 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4032 ),
4033 cx,
4034 )
4035 .unwrap();
4036 });
4037
4038 // Verify terminal exists and is in the thread
4039 let terminal_exists_before =
4040 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4041 assert!(
4042 terminal_exists_before,
4043 "Terminal should exist before checkpoint restore"
4044 );
4045
4046 // Verify the terminal's underlying task is still running (not completed)
4047 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4048 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4049 terminal_entity.read_with(cx, |term, _cx| {
4050 term.output().is_none() // output is None means it's still running
4051 })
4052 });
4053 assert!(
4054 terminal_running_before,
4055 "Terminal should be running before checkpoint restore"
4056 );
4057
4058 // Verify we have the expected entries before restore
4059 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4060 assert!(
4061 entry_count_before > 1,
4062 "Should have multiple entries before restore"
4063 );
4064
4065 // Restore the checkpoint to the second message.
4066 // This should:
4067 // 1. Cancel any in-progress generation (via the cancel() call)
4068 // 2. Remove the terminal that was created after that point
4069 thread
4070 .update(cx, |thread, cx| {
4071 thread.restore_checkpoint(second_message_id, cx)
4072 })
4073 .await
4074 .unwrap();
4075
4076 // Verify that no send_task is in progress after restore
4077 // (cancel() clears the send_task)
4078 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4079 assert!(
4080 !has_send_task_after,
4081 "Should not have a send_task after restore (cancel should have cleared it)"
4082 );
4083
4084 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4085 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4086 assert_eq!(
4087 entry_count, 1,
4088 "Should have 1 entry after restore (only the first user message)"
4089 );
4090
4091 // Verify the 2 completed terminals from before the checkpoint still exist
4092 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4093 thread.terminals.contains_key(&terminal_id_1)
4094 });
4095 assert!(
4096 terminal_1_exists,
4097 "Terminal 1 (from before checkpoint) should still exist"
4098 );
4099
4100 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4101 thread.terminals.contains_key(&terminal_id_2)
4102 });
4103 assert!(
4104 terminal_2_exists,
4105 "Terminal 2 (from before checkpoint) should still exist"
4106 );
4107
4108 // Verify they're still in completed state
4109 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4110 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4111 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4112 });
4113 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4114
4115 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4116 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4117 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4118 });
4119 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4120
4121 // Verify the running terminal (created after checkpoint) was removed
4122 let terminal_3_exists =
4123 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4124 assert!(
4125 !terminal_3_exists,
4126 "Terminal 3 (created after checkpoint) should have been removed"
4127 );
4128
4129 // Verify total count is 2 (the two from before the checkpoint)
4130 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4131 assert_eq!(
4132 terminal_count, 2,
4133 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4134 );
4135 }
4136
4137 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4138 /// even when a new user message is added while the async checkpoint comparison is in progress.
4139 ///
4140 /// This is a regression test for a bug where update_last_checkpoint would fail with
4141 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4142 /// update_last_checkpoint started and when its async closure ran.
4143 #[gpui::test]
4144 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4145 init_test(cx);
4146
4147 let fs = FakeFs::new(cx.executor());
4148 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4149 .await;
4150 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4151
4152 let handler_done = Arc::new(AtomicBool::new(false));
4153 let handler_done_clone = handler_done.clone();
4154 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4155 move |_, _thread, _cx| {
4156 handler_done_clone.store(true, SeqCst);
4157 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4158 },
4159 ));
4160
4161 let thread = cx
4162 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4163 .await
4164 .unwrap();
4165
4166 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4167 let send_task = cx.background_executor.spawn(send_future);
4168
4169 // Tick until handler completes, then a few more to let update_last_checkpoint start
4170 while !handler_done.load(SeqCst) {
4171 cx.executor().tick();
4172 }
4173 for _ in 0..5 {
4174 cx.executor().tick();
4175 }
4176
4177 thread.update(cx, |thread, cx| {
4178 thread.push_entry(
4179 AgentThreadEntry::UserMessage(UserMessage {
4180 id: Some(UserMessageId::new()),
4181 content: ContentBlock::Empty,
4182 chunks: vec!["Injected message (no checkpoint)".into()],
4183 checkpoint: None,
4184 indented: false,
4185 }),
4186 cx,
4187 );
4188 });
4189
4190 cx.run_until_parked();
4191 let result = send_task.await;
4192
4193 assert!(
4194 result.is_ok(),
4195 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4196 result.err()
4197 );
4198 }
4199}