1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer.update(cx, |buffer, _| {
379 let snapshot = buffer.snapshot();
380 if let Some(row) = location.line {
381 let column = snapshot.indent_size_for_line(row).len;
382 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
383 snapshot.anchor_before(point)
384 } else {
385 Anchor::min_for_buffer(snapshot.remote_id())
386 }
387 });
388
389 Some(ResolvedLocation { buffer, position })
390 }
391
392 fn resolve_locations(
393 &self,
394 project: Entity<Project>,
395 cx: &mut App,
396 ) -> Task<Vec<Option<ResolvedLocation>>> {
397 let locations = self.locations.clone();
398 project.update(cx, |_, cx| {
399 cx.spawn(async move |project, cx| {
400 let mut new_locations = Vec::new();
401 for location in locations {
402 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
403 }
404 new_locations
405 })
406 })
407 }
408}
409
410// Separate so we can hold a strong reference to the buffer
411// for saving on the thread
412#[derive(Clone, Debug, PartialEq, Eq)]
413struct ResolvedLocation {
414 buffer: Entity<Buffer>,
415 position: Anchor,
416}
417
418impl From<&ResolvedLocation> for AgentLocation {
419 fn from(value: &ResolvedLocation) -> Self {
420 Self {
421 buffer: value.buffer.downgrade(),
422 position: value.position,
423 }
424 }
425}
426
427#[derive(Debug)]
428pub enum ToolCallStatus {
429 /// The tool call hasn't started running yet, but we start showing it to
430 /// the user.
431 Pending,
432 /// The tool call is waiting for confirmation from the user.
433 WaitingForConfirmation {
434 options: Vec<acp::PermissionOption>,
435 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
436 },
437 /// The tool call is currently running.
438 InProgress,
439 /// The tool call completed successfully.
440 Completed,
441 /// The tool call failed.
442 Failed,
443 /// The user rejected the tool call.
444 Rejected,
445 /// The user canceled generation so the tool call was canceled.
446 Canceled,
447}
448
449impl From<acp::ToolCallStatus> for ToolCallStatus {
450 fn from(status: acp::ToolCallStatus) -> Self {
451 match status {
452 acp::ToolCallStatus::Pending => Self::Pending,
453 acp::ToolCallStatus::InProgress => Self::InProgress,
454 acp::ToolCallStatus::Completed => Self::Completed,
455 acp::ToolCallStatus::Failed => Self::Failed,
456 _ => Self::Pending,
457 }
458 }
459}
460
461impl Display for ToolCallStatus {
462 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463 write!(
464 f,
465 "{}",
466 match self {
467 ToolCallStatus::Pending => "Pending",
468 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
469 ToolCallStatus::InProgress => "In Progress",
470 ToolCallStatus::Completed => "Completed",
471 ToolCallStatus::Failed => "Failed",
472 ToolCallStatus::Rejected => "Rejected",
473 ToolCallStatus::Canceled => "Canceled",
474 }
475 )
476 }
477}
478
479#[derive(Debug, PartialEq, Clone)]
480pub enum ContentBlock {
481 Empty,
482 Markdown { markdown: Entity<Markdown> },
483 ResourceLink { resource_link: acp::ResourceLink },
484 Image { image: Arc<gpui::Image> },
485}
486
487impl ContentBlock {
488 pub fn new(
489 block: acp::ContentBlock,
490 language_registry: &Arc<LanguageRegistry>,
491 path_style: PathStyle,
492 cx: &mut App,
493 ) -> Self {
494 let mut this = Self::Empty;
495 this.append(block, language_registry, path_style, cx);
496 this
497 }
498
499 pub fn new_combined(
500 blocks: impl IntoIterator<Item = acp::ContentBlock>,
501 language_registry: Arc<LanguageRegistry>,
502 path_style: PathStyle,
503 cx: &mut App,
504 ) -> Self {
505 let mut this = Self::Empty;
506 for block in blocks {
507 this.append(block, &language_registry, path_style, cx);
508 }
509 this
510 }
511
512 pub fn append(
513 &mut self,
514 block: acp::ContentBlock,
515 language_registry: &Arc<LanguageRegistry>,
516 path_style: PathStyle,
517 cx: &mut App,
518 ) {
519 match (&mut *self, &block) {
520 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
521 *self = ContentBlock::ResourceLink {
522 resource_link: resource_link.clone(),
523 };
524 }
525 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
526 if let Some(image) = Self::decode_image(image_content) {
527 *self = ContentBlock::Image { image };
528 } else {
529 let new_content = Self::image_md(image_content);
530 *self = Self::create_markdown_block(new_content, language_registry, cx);
531 }
532 }
533 (ContentBlock::Empty, _) => {
534 let new_content = Self::block_string_contents(&block, path_style);
535 *self = Self::create_markdown_block(new_content, language_registry, cx);
536 }
537 (ContentBlock::Markdown { markdown }, _) => {
538 let new_content = Self::block_string_contents(&block, path_style);
539 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
540 }
541 (ContentBlock::ResourceLink { resource_link }, _) => {
542 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
543 let new_content = Self::block_string_contents(&block, path_style);
544 let combined = format!("{}\n{}", existing_content, new_content);
545 *self = Self::create_markdown_block(combined, language_registry, cx);
546 }
547 (ContentBlock::Image { .. }, _) => {
548 let new_content = Self::block_string_contents(&block, path_style);
549 let combined = format!("`Image`\n{}", new_content);
550 *self = Self::create_markdown_block(combined, language_registry, cx);
551 }
552 }
553 }
554
555 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
556 use base64::Engine as _;
557
558 let bytes = base64::engine::general_purpose::STANDARD
559 .decode(image_content.data.as_bytes())
560 .ok()?;
561 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
562 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
563 }
564
565 fn create_markdown_block(
566 content: String,
567 language_registry: &Arc<LanguageRegistry>,
568 cx: &mut App,
569 ) -> ContentBlock {
570 ContentBlock::Markdown {
571 markdown: cx
572 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
573 }
574 }
575
576 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
577 match block {
578 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
579 acp::ContentBlock::ResourceLink(resource_link) => {
580 Self::resource_link_md(&resource_link.uri, path_style)
581 }
582 acp::ContentBlock::Resource(acp::EmbeddedResource {
583 resource:
584 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
585 uri,
586 ..
587 }),
588 ..
589 }) => Self::resource_link_md(uri, path_style),
590 acp::ContentBlock::Image(image) => Self::image_md(image),
591 _ => String::new(),
592 }
593 }
594
595 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
596 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
597 uri.as_link().to_string()
598 } else {
599 uri.to_string()
600 }
601 }
602
603 fn image_md(_image: &acp::ImageContent) -> String {
604 "`Image`".into()
605 }
606
607 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
608 match self {
609 ContentBlock::Empty => "",
610 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
611 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
612 ContentBlock::Image { .. } => "`Image`",
613 }
614 }
615
616 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
617 match self {
618 ContentBlock::Empty => None,
619 ContentBlock::Markdown { markdown } => Some(markdown),
620 ContentBlock::ResourceLink { .. } => None,
621 ContentBlock::Image { .. } => None,
622 }
623 }
624
625 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
626 match self {
627 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
628 _ => None,
629 }
630 }
631
632 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
633 match self {
634 ContentBlock::Image { image } => Some(image),
635 _ => None,
636 }
637 }
638}
639
640#[derive(Debug)]
641pub enum ToolCallContent {
642 ContentBlock(ContentBlock),
643 Diff(Entity<Diff>),
644 Terminal(Entity<Terminal>),
645}
646
647impl ToolCallContent {
648 pub fn from_acp(
649 content: acp::ToolCallContent,
650 language_registry: Arc<LanguageRegistry>,
651 path_style: PathStyle,
652 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
653 cx: &mut App,
654 ) -> Result<Option<Self>> {
655 match content {
656 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
657 Ok(Some(Self::ContentBlock(ContentBlock::new(
658 content,
659 &language_registry,
660 path_style,
661 cx,
662 ))))
663 }
664 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
665 Diff::finalized(
666 diff.path.to_string_lossy().into_owned(),
667 diff.old_text,
668 diff.new_text,
669 language_registry,
670 cx,
671 )
672 })))),
673 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
674 .get(&terminal_id)
675 .cloned()
676 .map(|terminal| Some(Self::Terminal(terminal)))
677 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
678 _ => Ok(None),
679 }
680 }
681
682 pub fn update_from_acp(
683 &mut self,
684 new: acp::ToolCallContent,
685 language_registry: Arc<LanguageRegistry>,
686 path_style: PathStyle,
687 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
688 cx: &mut App,
689 ) -> Result<bool> {
690 let needs_update = match (&self, &new) {
691 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
692 old_diff.read(cx).needs_update(
693 new_diff.old_text.as_deref().unwrap_or(""),
694 &new_diff.new_text,
695 cx,
696 )
697 }
698 _ => true,
699 };
700
701 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
702 if needs_update {
703 *self = update;
704 }
705 Ok(true)
706 } else {
707 Ok(false)
708 }
709 }
710
711 pub fn to_markdown(&self, cx: &App) -> String {
712 match self {
713 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
714 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
715 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
716 }
717 }
718
719 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
720 match self {
721 Self::ContentBlock(content) => content.image(),
722 _ => None,
723 }
724 }
725}
726
727#[derive(Debug, PartialEq)]
728pub enum ToolCallUpdate {
729 UpdateFields(acp::ToolCallUpdate),
730 UpdateDiff(ToolCallUpdateDiff),
731 UpdateTerminal(ToolCallUpdateTerminal),
732}
733
734impl ToolCallUpdate {
735 fn id(&self) -> &acp::ToolCallId {
736 match self {
737 Self::UpdateFields(update) => &update.tool_call_id,
738 Self::UpdateDiff(diff) => &diff.id,
739 Self::UpdateTerminal(terminal) => &terminal.id,
740 }
741 }
742}
743
744impl From<acp::ToolCallUpdate> for ToolCallUpdate {
745 fn from(update: acp::ToolCallUpdate) -> Self {
746 Self::UpdateFields(update)
747 }
748}
749
750impl From<ToolCallUpdateDiff> for ToolCallUpdate {
751 fn from(diff: ToolCallUpdateDiff) -> Self {
752 Self::UpdateDiff(diff)
753 }
754}
755
756#[derive(Debug, PartialEq)]
757pub struct ToolCallUpdateDiff {
758 pub id: acp::ToolCallId,
759 pub diff: Entity<Diff>,
760}
761
762impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
763 fn from(terminal: ToolCallUpdateTerminal) -> Self {
764 Self::UpdateTerminal(terminal)
765 }
766}
767
768#[derive(Debug, PartialEq)]
769pub struct ToolCallUpdateTerminal {
770 pub id: acp::ToolCallId,
771 pub terminal: Entity<Terminal>,
772}
773
774#[derive(Debug, Default)]
775pub struct Plan {
776 pub entries: Vec<PlanEntry>,
777}
778
779#[derive(Debug)]
780pub struct PlanStats<'a> {
781 pub in_progress_entry: Option<&'a PlanEntry>,
782 pub pending: u32,
783 pub completed: u32,
784}
785
786impl Plan {
787 pub fn is_empty(&self) -> bool {
788 self.entries.is_empty()
789 }
790
791 pub fn stats(&self) -> PlanStats<'_> {
792 let mut stats = PlanStats {
793 in_progress_entry: None,
794 pending: 0,
795 completed: 0,
796 };
797
798 for entry in &self.entries {
799 match &entry.status {
800 acp::PlanEntryStatus::Pending => {
801 stats.pending += 1;
802 }
803 acp::PlanEntryStatus::InProgress => {
804 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
805 }
806 acp::PlanEntryStatus::Completed => {
807 stats.completed += 1;
808 }
809 _ => {}
810 }
811 }
812
813 stats
814 }
815}
816
817#[derive(Debug)]
818pub struct PlanEntry {
819 pub content: Entity<Markdown>,
820 pub priority: acp::PlanEntryPriority,
821 pub status: acp::PlanEntryStatus,
822}
823
824impl PlanEntry {
825 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
826 Self {
827 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
828 priority: entry.priority,
829 status: entry.status,
830 }
831 }
832}
833
834#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
835pub struct TokenUsage {
836 pub max_tokens: u64,
837 pub used_tokens: u64,
838}
839
840impl TokenUsage {
841 pub fn ratio(&self) -> TokenUsageRatio {
842 #[cfg(debug_assertions)]
843 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
844 .unwrap_or("0.8".to_string())
845 .parse()
846 .unwrap();
847 #[cfg(not(debug_assertions))]
848 let warning_threshold: f32 = 0.8;
849
850 // When the maximum is unknown because there is no selected model,
851 // avoid showing the token limit warning.
852 if self.max_tokens == 0 {
853 TokenUsageRatio::Normal
854 } else if self.used_tokens >= self.max_tokens {
855 TokenUsageRatio::Exceeded
856 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
857 TokenUsageRatio::Warning
858 } else {
859 TokenUsageRatio::Normal
860 }
861 }
862}
863
864#[derive(Debug, Clone, PartialEq, Eq)]
865pub enum TokenUsageRatio {
866 Normal,
867 Warning,
868 Exceeded,
869}
870
871#[derive(Debug, Clone)]
872pub struct RetryStatus {
873 pub last_error: SharedString,
874 pub attempt: usize,
875 pub max_attempts: usize,
876 pub started_at: Instant,
877 pub duration: Duration,
878}
879
880pub struct AcpThread {
881 title: SharedString,
882 entries: Vec<AgentThreadEntry>,
883 plan: Plan,
884 project: Entity<Project>,
885 action_log: Entity<ActionLog>,
886 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
887 send_task: Option<Task<()>>,
888 connection: Rc<dyn AgentConnection>,
889 session_id: acp::SessionId,
890 token_usage: Option<TokenUsage>,
891 prompt_capabilities: acp::PromptCapabilities,
892 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
893 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
894 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
895 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
896}
897
898impl From<&AcpThread> for ActionLogTelemetry {
899 fn from(value: &AcpThread) -> Self {
900 Self {
901 agent_telemetry_id: value.connection().telemetry_id(),
902 session_id: value.session_id.0.clone(),
903 }
904 }
905}
906
907#[derive(Debug)]
908pub enum AcpThreadEvent {
909 NewEntry,
910 TitleUpdated,
911 TokenUsageUpdated,
912 EntryUpdated(usize),
913 EntriesRemoved(Range<usize>),
914 ToolAuthorizationRequired,
915 Retry(RetryStatus),
916 Stopped,
917 Error,
918 LoadError(LoadError),
919 PromptCapabilitiesUpdated,
920 Refusal,
921 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
922 ModeUpdated(acp::SessionModeId),
923 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
924}
925
926impl EventEmitter<AcpThreadEvent> for AcpThread {}
927
928#[derive(Debug, Clone)]
929pub enum TerminalProviderEvent {
930 Created {
931 terminal_id: acp::TerminalId,
932 label: String,
933 cwd: Option<PathBuf>,
934 output_byte_limit: Option<u64>,
935 terminal: Entity<::terminal::Terminal>,
936 },
937 Output {
938 terminal_id: acp::TerminalId,
939 data: Vec<u8>,
940 },
941 TitleChanged {
942 terminal_id: acp::TerminalId,
943 title: String,
944 },
945 Exit {
946 terminal_id: acp::TerminalId,
947 status: acp::TerminalExitStatus,
948 },
949}
950
951#[derive(Debug, Clone)]
952pub enum TerminalProviderCommand {
953 WriteInput {
954 terminal_id: acp::TerminalId,
955 bytes: Vec<u8>,
956 },
957 Resize {
958 terminal_id: acp::TerminalId,
959 cols: u16,
960 rows: u16,
961 },
962 Close {
963 terminal_id: acp::TerminalId,
964 },
965}
966
967impl AcpThread {
968 pub fn on_terminal_provider_event(
969 &mut self,
970 event: TerminalProviderEvent,
971 cx: &mut Context<Self>,
972 ) {
973 match event {
974 TerminalProviderEvent::Created {
975 terminal_id,
976 label,
977 cwd,
978 output_byte_limit,
979 terminal,
980 } => {
981 let entity = self.register_terminal_created(
982 terminal_id.clone(),
983 label,
984 cwd,
985 output_byte_limit,
986 terminal,
987 cx,
988 );
989
990 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
991 for data in chunks.drain(..) {
992 entity.update(cx, |term, cx| {
993 term.inner().update(cx, |inner, cx| {
994 inner.write_output(&data, cx);
995 })
996 });
997 }
998 }
999
1000 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1001 entity.update(cx, |_term, cx| {
1002 cx.notify();
1003 });
1004 }
1005
1006 cx.notify();
1007 }
1008 TerminalProviderEvent::Output { terminal_id, data } => {
1009 if let Some(entity) = self.terminals.get(&terminal_id) {
1010 entity.update(cx, |term, cx| {
1011 term.inner().update(cx, |inner, cx| {
1012 inner.write_output(&data, cx);
1013 })
1014 });
1015 } else {
1016 self.pending_terminal_output
1017 .entry(terminal_id)
1018 .or_default()
1019 .push(data);
1020 }
1021 }
1022 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1023 if let Some(entity) = self.terminals.get(&terminal_id) {
1024 entity.update(cx, |term, cx| {
1025 term.inner().update(cx, |inner, cx| {
1026 inner.breadcrumb_text = title;
1027 cx.emit(::terminal::Event::BreadcrumbsChanged);
1028 })
1029 });
1030 }
1031 }
1032 TerminalProviderEvent::Exit {
1033 terminal_id,
1034 status,
1035 } => {
1036 if let Some(entity) = self.terminals.get(&terminal_id) {
1037 entity.update(cx, |_term, cx| {
1038 cx.notify();
1039 });
1040 } else {
1041 self.pending_terminal_exit.insert(terminal_id, status);
1042 }
1043 }
1044 }
1045 }
1046}
1047
1048#[derive(PartialEq, Eq, Debug)]
1049pub enum ThreadStatus {
1050 Idle,
1051 Generating,
1052}
1053
1054#[derive(Debug, Clone)]
1055pub enum LoadError {
1056 Unsupported {
1057 command: SharedString,
1058 current_version: SharedString,
1059 minimum_version: SharedString,
1060 },
1061 FailedToInstall(SharedString),
1062 Exited {
1063 status: ExitStatus,
1064 },
1065 Other(SharedString),
1066}
1067
1068impl Display for LoadError {
1069 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1070 match self {
1071 LoadError::Unsupported {
1072 command: path,
1073 current_version,
1074 minimum_version,
1075 } => {
1076 write!(
1077 f,
1078 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1079 )
1080 }
1081 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1082 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1083 LoadError::Other(msg) => write!(f, "{msg}"),
1084 }
1085 }
1086}
1087
1088impl Error for LoadError {}
1089
1090impl AcpThread {
1091 pub fn new(
1092 title: impl Into<SharedString>,
1093 connection: Rc<dyn AgentConnection>,
1094 project: Entity<Project>,
1095 action_log: Entity<ActionLog>,
1096 session_id: acp::SessionId,
1097 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1098 cx: &mut Context<Self>,
1099 ) -> Self {
1100 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1101 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1102 loop {
1103 let caps = prompt_capabilities_rx.recv().await?;
1104 this.update(cx, |this, cx| {
1105 this.prompt_capabilities = caps;
1106 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1107 })?;
1108 }
1109 });
1110
1111 Self {
1112 action_log,
1113 shared_buffers: Default::default(),
1114 entries: Default::default(),
1115 plan: Default::default(),
1116 title: title.into(),
1117 project,
1118 send_task: None,
1119 connection,
1120 session_id,
1121 token_usage: None,
1122 prompt_capabilities,
1123 _observe_prompt_capabilities: task,
1124 terminals: HashMap::default(),
1125 pending_terminal_output: HashMap::default(),
1126 pending_terminal_exit: HashMap::default(),
1127 }
1128 }
1129
1130 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1131 self.prompt_capabilities.clone()
1132 }
1133
1134 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1135 &self.connection
1136 }
1137
1138 pub fn action_log(&self) -> &Entity<ActionLog> {
1139 &self.action_log
1140 }
1141
1142 pub fn project(&self) -> &Entity<Project> {
1143 &self.project
1144 }
1145
1146 pub fn title(&self) -> SharedString {
1147 self.title.clone()
1148 }
1149
1150 pub fn entries(&self) -> &[AgentThreadEntry] {
1151 &self.entries
1152 }
1153
1154 pub fn session_id(&self) -> &acp::SessionId {
1155 &self.session_id
1156 }
1157
1158 pub fn status(&self) -> ThreadStatus {
1159 if self.send_task.is_some() {
1160 ThreadStatus::Generating
1161 } else {
1162 ThreadStatus::Idle
1163 }
1164 }
1165
1166 pub fn token_usage(&self) -> Option<&TokenUsage> {
1167 self.token_usage.as_ref()
1168 }
1169
1170 pub fn has_pending_edit_tool_calls(&self) -> bool {
1171 for entry in self.entries.iter().rev() {
1172 match entry {
1173 AgentThreadEntry::UserMessage(_) => return false,
1174 AgentThreadEntry::ToolCall(
1175 call @ ToolCall {
1176 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1177 ..
1178 },
1179 ) if call.diffs().next().is_some() => {
1180 return true;
1181 }
1182 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1183 }
1184 }
1185
1186 false
1187 }
1188
1189 pub fn used_tools_since_last_user_message(&self) -> bool {
1190 for entry in self.entries.iter().rev() {
1191 match entry {
1192 AgentThreadEntry::UserMessage(..) => return false,
1193 AgentThreadEntry::AssistantMessage(..) => continue,
1194 AgentThreadEntry::ToolCall(..) => return true,
1195 }
1196 }
1197
1198 false
1199 }
1200
1201 pub fn handle_session_update(
1202 &mut self,
1203 update: acp::SessionUpdate,
1204 cx: &mut Context<Self>,
1205 ) -> Result<(), acp::Error> {
1206 match update {
1207 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1208 self.push_user_content_block(None, content, cx);
1209 }
1210 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1211 self.push_assistant_content_block(content, false, cx);
1212 }
1213 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1214 self.push_assistant_content_block(content, true, cx);
1215 }
1216 acp::SessionUpdate::ToolCall(tool_call) => {
1217 self.upsert_tool_call(tool_call, cx)?;
1218 }
1219 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1220 self.update_tool_call(tool_call_update, cx)?;
1221 }
1222 acp::SessionUpdate::Plan(plan) => {
1223 self.update_plan(plan, cx);
1224 }
1225 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1226 available_commands,
1227 ..
1228 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1229 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1230 current_mode_id,
1231 ..
1232 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1233 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1234 config_options,
1235 ..
1236 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1237 _ => {}
1238 }
1239 Ok(())
1240 }
1241
1242 pub fn push_user_content_block(
1243 &mut self,
1244 message_id: Option<UserMessageId>,
1245 chunk: acp::ContentBlock,
1246 cx: &mut Context<Self>,
1247 ) {
1248 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1249 }
1250
1251 pub fn push_user_content_block_with_indent(
1252 &mut self,
1253 message_id: Option<UserMessageId>,
1254 chunk: acp::ContentBlock,
1255 indented: bool,
1256 cx: &mut Context<Self>,
1257 ) {
1258 let language_registry = self.project.read(cx).languages().clone();
1259 let path_style = self.project.read(cx).path_style(cx);
1260 let entries_len = self.entries.len();
1261
1262 if let Some(last_entry) = self.entries.last_mut()
1263 && let AgentThreadEntry::UserMessage(UserMessage {
1264 id,
1265 content,
1266 chunks,
1267 indented: existing_indented,
1268 ..
1269 }) = last_entry
1270 && *existing_indented == indented
1271 {
1272 *id = message_id.or(id.take());
1273 content.append(chunk.clone(), &language_registry, path_style, cx);
1274 chunks.push(chunk);
1275 let idx = entries_len - 1;
1276 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1277 } else {
1278 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1279 self.push_entry(
1280 AgentThreadEntry::UserMessage(UserMessage {
1281 id: message_id,
1282 content,
1283 chunks: vec![chunk],
1284 checkpoint: None,
1285 indented,
1286 }),
1287 cx,
1288 );
1289 }
1290 }
1291
1292 pub fn push_assistant_content_block(
1293 &mut self,
1294 chunk: acp::ContentBlock,
1295 is_thought: bool,
1296 cx: &mut Context<Self>,
1297 ) {
1298 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1299 }
1300
1301 pub fn push_assistant_content_block_with_indent(
1302 &mut self,
1303 chunk: acp::ContentBlock,
1304 is_thought: bool,
1305 indented: bool,
1306 cx: &mut Context<Self>,
1307 ) {
1308 let language_registry = self.project.read(cx).languages().clone();
1309 let path_style = self.project.read(cx).path_style(cx);
1310 let entries_len = self.entries.len();
1311 if let Some(last_entry) = self.entries.last_mut()
1312 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1313 chunks,
1314 indented: existing_indented,
1315 }) = last_entry
1316 && *existing_indented == indented
1317 {
1318 let idx = entries_len - 1;
1319 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1320 match (chunks.last_mut(), is_thought) {
1321 (Some(AssistantMessageChunk::Message { block }), false)
1322 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1323 block.append(chunk, &language_registry, path_style, cx)
1324 }
1325 _ => {
1326 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1327 if is_thought {
1328 chunks.push(AssistantMessageChunk::Thought { block })
1329 } else {
1330 chunks.push(AssistantMessageChunk::Message { block })
1331 }
1332 }
1333 }
1334 } else {
1335 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1336 let chunk = if is_thought {
1337 AssistantMessageChunk::Thought { block }
1338 } else {
1339 AssistantMessageChunk::Message { block }
1340 };
1341
1342 self.push_entry(
1343 AgentThreadEntry::AssistantMessage(AssistantMessage {
1344 chunks: vec![chunk],
1345 indented,
1346 }),
1347 cx,
1348 );
1349 }
1350 }
1351
1352 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1353 self.entries.push(entry);
1354 cx.emit(AcpThreadEvent::NewEntry);
1355 }
1356
1357 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1358 self.connection.set_title(&self.session_id, cx).is_some()
1359 }
1360
1361 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1362 if title != self.title {
1363 self.title = title.clone();
1364 cx.emit(AcpThreadEvent::TitleUpdated);
1365 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1366 return set_title.run(title, cx);
1367 }
1368 }
1369 Task::ready(Ok(()))
1370 }
1371
1372 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1373 self.token_usage = usage;
1374 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1375 }
1376
1377 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1378 cx.emit(AcpThreadEvent::Retry(status));
1379 }
1380
1381 pub fn update_tool_call(
1382 &mut self,
1383 update: impl Into<ToolCallUpdate>,
1384 cx: &mut Context<Self>,
1385 ) -> Result<()> {
1386 let update = update.into();
1387 let languages = self.project.read(cx).languages().clone();
1388 let path_style = self.project.read(cx).path_style(cx);
1389
1390 let ix = match self.index_for_tool_call(update.id()) {
1391 Some(ix) => ix,
1392 None => {
1393 // Tool call not found - create a failed tool call entry
1394 let failed_tool_call = ToolCall {
1395 id: update.id().clone(),
1396 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1397 kind: acp::ToolKind::Fetch,
1398 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1399 "Tool call not found".into(),
1400 &languages,
1401 path_style,
1402 cx,
1403 ))],
1404 status: ToolCallStatus::Failed,
1405 locations: Vec::new(),
1406 resolved_locations: Vec::new(),
1407 raw_input: None,
1408 raw_input_markdown: None,
1409 raw_output: None,
1410 };
1411 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1412 return Ok(());
1413 }
1414 };
1415 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1416 unreachable!()
1417 };
1418
1419 match update {
1420 ToolCallUpdate::UpdateFields(update) => {
1421 let location_updated = update.fields.locations.is_some();
1422 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1423 if location_updated {
1424 self.resolve_locations(update.tool_call_id, cx);
1425 }
1426 }
1427 ToolCallUpdate::UpdateDiff(update) => {
1428 call.content.clear();
1429 call.content.push(ToolCallContent::Diff(update.diff));
1430 }
1431 ToolCallUpdate::UpdateTerminal(update) => {
1432 call.content.clear();
1433 call.content
1434 .push(ToolCallContent::Terminal(update.terminal));
1435 }
1436 }
1437
1438 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1439
1440 Ok(())
1441 }
1442
1443 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1444 pub fn upsert_tool_call(
1445 &mut self,
1446 tool_call: acp::ToolCall,
1447 cx: &mut Context<Self>,
1448 ) -> Result<(), acp::Error> {
1449 let status = tool_call.status.into();
1450 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1451 }
1452
1453 /// Fails if id does not match an existing entry.
1454 pub fn upsert_tool_call_inner(
1455 &mut self,
1456 update: acp::ToolCallUpdate,
1457 status: ToolCallStatus,
1458 cx: &mut Context<Self>,
1459 ) -> Result<(), acp::Error> {
1460 let language_registry = self.project.read(cx).languages().clone();
1461 let path_style = self.project.read(cx).path_style(cx);
1462 let id = update.tool_call_id.clone();
1463
1464 let agent_telemetry_id = self.connection().telemetry_id();
1465 let session = self.session_id();
1466 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1467 let status = if matches!(status, ToolCallStatus::Completed) {
1468 "completed"
1469 } else {
1470 "failed"
1471 };
1472 telemetry::event!(
1473 "Agent Tool Call Completed",
1474 agent_telemetry_id,
1475 session,
1476 status
1477 );
1478 }
1479
1480 if let Some(ix) = self.index_for_tool_call(&id) {
1481 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1482 unreachable!()
1483 };
1484
1485 call.update_fields(
1486 update.fields,
1487 language_registry,
1488 path_style,
1489 &self.terminals,
1490 cx,
1491 )?;
1492 call.status = status;
1493
1494 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1495 } else {
1496 let call = ToolCall::from_acp(
1497 update.try_into()?,
1498 status,
1499 language_registry,
1500 self.project.read(cx).path_style(cx),
1501 &self.terminals,
1502 cx,
1503 )?;
1504 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1505 };
1506
1507 self.resolve_locations(id, cx);
1508 Ok(())
1509 }
1510
1511 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1512 self.entries
1513 .iter()
1514 .enumerate()
1515 .rev()
1516 .find_map(|(index, entry)| {
1517 if let AgentThreadEntry::ToolCall(tool_call) = entry
1518 && &tool_call.id == id
1519 {
1520 Some(index)
1521 } else {
1522 None
1523 }
1524 })
1525 }
1526
1527 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1528 // The tool call we are looking for is typically the last one, or very close to the end.
1529 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1530 self.entries
1531 .iter_mut()
1532 .enumerate()
1533 .rev()
1534 .find_map(|(index, tool_call)| {
1535 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1536 && &tool_call.id == id
1537 {
1538 Some((index, tool_call))
1539 } else {
1540 None
1541 }
1542 })
1543 }
1544
1545 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1546 self.entries
1547 .iter()
1548 .enumerate()
1549 .rev()
1550 .find_map(|(index, tool_call)| {
1551 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1552 && &tool_call.id == id
1553 {
1554 Some((index, tool_call))
1555 } else {
1556 None
1557 }
1558 })
1559 }
1560
1561 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1562 let project = self.project.clone();
1563 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1564 return;
1565 };
1566 let task = tool_call.resolve_locations(project, cx);
1567 cx.spawn(async move |this, cx| {
1568 let resolved_locations = task.await;
1569
1570 this.update(cx, |this, cx| {
1571 let project = this.project.clone();
1572
1573 for location in resolved_locations.iter().flatten() {
1574 this.shared_buffers
1575 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1576 }
1577 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1578 return;
1579 };
1580
1581 if let Some(Some(location)) = resolved_locations.last() {
1582 project.update(cx, |project, cx| {
1583 let should_ignore = if let Some(agent_location) = project
1584 .agent_location()
1585 .filter(|agent_location| agent_location.buffer == location.buffer)
1586 {
1587 let snapshot = location.buffer.read(cx).snapshot();
1588 let old_position = agent_location.position.to_point(&snapshot);
1589 let new_position = location.position.to_point(&snapshot);
1590
1591 // ignore this so that when we get updates from the edit tool
1592 // the position doesn't reset to the startof line
1593 old_position.row == new_position.row
1594 && old_position.column > new_position.column
1595 } else {
1596 false
1597 };
1598 if !should_ignore {
1599 project.set_agent_location(Some(location.into()), cx);
1600 }
1601 });
1602 }
1603
1604 let resolved_locations = resolved_locations
1605 .iter()
1606 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1607 .collect::<Vec<_>>();
1608
1609 if tool_call.resolved_locations != resolved_locations {
1610 tool_call.resolved_locations = resolved_locations;
1611 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1612 }
1613 })
1614 })
1615 .detach();
1616 }
1617
1618 pub fn request_tool_call_authorization(
1619 &mut self,
1620 tool_call: acp::ToolCallUpdate,
1621 options: Vec<acp::PermissionOption>,
1622 respect_always_allow_setting: bool,
1623 cx: &mut Context<Self>,
1624 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1625 let (tx, rx) = oneshot::channel();
1626
1627 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1628 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1629 // some tools would (incorrectly) continue to auto-accept.
1630 if let Some(allow_once_option) = options.iter().find_map(|option| {
1631 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1632 Some(option.option_id.clone())
1633 } else {
1634 None
1635 }
1636 }) {
1637 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1638 return Ok(async {
1639 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1640 allow_once_option,
1641 ))
1642 }
1643 .boxed());
1644 }
1645 }
1646
1647 let status = ToolCallStatus::WaitingForConfirmation {
1648 options,
1649 respond_tx: tx,
1650 };
1651
1652 self.upsert_tool_call_inner(tool_call, status, cx)?;
1653 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1654
1655 let fut = async {
1656 match rx.await {
1657 Ok(option) => acp::RequestPermissionOutcome::Selected(
1658 acp::SelectedPermissionOutcome::new(option),
1659 ),
1660 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1661 }
1662 }
1663 .boxed();
1664
1665 Ok(fut)
1666 }
1667
1668 pub fn authorize_tool_call(
1669 &mut self,
1670 id: acp::ToolCallId,
1671 option_id: acp::PermissionOptionId,
1672 option_kind: acp::PermissionOptionKind,
1673 cx: &mut Context<Self>,
1674 ) {
1675 let Some((ix, call)) = self.tool_call_mut(&id) else {
1676 return;
1677 };
1678
1679 let new_status = match option_kind {
1680 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1681 ToolCallStatus::Rejected
1682 }
1683 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1684 ToolCallStatus::InProgress
1685 }
1686 _ => ToolCallStatus::InProgress,
1687 };
1688
1689 let curr_status = mem::replace(&mut call.status, new_status);
1690
1691 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1692 respond_tx.send(option_id).log_err();
1693 } else if cfg!(debug_assertions) {
1694 panic!("tried to authorize an already authorized tool call");
1695 }
1696
1697 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1698 }
1699
1700 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1701 let mut first_tool_call = None;
1702
1703 for entry in self.entries.iter().rev() {
1704 match &entry {
1705 AgentThreadEntry::ToolCall(call) => {
1706 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1707 first_tool_call = Some(call);
1708 } else {
1709 continue;
1710 }
1711 }
1712 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1713 // Reached the beginning of the turn.
1714 // If we had pending permission requests in the previous turn, they have been cancelled.
1715 break;
1716 }
1717 }
1718 }
1719
1720 first_tool_call
1721 }
1722
1723 pub fn plan(&self) -> &Plan {
1724 &self.plan
1725 }
1726
1727 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1728 let new_entries_len = request.entries.len();
1729 let mut new_entries = request.entries.into_iter();
1730
1731 // Reuse existing markdown to prevent flickering
1732 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1733 let PlanEntry {
1734 content,
1735 priority,
1736 status,
1737 } = old;
1738 content.update(cx, |old, cx| {
1739 old.replace(new.content, cx);
1740 });
1741 *priority = new.priority;
1742 *status = new.status;
1743 }
1744 for new in new_entries {
1745 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1746 }
1747 self.plan.entries.truncate(new_entries_len);
1748
1749 cx.notify();
1750 }
1751
1752 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1753 self.plan
1754 .entries
1755 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1756 cx.notify();
1757 }
1758
1759 #[cfg(any(test, feature = "test-support"))]
1760 pub fn send_raw(
1761 &mut self,
1762 message: &str,
1763 cx: &mut Context<Self>,
1764 ) -> BoxFuture<'static, Result<()>> {
1765 self.send(vec![message.into()], cx)
1766 }
1767
1768 pub fn send(
1769 &mut self,
1770 message: Vec<acp::ContentBlock>,
1771 cx: &mut Context<Self>,
1772 ) -> BoxFuture<'static, Result<()>> {
1773 let block = ContentBlock::new_combined(
1774 message.clone(),
1775 self.project.read(cx).languages().clone(),
1776 self.project.read(cx).path_style(cx),
1777 cx,
1778 );
1779 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1780 let git_store = self.project.read(cx).git_store().clone();
1781
1782 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1783 Some(UserMessageId::new())
1784 } else {
1785 None
1786 };
1787
1788 self.run_turn(cx, async move |this, cx| {
1789 this.update(cx, |this, cx| {
1790 this.push_entry(
1791 AgentThreadEntry::UserMessage(UserMessage {
1792 id: message_id.clone(),
1793 content: block,
1794 chunks: message,
1795 checkpoint: None,
1796 indented: false,
1797 }),
1798 cx,
1799 );
1800 })
1801 .ok();
1802
1803 let old_checkpoint = git_store
1804 .update(cx, |git, cx| git.checkpoint(cx))
1805 .await
1806 .context("failed to get old checkpoint")
1807 .log_err();
1808 this.update(cx, |this, cx| {
1809 if let Some((_ix, message)) = this.last_user_message() {
1810 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1811 git_checkpoint,
1812 show: false,
1813 });
1814 }
1815 this.connection.prompt(message_id, request, cx)
1816 })?
1817 .await
1818 })
1819 }
1820
1821 pub fn can_resume(&self, cx: &App) -> bool {
1822 self.connection.resume(&self.session_id, cx).is_some()
1823 }
1824
1825 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1826 self.run_turn(cx, async move |this, cx| {
1827 this.update(cx, |this, cx| {
1828 this.connection
1829 .resume(&this.session_id, cx)
1830 .map(|resume| resume.run(cx))
1831 })?
1832 .context("resuming a session is not supported")?
1833 .await
1834 })
1835 }
1836
1837 fn run_turn(
1838 &mut self,
1839 cx: &mut Context<Self>,
1840 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1841 ) -> BoxFuture<'static, Result<()>> {
1842 self.clear_completed_plan_entries(cx);
1843
1844 let (tx, rx) = oneshot::channel();
1845 let cancel_task = self.cancel(cx);
1846
1847 self.send_task = Some(cx.spawn(async move |this, cx| {
1848 cancel_task.await;
1849 tx.send(f(this, cx).await).ok();
1850 }));
1851
1852 cx.spawn(async move |this, cx| {
1853 let response = rx.await;
1854
1855 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1856 .await?;
1857
1858 this.update(cx, |this, cx| {
1859 this.project
1860 .update(cx, |project, cx| project.set_agent_location(None, cx));
1861 match response {
1862 Ok(Err(e)) => {
1863 this.send_task.take();
1864 cx.emit(AcpThreadEvent::Error);
1865 Err(e)
1866 }
1867 result => {
1868 let canceled = matches!(
1869 result,
1870 Ok(Ok(acp::PromptResponse {
1871 stop_reason: acp::StopReason::Cancelled,
1872 ..
1873 }))
1874 );
1875
1876 // We only take the task if the current prompt wasn't canceled.
1877 //
1878 // This prompt may have been canceled because another one was sent
1879 // while it was still generating. In these cases, dropping `send_task`
1880 // would cause the next generation to be canceled.
1881 if !canceled {
1882 this.send_task.take();
1883 }
1884
1885 // Handle refusal - distinguish between user prompt and tool call refusals
1886 if let Ok(Ok(acp::PromptResponse {
1887 stop_reason: acp::StopReason::Refusal,
1888 ..
1889 })) = result
1890 {
1891 if let Some((user_msg_ix, _)) = this.last_user_message() {
1892 // Check if there's a completed tool call with results after the last user message
1893 // This indicates the refusal is in response to tool output, not the user's prompt
1894 let has_completed_tool_call_after_user_msg =
1895 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1896 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1897 // Check if the tool call has completed and has output
1898 matches!(tool_call.status, ToolCallStatus::Completed)
1899 && tool_call.raw_output.is_some()
1900 } else {
1901 false
1902 }
1903 });
1904
1905 if has_completed_tool_call_after_user_msg {
1906 // Refusal is due to tool output - don't truncate, just notify
1907 // The model refused based on what the tool returned
1908 cx.emit(AcpThreadEvent::Refusal);
1909 } else {
1910 // User prompt was refused - truncate back to before the user message
1911 let range = user_msg_ix..this.entries.len();
1912 if range.start < range.end {
1913 this.entries.truncate(user_msg_ix);
1914 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1915 }
1916 cx.emit(AcpThreadEvent::Refusal);
1917 }
1918 } else {
1919 // No user message found, treat as general refusal
1920 cx.emit(AcpThreadEvent::Refusal);
1921 }
1922 }
1923
1924 cx.emit(AcpThreadEvent::Stopped);
1925 Ok(())
1926 }
1927 }
1928 })?
1929 })
1930 .boxed()
1931 }
1932
1933 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1934 let Some(send_task) = self.send_task.take() else {
1935 return Task::ready(());
1936 };
1937
1938 for entry in self.entries.iter_mut() {
1939 if let AgentThreadEntry::ToolCall(call) = entry {
1940 let cancel = matches!(
1941 call.status,
1942 ToolCallStatus::Pending
1943 | ToolCallStatus::WaitingForConfirmation { .. }
1944 | ToolCallStatus::InProgress
1945 );
1946
1947 if cancel {
1948 call.status = ToolCallStatus::Canceled;
1949 }
1950 }
1951 }
1952
1953 self.connection.cancel(&self.session_id, cx);
1954
1955 // Wait for the send task to complete
1956 cx.foreground_executor().spawn(send_task)
1957 }
1958
1959 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1960 pub fn restore_checkpoint(
1961 &mut self,
1962 id: UserMessageId,
1963 cx: &mut Context<Self>,
1964 ) -> Task<Result<()>> {
1965 let Some((_, message)) = self.user_message_mut(&id) else {
1966 return Task::ready(Err(anyhow!("message not found")));
1967 };
1968
1969 let checkpoint = message
1970 .checkpoint
1971 .as_ref()
1972 .map(|c| c.git_checkpoint.clone());
1973
1974 // Cancel any in-progress generation before restoring
1975 let cancel_task = self.cancel(cx);
1976 let rewind = self.rewind(id.clone(), cx);
1977 let git_store = self.project.read(cx).git_store().clone();
1978
1979 cx.spawn(async move |_, cx| {
1980 cancel_task.await;
1981 rewind.await?;
1982 if let Some(checkpoint) = checkpoint {
1983 git_store
1984 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))
1985 .await?;
1986 }
1987
1988 Ok(())
1989 })
1990 }
1991
1992 /// Rewinds this thread to before the entry at `index`, removing it and all
1993 /// subsequent entries while rejecting any action_log changes made from that point.
1994 /// Unlike `restore_checkpoint`, this method does not restore from git.
1995 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1996 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1997 return Task::ready(Err(anyhow!("not supported")));
1998 };
1999
2000 let telemetry = ActionLogTelemetry::from(&*self);
2001 cx.spawn(async move |this, cx| {
2002 cx.update(|cx| truncate.run(id.clone(), cx)).await?;
2003 this.update(cx, |this, cx| {
2004 if let Some((ix, _)) = this.user_message_mut(&id) {
2005 // Collect all terminals from entries that will be removed
2006 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2007 .iter()
2008 .flat_map(|entry| entry.terminals())
2009 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2010 .collect();
2011
2012 let range = ix..this.entries.len();
2013 this.entries.truncate(ix);
2014 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2015
2016 // Kill and remove the terminals
2017 for terminal_id in terminals_to_remove {
2018 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2019 terminal.update(cx, |terminal, cx| {
2020 terminal.kill(cx);
2021 });
2022 }
2023 }
2024 }
2025 this.action_log().update(cx, |action_log, cx| {
2026 action_log.reject_all_edits(Some(telemetry), cx)
2027 })
2028 })?
2029 .await;
2030 Ok(())
2031 })
2032 }
2033
2034 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2035 let git_store = self.project.read(cx).git_store().clone();
2036
2037 let Some((_, message)) = self.last_user_message() else {
2038 return Task::ready(Ok(()));
2039 };
2040 let Some(user_message_id) = message.id.clone() else {
2041 return Task::ready(Ok(()));
2042 };
2043 let Some(checkpoint) = message.checkpoint.as_ref() else {
2044 return Task::ready(Ok(()));
2045 };
2046 let old_checkpoint = checkpoint.git_checkpoint.clone();
2047
2048 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2049 cx.spawn(async move |this, cx| {
2050 let Some(new_checkpoint) = new_checkpoint
2051 .await
2052 .context("failed to get new checkpoint")
2053 .log_err()
2054 else {
2055 return Ok(());
2056 };
2057
2058 let equal = git_store
2059 .update(cx, |git, cx| {
2060 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2061 })
2062 .await
2063 .unwrap_or(true);
2064
2065 this.update(cx, |this, cx| {
2066 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2067 if let Some(checkpoint) = message.checkpoint.as_mut() {
2068 checkpoint.show = !equal;
2069 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2070 }
2071 }
2072 })?;
2073
2074 Ok(())
2075 })
2076 }
2077
2078 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2079 self.entries
2080 .iter_mut()
2081 .enumerate()
2082 .rev()
2083 .find_map(|(ix, entry)| {
2084 if let AgentThreadEntry::UserMessage(message) = entry {
2085 Some((ix, message))
2086 } else {
2087 None
2088 }
2089 })
2090 }
2091
2092 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2093 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2094 if let AgentThreadEntry::UserMessage(message) = entry {
2095 if message.id.as_ref() == Some(id) {
2096 Some((ix, message))
2097 } else {
2098 None
2099 }
2100 } else {
2101 None
2102 }
2103 })
2104 }
2105
2106 pub fn read_text_file(
2107 &self,
2108 path: PathBuf,
2109 line: Option<u32>,
2110 limit: Option<u32>,
2111 reuse_shared_snapshot: bool,
2112 cx: &mut Context<Self>,
2113 ) -> Task<Result<String, acp::Error>> {
2114 // Args are 1-based, move to 0-based
2115 let line = line.unwrap_or_default().saturating_sub(1);
2116 let limit = limit.unwrap_or(u32::MAX);
2117 let project = self.project.clone();
2118 let action_log = self.action_log.clone();
2119 cx.spawn(async move |this, cx| {
2120 let load = project.update(cx, |project, cx| {
2121 let path = project
2122 .project_path_for_absolute_path(&path, cx)
2123 .ok_or_else(|| {
2124 acp::Error::resource_not_found(Some(path.display().to_string()))
2125 })?;
2126 Ok::<_, acp::Error>(project.open_buffer(path, cx))
2127 })?;
2128
2129 let buffer = load.await?;
2130
2131 let snapshot = if reuse_shared_snapshot {
2132 this.read_with(cx, |this, _| {
2133 this.shared_buffers.get(&buffer.clone()).cloned()
2134 })
2135 .log_err()
2136 .flatten()
2137 } else {
2138 None
2139 };
2140
2141 let snapshot = if let Some(snapshot) = snapshot {
2142 snapshot
2143 } else {
2144 action_log.update(cx, |action_log, cx| {
2145 action_log.buffer_read(buffer.clone(), cx);
2146 });
2147
2148 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
2149 this.update(cx, |this, _| {
2150 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2151 })?;
2152 snapshot
2153 };
2154
2155 let max_point = snapshot.max_point();
2156 let start_position = Point::new(line, 0);
2157
2158 if start_position > max_point {
2159 return Err(acp::Error::invalid_params().data(format!(
2160 "Attempting to read beyond the end of the file, line {}:{}",
2161 max_point.row + 1,
2162 max_point.column
2163 )));
2164 }
2165
2166 let start = snapshot.anchor_before(start_position);
2167 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2168
2169 project.update(cx, |project, cx| {
2170 project.set_agent_location(
2171 Some(AgentLocation {
2172 buffer: buffer.downgrade(),
2173 position: start,
2174 }),
2175 cx,
2176 );
2177 });
2178
2179 Ok(snapshot.text_for_range(start..end).collect::<String>())
2180 })
2181 }
2182
2183 pub fn write_text_file(
2184 &self,
2185 path: PathBuf,
2186 content: String,
2187 cx: &mut Context<Self>,
2188 ) -> Task<Result<()>> {
2189 let project = self.project.clone();
2190 let action_log = self.action_log.clone();
2191 cx.spawn(async move |this, cx| {
2192 let load = project.update(cx, |project, cx| {
2193 let path = project
2194 .project_path_for_absolute_path(&path, cx)
2195 .context("invalid path")?;
2196 anyhow::Ok(project.open_buffer(path, cx))
2197 });
2198 let buffer = load?.await?;
2199 let snapshot = this.update(cx, |this, cx| {
2200 this.shared_buffers
2201 .get(&buffer)
2202 .cloned()
2203 .unwrap_or_else(|| buffer.read(cx).snapshot())
2204 })?;
2205 let edits = cx
2206 .background_executor()
2207 .spawn(async move {
2208 let old_text = snapshot.text();
2209 text_diff(old_text.as_str(), &content)
2210 .into_iter()
2211 .map(|(range, replacement)| {
2212 (
2213 snapshot.anchor_after(range.start)
2214 ..snapshot.anchor_before(range.end),
2215 replacement,
2216 )
2217 })
2218 .collect::<Vec<_>>()
2219 })
2220 .await;
2221
2222 project.update(cx, |project, cx| {
2223 project.set_agent_location(
2224 Some(AgentLocation {
2225 buffer: buffer.downgrade(),
2226 position: edits
2227 .last()
2228 .map(|(range, _)| range.end)
2229 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2230 }),
2231 cx,
2232 );
2233 });
2234
2235 let format_on_save = cx.update(|cx| {
2236 action_log.update(cx, |action_log, cx| {
2237 action_log.buffer_read(buffer.clone(), cx);
2238 });
2239
2240 let format_on_save = buffer.update(cx, |buffer, cx| {
2241 buffer.edit(edits, None, cx);
2242
2243 let settings = language::language_settings::language_settings(
2244 buffer.language().map(|l| l.name()),
2245 buffer.file(),
2246 cx,
2247 );
2248
2249 settings.format_on_save != FormatOnSave::Off
2250 });
2251 action_log.update(cx, |action_log, cx| {
2252 action_log.buffer_edited(buffer.clone(), cx);
2253 });
2254 format_on_save
2255 });
2256
2257 if format_on_save {
2258 let format_task = project.update(cx, |project, cx| {
2259 project.format(
2260 HashSet::from_iter([buffer.clone()]),
2261 LspFormatTarget::Buffers,
2262 false,
2263 FormatTrigger::Save,
2264 cx,
2265 )
2266 });
2267 format_task.await.log_err();
2268
2269 action_log.update(cx, |action_log, cx| {
2270 action_log.buffer_edited(buffer.clone(), cx);
2271 });
2272 }
2273
2274 project
2275 .update(cx, |project, cx| project.save_buffer(buffer, cx))
2276 .await
2277 })
2278 }
2279
2280 pub fn create_terminal(
2281 &self,
2282 command: String,
2283 args: Vec<String>,
2284 extra_env: Vec<acp::EnvVariable>,
2285 cwd: Option<PathBuf>,
2286 output_byte_limit: Option<u64>,
2287 cx: &mut Context<Self>,
2288 ) -> Task<Result<Entity<Terminal>>> {
2289 let env = match &cwd {
2290 Some(dir) => self.project.update(cx, |project, cx| {
2291 project.environment().update(cx, |env, cx| {
2292 env.directory_environment(dir.as_path().into(), cx)
2293 })
2294 }),
2295 None => Task::ready(None).shared(),
2296 };
2297 let env = cx.spawn(async move |_, _| {
2298 let mut env = env.await.unwrap_or_default();
2299 // Disables paging for `git` and hopefully other commands
2300 env.insert("PAGER".into(), "".into());
2301 for var in extra_env {
2302 env.insert(var.name, var.value);
2303 }
2304 env
2305 });
2306
2307 let project = self.project.clone();
2308 let language_registry = project.read(cx).languages().clone();
2309 let is_windows = project.read(cx).path_style(cx).is_windows();
2310
2311 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2312 let terminal_task = cx.spawn({
2313 let terminal_id = terminal_id.clone();
2314 async move |_this, cx| {
2315 let env = env.await;
2316 let shell = project
2317 .update(cx, |project, cx| {
2318 project
2319 .remote_client()
2320 .and_then(|r| r.read(cx).default_system_shell())
2321 })
2322 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2323 let (task_command, task_args) =
2324 ShellBuilder::new(&Shell::Program(shell), is_windows)
2325 .redirect_stdin_to_dev_null()
2326 .build(Some(command.clone()), &args);
2327 let terminal = project
2328 .update(cx, |project, cx| {
2329 project.create_terminal_task(
2330 task::SpawnInTerminal {
2331 command: Some(task_command),
2332 args: task_args,
2333 cwd: cwd.clone(),
2334 env,
2335 ..Default::default()
2336 },
2337 cx,
2338 )
2339 })
2340 .await?;
2341
2342 anyhow::Ok(cx.new(|cx| {
2343 Terminal::new(
2344 terminal_id,
2345 &format!("{} {}", command, args.join(" ")),
2346 cwd,
2347 output_byte_limit.map(|l| l as usize),
2348 terminal,
2349 language_registry,
2350 cx,
2351 )
2352 }))
2353 }
2354 });
2355
2356 cx.spawn(async move |this, cx| {
2357 let terminal = terminal_task.await?;
2358 this.update(cx, |this, _cx| {
2359 this.terminals.insert(terminal_id, terminal.clone());
2360 terminal
2361 })
2362 })
2363 }
2364
2365 pub fn kill_terminal(
2366 &mut self,
2367 terminal_id: acp::TerminalId,
2368 cx: &mut Context<Self>,
2369 ) -> Result<()> {
2370 self.terminals
2371 .get(&terminal_id)
2372 .context("Terminal not found")?
2373 .update(cx, |terminal, cx| {
2374 terminal.kill(cx);
2375 });
2376
2377 Ok(())
2378 }
2379
2380 pub fn release_terminal(
2381 &mut self,
2382 terminal_id: acp::TerminalId,
2383 cx: &mut Context<Self>,
2384 ) -> Result<()> {
2385 self.terminals
2386 .remove(&terminal_id)
2387 .context("Terminal not found")?
2388 .update(cx, |terminal, cx| {
2389 terminal.kill(cx);
2390 });
2391
2392 Ok(())
2393 }
2394
2395 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2396 self.terminals
2397 .get(&terminal_id)
2398 .context("Terminal not found")
2399 .cloned()
2400 }
2401
2402 pub fn to_markdown(&self, cx: &App) -> String {
2403 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2404 }
2405
2406 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2407 cx.emit(AcpThreadEvent::LoadError(error));
2408 }
2409
2410 pub fn register_terminal_created(
2411 &mut self,
2412 terminal_id: acp::TerminalId,
2413 command_label: String,
2414 working_dir: Option<PathBuf>,
2415 output_byte_limit: Option<u64>,
2416 terminal: Entity<::terminal::Terminal>,
2417 cx: &mut Context<Self>,
2418 ) -> Entity<Terminal> {
2419 let language_registry = self.project.read(cx).languages().clone();
2420
2421 let entity = cx.new(|cx| {
2422 Terminal::new(
2423 terminal_id.clone(),
2424 &command_label,
2425 working_dir.clone(),
2426 output_byte_limit.map(|l| l as usize),
2427 terminal,
2428 language_registry,
2429 cx,
2430 )
2431 });
2432 self.terminals.insert(terminal_id.clone(), entity.clone());
2433 entity
2434 }
2435}
2436
2437fn markdown_for_raw_output(
2438 raw_output: &serde_json::Value,
2439 language_registry: &Arc<LanguageRegistry>,
2440 cx: &mut App,
2441) -> Option<Entity<Markdown>> {
2442 match raw_output {
2443 serde_json::Value::Null => None,
2444 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2445 Markdown::new(
2446 value.to_string().into(),
2447 Some(language_registry.clone()),
2448 None,
2449 cx,
2450 )
2451 })),
2452 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2453 Markdown::new(
2454 value.to_string().into(),
2455 Some(language_registry.clone()),
2456 None,
2457 cx,
2458 )
2459 })),
2460 serde_json::Value::String(value) => Some(cx.new(|cx| {
2461 Markdown::new(
2462 value.clone().into(),
2463 Some(language_registry.clone()),
2464 None,
2465 cx,
2466 )
2467 })),
2468 value => Some(cx.new(|cx| {
2469 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2470
2471 Markdown::new(
2472 format!("```json\n{}\n```", pretty_json).into(),
2473 Some(language_registry.clone()),
2474 None,
2475 cx,
2476 )
2477 })),
2478 }
2479}
2480
2481#[cfg(test)]
2482mod tests {
2483 use super::*;
2484 use anyhow::anyhow;
2485 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2486 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2487 use indoc::indoc;
2488 use project::{FakeFs, Fs};
2489 use rand::{distr, prelude::*};
2490 use serde_json::json;
2491 use settings::SettingsStore;
2492 use smol::stream::StreamExt as _;
2493 use std::{
2494 any::Any,
2495 cell::RefCell,
2496 path::Path,
2497 rc::Rc,
2498 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2499 time::Duration,
2500 };
2501 use util::path;
2502
2503 fn init_test(cx: &mut TestAppContext) {
2504 env_logger::try_init().ok();
2505 cx.update(|cx| {
2506 let settings_store = SettingsStore::test(cx);
2507 cx.set_global(settings_store);
2508 });
2509 }
2510
2511 #[gpui::test]
2512 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2513 init_test(cx);
2514
2515 let fs = FakeFs::new(cx.executor());
2516 let project = Project::test(fs, [], cx).await;
2517 let connection = Rc::new(FakeAgentConnection::new());
2518 let thread = cx
2519 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2520 .await
2521 .unwrap();
2522
2523 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2524
2525 // Send Output BEFORE Created - should be buffered by acp_thread
2526 thread.update(cx, |thread, cx| {
2527 thread.on_terminal_provider_event(
2528 TerminalProviderEvent::Output {
2529 terminal_id: terminal_id.clone(),
2530 data: b"hello buffered".to_vec(),
2531 },
2532 cx,
2533 );
2534 });
2535
2536 // Create a display-only terminal and then send Created
2537 let lower = cx.new(|cx| {
2538 let builder = ::terminal::TerminalBuilder::new_display_only(
2539 ::terminal::terminal_settings::CursorShape::default(),
2540 ::terminal::terminal_settings::AlternateScroll::On,
2541 None,
2542 0,
2543 )
2544 .unwrap();
2545 builder.subscribe(cx)
2546 });
2547
2548 thread.update(cx, |thread, cx| {
2549 thread.on_terminal_provider_event(
2550 TerminalProviderEvent::Created {
2551 terminal_id: terminal_id.clone(),
2552 label: "Buffered Test".to_string(),
2553 cwd: None,
2554 output_byte_limit: None,
2555 terminal: lower.clone(),
2556 },
2557 cx,
2558 );
2559 });
2560
2561 // After Created, buffered Output should have been flushed into the renderer
2562 let content = thread.read_with(cx, |thread, cx| {
2563 let term = thread.terminal(terminal_id.clone()).unwrap();
2564 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2565 });
2566
2567 assert!(
2568 content.contains("hello buffered"),
2569 "expected buffered output to render, got: {content}"
2570 );
2571 }
2572
2573 #[gpui::test]
2574 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2575 init_test(cx);
2576
2577 let fs = FakeFs::new(cx.executor());
2578 let project = Project::test(fs, [], cx).await;
2579 let connection = Rc::new(FakeAgentConnection::new());
2580 let thread = cx
2581 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2582 .await
2583 .unwrap();
2584
2585 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2586
2587 // Send Output BEFORE Created
2588 thread.update(cx, |thread, cx| {
2589 thread.on_terminal_provider_event(
2590 TerminalProviderEvent::Output {
2591 terminal_id: terminal_id.clone(),
2592 data: b"pre-exit data".to_vec(),
2593 },
2594 cx,
2595 );
2596 });
2597
2598 // Send Exit BEFORE Created
2599 thread.update(cx, |thread, cx| {
2600 thread.on_terminal_provider_event(
2601 TerminalProviderEvent::Exit {
2602 terminal_id: terminal_id.clone(),
2603 status: acp::TerminalExitStatus::new().exit_code(0),
2604 },
2605 cx,
2606 );
2607 });
2608
2609 // Now create a display-only lower-level terminal and send Created
2610 let lower = cx.new(|cx| {
2611 let builder = ::terminal::TerminalBuilder::new_display_only(
2612 ::terminal::terminal_settings::CursorShape::default(),
2613 ::terminal::terminal_settings::AlternateScroll::On,
2614 None,
2615 0,
2616 )
2617 .unwrap();
2618 builder.subscribe(cx)
2619 });
2620
2621 thread.update(cx, |thread, cx| {
2622 thread.on_terminal_provider_event(
2623 TerminalProviderEvent::Created {
2624 terminal_id: terminal_id.clone(),
2625 label: "Buffered Exit Test".to_string(),
2626 cwd: None,
2627 output_byte_limit: None,
2628 terminal: lower.clone(),
2629 },
2630 cx,
2631 );
2632 });
2633
2634 // Output should be present after Created (flushed from buffer)
2635 let content = thread.read_with(cx, |thread, cx| {
2636 let term = thread.terminal(terminal_id.clone()).unwrap();
2637 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2638 });
2639
2640 assert!(
2641 content.contains("pre-exit data"),
2642 "expected pre-exit data to render, got: {content}"
2643 );
2644 }
2645
2646 #[gpui::test]
2647 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2648 init_test(cx);
2649
2650 let fs = FakeFs::new(cx.executor());
2651 let project = Project::test(fs, [], cx).await;
2652 let connection = Rc::new(FakeAgentConnection::new());
2653 let thread = cx
2654 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2655 .await
2656 .unwrap();
2657
2658 // Test creating a new user message
2659 thread.update(cx, |thread, cx| {
2660 thread.push_user_content_block(None, "Hello, ".into(), cx);
2661 });
2662
2663 thread.update(cx, |thread, cx| {
2664 assert_eq!(thread.entries.len(), 1);
2665 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2666 assert_eq!(user_msg.id, None);
2667 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2668 } else {
2669 panic!("Expected UserMessage");
2670 }
2671 });
2672
2673 // Test appending to existing user message
2674 let message_1_id = UserMessageId::new();
2675 thread.update(cx, |thread, cx| {
2676 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2677 });
2678
2679 thread.update(cx, |thread, cx| {
2680 assert_eq!(thread.entries.len(), 1);
2681 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2682 assert_eq!(user_msg.id, Some(message_1_id));
2683 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2684 } else {
2685 panic!("Expected UserMessage");
2686 }
2687 });
2688
2689 // Test creating new user message after assistant message
2690 thread.update(cx, |thread, cx| {
2691 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2692 });
2693
2694 let message_2_id = UserMessageId::new();
2695 thread.update(cx, |thread, cx| {
2696 thread.push_user_content_block(
2697 Some(message_2_id.clone()),
2698 "New user message".into(),
2699 cx,
2700 );
2701 });
2702
2703 thread.update(cx, |thread, cx| {
2704 assert_eq!(thread.entries.len(), 3);
2705 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2706 assert_eq!(user_msg.id, Some(message_2_id));
2707 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2708 } else {
2709 panic!("Expected UserMessage at index 2");
2710 }
2711 });
2712 }
2713
2714 #[gpui::test]
2715 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2716 init_test(cx);
2717
2718 let fs = FakeFs::new(cx.executor());
2719 let project = Project::test(fs, [], cx).await;
2720 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2721 |_, thread, mut cx| {
2722 async move {
2723 thread.update(&mut cx, |thread, cx| {
2724 thread
2725 .handle_session_update(
2726 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2727 "Thinking ".into(),
2728 )),
2729 cx,
2730 )
2731 .unwrap();
2732 thread
2733 .handle_session_update(
2734 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2735 "hard!".into(),
2736 )),
2737 cx,
2738 )
2739 .unwrap();
2740 })?;
2741 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2742 }
2743 .boxed_local()
2744 },
2745 ));
2746
2747 let thread = cx
2748 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2749 .await
2750 .unwrap();
2751
2752 thread
2753 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2754 .await
2755 .unwrap();
2756
2757 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2758 assert_eq!(
2759 output,
2760 indoc! {r#"
2761 ## User
2762
2763 Hello from Zed!
2764
2765 ## Assistant
2766
2767 <thinking>
2768 Thinking hard!
2769 </thinking>
2770
2771 "#}
2772 );
2773 }
2774
2775 #[gpui::test]
2776 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2777 init_test(cx);
2778
2779 let fs = FakeFs::new(cx.executor());
2780 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2781 .await;
2782 let project = Project::test(fs.clone(), [], cx).await;
2783 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2784 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2785 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2786 move |_, thread, mut cx| {
2787 let read_file_tx = read_file_tx.clone();
2788 async move {
2789 let content = thread
2790 .update(&mut cx, |thread, cx| {
2791 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2792 })
2793 .unwrap()
2794 .await
2795 .unwrap();
2796 assert_eq!(content, "one\ntwo\nthree\n");
2797 read_file_tx.take().unwrap().send(()).unwrap();
2798 thread
2799 .update(&mut cx, |thread, cx| {
2800 thread.write_text_file(
2801 path!("/tmp/foo").into(),
2802 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2803 cx,
2804 )
2805 })
2806 .unwrap()
2807 .await
2808 .unwrap();
2809 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2810 }
2811 .boxed_local()
2812 },
2813 ));
2814
2815 let (worktree, pathbuf) = project
2816 .update(cx, |project, cx| {
2817 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2818 })
2819 .await
2820 .unwrap();
2821 let buffer = project
2822 .update(cx, |project, cx| {
2823 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2824 })
2825 .await
2826 .unwrap();
2827
2828 let thread = cx
2829 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2830 .await
2831 .unwrap();
2832
2833 let request = thread.update(cx, |thread, cx| {
2834 thread.send_raw("Extend the count in /tmp/foo", cx)
2835 });
2836 read_file_rx.await.ok();
2837 buffer.update(cx, |buffer, cx| {
2838 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2839 });
2840 cx.run_until_parked();
2841 assert_eq!(
2842 buffer.read_with(cx, |buffer, _| buffer.text()),
2843 "zero\none\ntwo\nthree\nfour\nfive\n"
2844 );
2845 assert_eq!(
2846 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2847 "zero\none\ntwo\nthree\nfour\nfive\n"
2848 );
2849 request.await.unwrap();
2850 }
2851
2852 #[gpui::test]
2853 async fn test_reading_from_line(cx: &mut TestAppContext) {
2854 init_test(cx);
2855
2856 let fs = FakeFs::new(cx.executor());
2857 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2858 .await;
2859 let project = Project::test(fs.clone(), [], cx).await;
2860 project
2861 .update(cx, |project, cx| {
2862 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2863 })
2864 .await
2865 .unwrap();
2866
2867 let connection = Rc::new(FakeAgentConnection::new());
2868
2869 let thread = cx
2870 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2871 .await
2872 .unwrap();
2873
2874 // Whole file
2875 let content = thread
2876 .update(cx, |thread, cx| {
2877 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2878 })
2879 .await
2880 .unwrap();
2881
2882 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2883
2884 // Only start line
2885 let content = thread
2886 .update(cx, |thread, cx| {
2887 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2888 })
2889 .await
2890 .unwrap();
2891
2892 assert_eq!(content, "three\nfour\n");
2893
2894 // Only limit
2895 let content = thread
2896 .update(cx, |thread, cx| {
2897 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2898 })
2899 .await
2900 .unwrap();
2901
2902 assert_eq!(content, "one\ntwo\n");
2903
2904 // Range
2905 let content = thread
2906 .update(cx, |thread, cx| {
2907 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2908 })
2909 .await
2910 .unwrap();
2911
2912 assert_eq!(content, "two\nthree\n");
2913
2914 // Invalid
2915 let err = thread
2916 .update(cx, |thread, cx| {
2917 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2918 })
2919 .await
2920 .unwrap_err();
2921
2922 assert_eq!(
2923 err.to_string(),
2924 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2925 );
2926 }
2927
2928 #[gpui::test]
2929 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2930 init_test(cx);
2931
2932 let fs = FakeFs::new(cx.executor());
2933 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2934 let project = Project::test(fs.clone(), [], cx).await;
2935 project
2936 .update(cx, |project, cx| {
2937 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2938 })
2939 .await
2940 .unwrap();
2941
2942 let connection = Rc::new(FakeAgentConnection::new());
2943
2944 let thread = cx
2945 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2946 .await
2947 .unwrap();
2948
2949 // Whole file
2950 let content = thread
2951 .update(cx, |thread, cx| {
2952 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2953 })
2954 .await
2955 .unwrap();
2956
2957 assert_eq!(content, "");
2958
2959 // Only start line
2960 let content = thread
2961 .update(cx, |thread, cx| {
2962 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2963 })
2964 .await
2965 .unwrap();
2966
2967 assert_eq!(content, "");
2968
2969 // Only limit
2970 let content = thread
2971 .update(cx, |thread, cx| {
2972 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2973 })
2974 .await
2975 .unwrap();
2976
2977 assert_eq!(content, "");
2978
2979 // Range
2980 let content = thread
2981 .update(cx, |thread, cx| {
2982 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2983 })
2984 .await
2985 .unwrap();
2986
2987 assert_eq!(content, "");
2988
2989 // Invalid
2990 let err = thread
2991 .update(cx, |thread, cx| {
2992 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2993 })
2994 .await
2995 .unwrap_err();
2996
2997 assert_eq!(
2998 err.to_string(),
2999 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3000 );
3001 }
3002 #[gpui::test]
3003 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3004 init_test(cx);
3005
3006 let fs = FakeFs::new(cx.executor());
3007 fs.insert_tree(path!("/tmp"), json!({})).await;
3008 let project = Project::test(fs.clone(), [], cx).await;
3009 project
3010 .update(cx, |project, cx| {
3011 project.find_or_create_worktree(path!("/tmp"), true, cx)
3012 })
3013 .await
3014 .unwrap();
3015
3016 let connection = Rc::new(FakeAgentConnection::new());
3017
3018 let thread = cx
3019 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3020 .await
3021 .unwrap();
3022
3023 // Out of project file
3024 let err = thread
3025 .update(cx, |thread, cx| {
3026 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3027 })
3028 .await
3029 .unwrap_err();
3030
3031 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3032 }
3033
3034 #[gpui::test]
3035 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3036 init_test(cx);
3037
3038 let fs = FakeFs::new(cx.executor());
3039 let project = Project::test(fs, [], cx).await;
3040 let id = acp::ToolCallId::new("test");
3041
3042 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3043 let id = id.clone();
3044 move |_, thread, mut cx| {
3045 let id = id.clone();
3046 async move {
3047 thread
3048 .update(&mut cx, |thread, cx| {
3049 thread.handle_session_update(
3050 acp::SessionUpdate::ToolCall(
3051 acp::ToolCall::new(id.clone(), "Label")
3052 .kind(acp::ToolKind::Fetch)
3053 .status(acp::ToolCallStatus::InProgress),
3054 ),
3055 cx,
3056 )
3057 })
3058 .unwrap()
3059 .unwrap();
3060 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3061 }
3062 .boxed_local()
3063 }
3064 }));
3065
3066 let thread = cx
3067 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3068 .await
3069 .unwrap();
3070
3071 let request = thread.update(cx, |thread, cx| {
3072 thread.send_raw("Fetch https://example.com", cx)
3073 });
3074
3075 run_until_first_tool_call(&thread, cx).await;
3076
3077 thread.read_with(cx, |thread, _| {
3078 assert!(matches!(
3079 thread.entries[1],
3080 AgentThreadEntry::ToolCall(ToolCall {
3081 status: ToolCallStatus::InProgress,
3082 ..
3083 })
3084 ));
3085 });
3086
3087 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3088
3089 thread.read_with(cx, |thread, _| {
3090 assert!(matches!(
3091 &thread.entries[1],
3092 AgentThreadEntry::ToolCall(ToolCall {
3093 status: ToolCallStatus::Canceled,
3094 ..
3095 })
3096 ));
3097 });
3098
3099 thread
3100 .update(cx, |thread, cx| {
3101 thread.handle_session_update(
3102 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3103 id,
3104 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3105 )),
3106 cx,
3107 )
3108 })
3109 .unwrap();
3110
3111 request.await.unwrap();
3112
3113 thread.read_with(cx, |thread, _| {
3114 assert!(matches!(
3115 thread.entries[1],
3116 AgentThreadEntry::ToolCall(ToolCall {
3117 status: ToolCallStatus::Completed,
3118 ..
3119 })
3120 ));
3121 });
3122 }
3123
3124 #[gpui::test]
3125 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3126 init_test(cx);
3127 let fs = FakeFs::new(cx.background_executor.clone());
3128 fs.insert_tree(path!("/test"), json!({})).await;
3129 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3130
3131 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3132 move |_, thread, mut cx| {
3133 async move {
3134 thread
3135 .update(&mut cx, |thread, cx| {
3136 thread.handle_session_update(
3137 acp::SessionUpdate::ToolCall(
3138 acp::ToolCall::new("test", "Label")
3139 .kind(acp::ToolKind::Edit)
3140 .status(acp::ToolCallStatus::Completed)
3141 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3142 "/test/test.txt",
3143 "foo",
3144 ))]),
3145 ),
3146 cx,
3147 )
3148 })
3149 .unwrap()
3150 .unwrap();
3151 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3152 }
3153 .boxed_local()
3154 }
3155 }));
3156
3157 let thread = cx
3158 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3159 .await
3160 .unwrap();
3161
3162 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3163 .await
3164 .unwrap();
3165
3166 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3167 }
3168
3169 #[gpui::test(iterations = 10)]
3170 async fn test_checkpoints(cx: &mut TestAppContext) {
3171 init_test(cx);
3172 let fs = FakeFs::new(cx.background_executor.clone());
3173 fs.insert_tree(
3174 path!("/test"),
3175 json!({
3176 ".git": {}
3177 }),
3178 )
3179 .await;
3180 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3181
3182 let simulate_changes = Arc::new(AtomicBool::new(true));
3183 let next_filename = Arc::new(AtomicUsize::new(0));
3184 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3185 let simulate_changes = simulate_changes.clone();
3186 let next_filename = next_filename.clone();
3187 let fs = fs.clone();
3188 move |request, thread, mut cx| {
3189 let fs = fs.clone();
3190 let simulate_changes = simulate_changes.clone();
3191 let next_filename = next_filename.clone();
3192 async move {
3193 if simulate_changes.load(SeqCst) {
3194 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3195 fs.write(Path::new(&filename), b"").await?;
3196 }
3197
3198 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3199 panic!("expected text content block");
3200 };
3201 thread.update(&mut cx, |thread, cx| {
3202 thread
3203 .handle_session_update(
3204 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3205 content.text.to_uppercase().into(),
3206 )),
3207 cx,
3208 )
3209 .unwrap();
3210 })?;
3211 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3212 }
3213 .boxed_local()
3214 }
3215 }));
3216 let thread = cx
3217 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3218 .await
3219 .unwrap();
3220
3221 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3222 .await
3223 .unwrap();
3224 thread.read_with(cx, |thread, cx| {
3225 assert_eq!(
3226 thread.to_markdown(cx),
3227 indoc! {"
3228 ## User (checkpoint)
3229
3230 Lorem
3231
3232 ## Assistant
3233
3234 LOREM
3235
3236 "}
3237 );
3238 });
3239 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3240
3241 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3242 .await
3243 .unwrap();
3244 thread.read_with(cx, |thread, cx| {
3245 assert_eq!(
3246 thread.to_markdown(cx),
3247 indoc! {"
3248 ## User (checkpoint)
3249
3250 Lorem
3251
3252 ## Assistant
3253
3254 LOREM
3255
3256 ## User (checkpoint)
3257
3258 ipsum
3259
3260 ## Assistant
3261
3262 IPSUM
3263
3264 "}
3265 );
3266 });
3267 assert_eq!(
3268 fs.files(),
3269 vec![
3270 Path::new(path!("/test/file-0")),
3271 Path::new(path!("/test/file-1"))
3272 ]
3273 );
3274
3275 // Checkpoint isn't stored when there are no changes.
3276 simulate_changes.store(false, SeqCst);
3277 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3278 .await
3279 .unwrap();
3280 thread.read_with(cx, |thread, cx| {
3281 assert_eq!(
3282 thread.to_markdown(cx),
3283 indoc! {"
3284 ## User (checkpoint)
3285
3286 Lorem
3287
3288 ## Assistant
3289
3290 LOREM
3291
3292 ## User (checkpoint)
3293
3294 ipsum
3295
3296 ## Assistant
3297
3298 IPSUM
3299
3300 ## User
3301
3302 dolor
3303
3304 ## Assistant
3305
3306 DOLOR
3307
3308 "}
3309 );
3310 });
3311 assert_eq!(
3312 fs.files(),
3313 vec![
3314 Path::new(path!("/test/file-0")),
3315 Path::new(path!("/test/file-1"))
3316 ]
3317 );
3318
3319 // Rewinding the conversation truncates the history and restores the checkpoint.
3320 thread
3321 .update(cx, |thread, cx| {
3322 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3323 panic!("unexpected entries {:?}", thread.entries)
3324 };
3325 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3326 })
3327 .await
3328 .unwrap();
3329 thread.read_with(cx, |thread, cx| {
3330 assert_eq!(
3331 thread.to_markdown(cx),
3332 indoc! {"
3333 ## User (checkpoint)
3334
3335 Lorem
3336
3337 ## Assistant
3338
3339 LOREM
3340
3341 "}
3342 );
3343 });
3344 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3345 }
3346
3347 #[gpui::test]
3348 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3349 use std::sync::atomic::AtomicUsize;
3350 init_test(cx);
3351
3352 let fs = FakeFs::new(cx.executor());
3353 let project = Project::test(fs, None, cx).await;
3354
3355 // Create a connection that simulates refusal after tool result
3356 let prompt_count = Arc::new(AtomicUsize::new(0));
3357 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3358 let prompt_count = prompt_count.clone();
3359 move |_request, thread, mut cx| {
3360 let count = prompt_count.fetch_add(1, SeqCst);
3361 async move {
3362 if count == 0 {
3363 // First prompt: Generate a tool call with result
3364 thread.update(&mut cx, |thread, cx| {
3365 thread
3366 .handle_session_update(
3367 acp::SessionUpdate::ToolCall(
3368 acp::ToolCall::new("tool1", "Test Tool")
3369 .kind(acp::ToolKind::Fetch)
3370 .status(acp::ToolCallStatus::Completed)
3371 .raw_input(serde_json::json!({"query": "test"}))
3372 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3373 ),
3374 cx,
3375 )
3376 .unwrap();
3377 })?;
3378
3379 // Now return refusal because of the tool result
3380 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3381 } else {
3382 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3383 }
3384 }
3385 .boxed_local()
3386 }
3387 }));
3388
3389 let thread = cx
3390 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3391 .await
3392 .unwrap();
3393
3394 // Track if we see a Refusal event
3395 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3396 let saw_refusal_event_captured = saw_refusal_event.clone();
3397 thread.update(cx, |_thread, cx| {
3398 cx.subscribe(
3399 &thread,
3400 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3401 if matches!(event, AcpThreadEvent::Refusal) {
3402 *saw_refusal_event_captured.lock().unwrap() = true;
3403 }
3404 },
3405 )
3406 .detach();
3407 });
3408
3409 // Send a user message - this will trigger tool call and then refusal
3410 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3411 cx.background_executor.spawn(send_task).detach();
3412 cx.run_until_parked();
3413
3414 // Verify that:
3415 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3416 // 2. The user message was NOT truncated
3417 assert!(
3418 *saw_refusal_event.lock().unwrap(),
3419 "Refusal event should be emitted for tool result refusals"
3420 );
3421
3422 thread.read_with(cx, |thread, _| {
3423 let entries = thread.entries();
3424 assert!(entries.len() >= 2, "Should have user message and tool call");
3425
3426 // Verify user message is still there
3427 assert!(
3428 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3429 "User message should not be truncated"
3430 );
3431
3432 // Verify tool call is there with result
3433 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3434 assert!(
3435 tool_call.raw_output.is_some(),
3436 "Tool call should have output"
3437 );
3438 } else {
3439 panic!("Expected tool call at index 1");
3440 }
3441 });
3442 }
3443
3444 #[gpui::test]
3445 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3446 init_test(cx);
3447
3448 let fs = FakeFs::new(cx.executor());
3449 let project = Project::test(fs, None, cx).await;
3450
3451 let refuse_next = Arc::new(AtomicBool::new(false));
3452 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3453 let refuse_next = refuse_next.clone();
3454 move |_request, _thread, _cx| {
3455 if refuse_next.load(SeqCst) {
3456 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3457 .boxed_local()
3458 } else {
3459 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3460 .boxed_local()
3461 }
3462 }
3463 }));
3464
3465 let thread = cx
3466 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3467 .await
3468 .unwrap();
3469
3470 // Track if we see a Refusal event
3471 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3472 let saw_refusal_event_captured = saw_refusal_event.clone();
3473 thread.update(cx, |_thread, cx| {
3474 cx.subscribe(
3475 &thread,
3476 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3477 if matches!(event, AcpThreadEvent::Refusal) {
3478 *saw_refusal_event_captured.lock().unwrap() = true;
3479 }
3480 },
3481 )
3482 .detach();
3483 });
3484
3485 // Send a message that will be refused
3486 refuse_next.store(true, SeqCst);
3487 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3488 .await
3489 .unwrap();
3490
3491 // Verify that a Refusal event WAS emitted for user prompt refusal
3492 assert!(
3493 *saw_refusal_event.lock().unwrap(),
3494 "Refusal event should be emitted for user prompt refusals"
3495 );
3496
3497 // Verify the message was truncated (user prompt refusal)
3498 thread.read_with(cx, |thread, cx| {
3499 assert_eq!(thread.to_markdown(cx), "");
3500 });
3501 }
3502
3503 #[gpui::test]
3504 async fn test_refusal(cx: &mut TestAppContext) {
3505 init_test(cx);
3506 let fs = FakeFs::new(cx.background_executor.clone());
3507 fs.insert_tree(path!("/"), json!({})).await;
3508 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3509
3510 let refuse_next = Arc::new(AtomicBool::new(false));
3511 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3512 let refuse_next = refuse_next.clone();
3513 move |request, thread, mut cx| {
3514 let refuse_next = refuse_next.clone();
3515 async move {
3516 if refuse_next.load(SeqCst) {
3517 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3518 }
3519
3520 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3521 panic!("expected text content block");
3522 };
3523 thread.update(&mut cx, |thread, cx| {
3524 thread
3525 .handle_session_update(
3526 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3527 content.text.to_uppercase().into(),
3528 )),
3529 cx,
3530 )
3531 .unwrap();
3532 })?;
3533 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3534 }
3535 .boxed_local()
3536 }
3537 }));
3538 let thread = cx
3539 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3540 .await
3541 .unwrap();
3542
3543 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3544 .await
3545 .unwrap();
3546 thread.read_with(cx, |thread, cx| {
3547 assert_eq!(
3548 thread.to_markdown(cx),
3549 indoc! {"
3550 ## User
3551
3552 hello
3553
3554 ## Assistant
3555
3556 HELLO
3557
3558 "}
3559 );
3560 });
3561
3562 // Simulate refusing the second message. The message should be truncated
3563 // when a user prompt is refused.
3564 refuse_next.store(true, SeqCst);
3565 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3566 .await
3567 .unwrap();
3568 thread.read_with(cx, |thread, cx| {
3569 assert_eq!(
3570 thread.to_markdown(cx),
3571 indoc! {"
3572 ## User
3573
3574 hello
3575
3576 ## Assistant
3577
3578 HELLO
3579
3580 "}
3581 );
3582 });
3583 }
3584
3585 async fn run_until_first_tool_call(
3586 thread: &Entity<AcpThread>,
3587 cx: &mut TestAppContext,
3588 ) -> usize {
3589 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3590
3591 let subscription = cx.update(|cx| {
3592 cx.subscribe(thread, move |thread, _, cx| {
3593 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3594 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3595 return tx.try_send(ix).unwrap();
3596 }
3597 }
3598 })
3599 });
3600
3601 select! {
3602 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3603 panic!("Timeout waiting for tool call")
3604 }
3605 ix = rx.next().fuse() => {
3606 drop(subscription);
3607 ix.unwrap()
3608 }
3609 }
3610 }
3611
3612 #[derive(Clone, Default)]
3613 struct FakeAgentConnection {
3614 auth_methods: Vec<acp::AuthMethod>,
3615 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3616 on_user_message: Option<
3617 Rc<
3618 dyn Fn(
3619 acp::PromptRequest,
3620 WeakEntity<AcpThread>,
3621 AsyncApp,
3622 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3623 + 'static,
3624 >,
3625 >,
3626 }
3627
3628 impl FakeAgentConnection {
3629 fn new() -> Self {
3630 Self {
3631 auth_methods: Vec::new(),
3632 on_user_message: None,
3633 sessions: Arc::default(),
3634 }
3635 }
3636
3637 #[expect(unused)]
3638 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3639 self.auth_methods = auth_methods;
3640 self
3641 }
3642
3643 fn on_user_message(
3644 mut self,
3645 handler: impl Fn(
3646 acp::PromptRequest,
3647 WeakEntity<AcpThread>,
3648 AsyncApp,
3649 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3650 + 'static,
3651 ) -> Self {
3652 self.on_user_message.replace(Rc::new(handler));
3653 self
3654 }
3655 }
3656
3657 impl AgentConnection for FakeAgentConnection {
3658 fn telemetry_id(&self) -> SharedString {
3659 "fake".into()
3660 }
3661
3662 fn auth_methods(&self) -> &[acp::AuthMethod] {
3663 &self.auth_methods
3664 }
3665
3666 fn new_thread(
3667 self: Rc<Self>,
3668 project: Entity<Project>,
3669 _cwd: &Path,
3670 cx: &mut App,
3671 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3672 let session_id = acp::SessionId::new(
3673 rand::rng()
3674 .sample_iter(&distr::Alphanumeric)
3675 .take(7)
3676 .map(char::from)
3677 .collect::<String>(),
3678 );
3679 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3680 let thread = cx.new(|cx| {
3681 AcpThread::new(
3682 "Test",
3683 self.clone(),
3684 project,
3685 action_log,
3686 session_id.clone(),
3687 watch::Receiver::constant(
3688 acp::PromptCapabilities::new()
3689 .image(true)
3690 .audio(true)
3691 .embedded_context(true),
3692 ),
3693 cx,
3694 )
3695 });
3696 self.sessions.lock().insert(session_id, thread.downgrade());
3697 Task::ready(Ok(thread))
3698 }
3699
3700 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3701 if self.auth_methods().iter().any(|m| m.id == method) {
3702 Task::ready(Ok(()))
3703 } else {
3704 Task::ready(Err(anyhow!("Invalid Auth Method")))
3705 }
3706 }
3707
3708 fn prompt(
3709 &self,
3710 _id: Option<UserMessageId>,
3711 params: acp::PromptRequest,
3712 cx: &mut App,
3713 ) -> Task<gpui::Result<acp::PromptResponse>> {
3714 let sessions = self.sessions.lock();
3715 let thread = sessions.get(¶ms.session_id).unwrap();
3716 if let Some(handler) = &self.on_user_message {
3717 let handler = handler.clone();
3718 let thread = thread.clone();
3719 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3720 } else {
3721 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3722 }
3723 }
3724
3725 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3726 let sessions = self.sessions.lock();
3727 let thread = sessions.get(session_id).unwrap().clone();
3728
3729 cx.spawn(async move |cx| {
3730 thread
3731 .update(cx, |thread, cx| thread.cancel(cx))
3732 .unwrap()
3733 .await
3734 })
3735 .detach();
3736 }
3737
3738 fn truncate(
3739 &self,
3740 session_id: &acp::SessionId,
3741 _cx: &App,
3742 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3743 Some(Rc::new(FakeAgentSessionEditor {
3744 _session_id: session_id.clone(),
3745 }))
3746 }
3747
3748 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3749 self
3750 }
3751 }
3752
3753 struct FakeAgentSessionEditor {
3754 _session_id: acp::SessionId,
3755 }
3756
3757 impl AgentSessionTruncate for FakeAgentSessionEditor {
3758 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3759 Task::ready(Ok(()))
3760 }
3761 }
3762
3763 #[gpui::test]
3764 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3765 init_test(cx);
3766
3767 let fs = FakeFs::new(cx.executor());
3768 let project = Project::test(fs, [], cx).await;
3769 let connection = Rc::new(FakeAgentConnection::new());
3770 let thread = cx
3771 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3772 .await
3773 .unwrap();
3774
3775 // Try to update a tool call that doesn't exist
3776 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3777 thread.update(cx, |thread, cx| {
3778 let result = thread.handle_session_update(
3779 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3780 nonexistent_id.clone(),
3781 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3782 )),
3783 cx,
3784 );
3785
3786 // The update should succeed (not return an error)
3787 assert!(result.is_ok());
3788
3789 // There should now be exactly one entry in the thread
3790 assert_eq!(thread.entries.len(), 1);
3791
3792 // The entry should be a failed tool call
3793 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3794 assert_eq!(tool_call.id, nonexistent_id);
3795 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3796 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3797
3798 // Check that the content contains the error message
3799 assert_eq!(tool_call.content.len(), 1);
3800 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3801 match content_block {
3802 ContentBlock::Markdown { markdown } => {
3803 let markdown_text = markdown.read(cx).source();
3804 assert!(markdown_text.contains("Tool call not found"));
3805 }
3806 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3807 ContentBlock::ResourceLink { .. } => {
3808 panic!("Expected markdown content, got resource link")
3809 }
3810 ContentBlock::Image { .. } => {
3811 panic!("Expected markdown content, got image")
3812 }
3813 }
3814 } else {
3815 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3816 }
3817 } else {
3818 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3819 }
3820 });
3821 }
3822
3823 /// Tests that restoring a checkpoint properly cleans up terminals that were
3824 /// created after that checkpoint, and cancels any in-progress generation.
3825 ///
3826 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3827 /// that were started after that checkpoint should be terminated, and any in-progress
3828 /// AI generation should be canceled.
3829 #[gpui::test]
3830 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3831 init_test(cx);
3832
3833 let fs = FakeFs::new(cx.executor());
3834 let project = Project::test(fs, [], cx).await;
3835 let connection = Rc::new(FakeAgentConnection::new());
3836 let thread = cx
3837 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3838 .await
3839 .unwrap();
3840
3841 // Send first user message to create a checkpoint
3842 cx.update(|cx| {
3843 thread.update(cx, |thread, cx| {
3844 thread.send(vec!["first message".into()], cx)
3845 })
3846 })
3847 .await
3848 .unwrap();
3849
3850 // Send second message (creates another checkpoint) - we'll restore to this one
3851 cx.update(|cx| {
3852 thread.update(cx, |thread, cx| {
3853 thread.send(vec!["second message".into()], cx)
3854 })
3855 })
3856 .await
3857 .unwrap();
3858
3859 // Create 2 terminals BEFORE the checkpoint that have completed running
3860 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3861 let mock_terminal_1 = cx.new(|cx| {
3862 let builder = ::terminal::TerminalBuilder::new_display_only(
3863 ::terminal::terminal_settings::CursorShape::default(),
3864 ::terminal::terminal_settings::AlternateScroll::On,
3865 None,
3866 0,
3867 )
3868 .unwrap();
3869 builder.subscribe(cx)
3870 });
3871
3872 thread.update(cx, |thread, cx| {
3873 thread.on_terminal_provider_event(
3874 TerminalProviderEvent::Created {
3875 terminal_id: terminal_id_1.clone(),
3876 label: "echo 'first'".to_string(),
3877 cwd: Some(PathBuf::from("/test")),
3878 output_byte_limit: None,
3879 terminal: mock_terminal_1.clone(),
3880 },
3881 cx,
3882 );
3883 });
3884
3885 thread.update(cx, |thread, cx| {
3886 thread.on_terminal_provider_event(
3887 TerminalProviderEvent::Output {
3888 terminal_id: terminal_id_1.clone(),
3889 data: b"first\n".to_vec(),
3890 },
3891 cx,
3892 );
3893 });
3894
3895 thread.update(cx, |thread, cx| {
3896 thread.on_terminal_provider_event(
3897 TerminalProviderEvent::Exit {
3898 terminal_id: terminal_id_1.clone(),
3899 status: acp::TerminalExitStatus::new().exit_code(0),
3900 },
3901 cx,
3902 );
3903 });
3904
3905 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3906 let mock_terminal_2 = cx.new(|cx| {
3907 let builder = ::terminal::TerminalBuilder::new_display_only(
3908 ::terminal::terminal_settings::CursorShape::default(),
3909 ::terminal::terminal_settings::AlternateScroll::On,
3910 None,
3911 0,
3912 )
3913 .unwrap();
3914 builder.subscribe(cx)
3915 });
3916
3917 thread.update(cx, |thread, cx| {
3918 thread.on_terminal_provider_event(
3919 TerminalProviderEvent::Created {
3920 terminal_id: terminal_id_2.clone(),
3921 label: "echo 'second'".to_string(),
3922 cwd: Some(PathBuf::from("/test")),
3923 output_byte_limit: None,
3924 terminal: mock_terminal_2.clone(),
3925 },
3926 cx,
3927 );
3928 });
3929
3930 thread.update(cx, |thread, cx| {
3931 thread.on_terminal_provider_event(
3932 TerminalProviderEvent::Output {
3933 terminal_id: terminal_id_2.clone(),
3934 data: b"second\n".to_vec(),
3935 },
3936 cx,
3937 );
3938 });
3939
3940 thread.update(cx, |thread, cx| {
3941 thread.on_terminal_provider_event(
3942 TerminalProviderEvent::Exit {
3943 terminal_id: terminal_id_2.clone(),
3944 status: acp::TerminalExitStatus::new().exit_code(0),
3945 },
3946 cx,
3947 );
3948 });
3949
3950 // Get the second message ID to restore to
3951 let second_message_id = thread.read_with(cx, |thread, _| {
3952 // At this point we have:
3953 // - Index 0: First user message (with checkpoint)
3954 // - Index 1: Second user message (with checkpoint)
3955 // No assistant responses because FakeAgentConnection just returns EndTurn
3956 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3957 panic!("expected user message at index 1");
3958 };
3959 message.id.clone().unwrap()
3960 });
3961
3962 // Create a terminal AFTER the checkpoint we'll restore to.
3963 // This simulates the AI agent starting a long-running terminal command.
3964 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3965 let mock_terminal = cx.new(|cx| {
3966 let builder = ::terminal::TerminalBuilder::new_display_only(
3967 ::terminal::terminal_settings::CursorShape::default(),
3968 ::terminal::terminal_settings::AlternateScroll::On,
3969 None,
3970 0,
3971 )
3972 .unwrap();
3973 builder.subscribe(cx)
3974 });
3975
3976 // Register the terminal as created
3977 thread.update(cx, |thread, cx| {
3978 thread.on_terminal_provider_event(
3979 TerminalProviderEvent::Created {
3980 terminal_id: terminal_id.clone(),
3981 label: "sleep 1000".to_string(),
3982 cwd: Some(PathBuf::from("/test")),
3983 output_byte_limit: None,
3984 terminal: mock_terminal.clone(),
3985 },
3986 cx,
3987 );
3988 });
3989
3990 // Simulate the terminal producing output (still running)
3991 thread.update(cx, |thread, cx| {
3992 thread.on_terminal_provider_event(
3993 TerminalProviderEvent::Output {
3994 terminal_id: terminal_id.clone(),
3995 data: b"terminal is running...\n".to_vec(),
3996 },
3997 cx,
3998 );
3999 });
4000
4001 // Create a tool call entry that references this terminal
4002 // This represents the agent requesting a terminal command
4003 thread.update(cx, |thread, cx| {
4004 thread
4005 .handle_session_update(
4006 acp::SessionUpdate::ToolCall(
4007 acp::ToolCall::new("terminal-tool-1", "Running command")
4008 .kind(acp::ToolKind::Execute)
4009 .status(acp::ToolCallStatus::InProgress)
4010 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4011 terminal_id.clone(),
4012 ))])
4013 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4014 ),
4015 cx,
4016 )
4017 .unwrap();
4018 });
4019
4020 // Verify terminal exists and is in the thread
4021 let terminal_exists_before =
4022 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4023 assert!(
4024 terminal_exists_before,
4025 "Terminal should exist before checkpoint restore"
4026 );
4027
4028 // Verify the terminal's underlying task is still running (not completed)
4029 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4030 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4031 terminal_entity.read_with(cx, |term, _cx| {
4032 term.output().is_none() // output is None means it's still running
4033 })
4034 });
4035 assert!(
4036 terminal_running_before,
4037 "Terminal should be running before checkpoint restore"
4038 );
4039
4040 // Verify we have the expected entries before restore
4041 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4042 assert!(
4043 entry_count_before > 1,
4044 "Should have multiple entries before restore"
4045 );
4046
4047 // Restore the checkpoint to the second message.
4048 // This should:
4049 // 1. Cancel any in-progress generation (via the cancel() call)
4050 // 2. Remove the terminal that was created after that point
4051 thread
4052 .update(cx, |thread, cx| {
4053 thread.restore_checkpoint(second_message_id, cx)
4054 })
4055 .await
4056 .unwrap();
4057
4058 // Verify that no send_task is in progress after restore
4059 // (cancel() clears the send_task)
4060 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4061 assert!(
4062 !has_send_task_after,
4063 "Should not have a send_task after restore (cancel should have cleared it)"
4064 );
4065
4066 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4067 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4068 assert_eq!(
4069 entry_count, 1,
4070 "Should have 1 entry after restore (only the first user message)"
4071 );
4072
4073 // Verify the 2 completed terminals from before the checkpoint still exist
4074 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4075 thread.terminals.contains_key(&terminal_id_1)
4076 });
4077 assert!(
4078 terminal_1_exists,
4079 "Terminal 1 (from before checkpoint) should still exist"
4080 );
4081
4082 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4083 thread.terminals.contains_key(&terminal_id_2)
4084 });
4085 assert!(
4086 terminal_2_exists,
4087 "Terminal 2 (from before checkpoint) should still exist"
4088 );
4089
4090 // Verify they're still in completed state
4091 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4092 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4093 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4094 });
4095 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4096
4097 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4098 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4099 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4100 });
4101 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4102
4103 // Verify the running terminal (created after checkpoint) was removed
4104 let terminal_3_exists =
4105 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4106 assert!(
4107 !terminal_3_exists,
4108 "Terminal 3 (created after checkpoint) should have been removed"
4109 );
4110
4111 // Verify total count is 2 (the two from before the checkpoint)
4112 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4113 assert_eq!(
4114 terminal_count, 2,
4115 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4116 );
4117 }
4118
4119 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4120 /// even when a new user message is added while the async checkpoint comparison is in progress.
4121 ///
4122 /// This is a regression test for a bug where update_last_checkpoint would fail with
4123 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4124 /// update_last_checkpoint started and when its async closure ran.
4125 #[gpui::test]
4126 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4127 init_test(cx);
4128
4129 let fs = FakeFs::new(cx.executor());
4130 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4131 .await;
4132 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4133
4134 let handler_done = Arc::new(AtomicBool::new(false));
4135 let handler_done_clone = handler_done.clone();
4136 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4137 move |_, _thread, _cx| {
4138 handler_done_clone.store(true, SeqCst);
4139 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4140 },
4141 ));
4142
4143 let thread = cx
4144 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4145 .await
4146 .unwrap();
4147
4148 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4149 let send_task = cx.background_executor.spawn(send_future);
4150
4151 // Tick until handler completes, then a few more to let update_last_checkpoint start
4152 while !handler_done.load(SeqCst) {
4153 cx.executor().tick();
4154 }
4155 for _ in 0..5 {
4156 cx.executor().tick();
4157 }
4158
4159 thread.update(cx, |thread, cx| {
4160 thread.push_entry(
4161 AgentThreadEntry::UserMessage(UserMessage {
4162 id: Some(UserMessageId::new()),
4163 content: ContentBlock::Empty,
4164 chunks: vec!["Injected message (no checkpoint)".into()],
4165 checkpoint: None,
4166 indented: false,
4167 }),
4168 cx,
4169 );
4170 });
4171
4172 cx.run_until_parked();
4173 let result = send_task.await;
4174
4175 assert!(
4176 result.is_ok(),
4177 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4178 result.err()
4179 );
4180 }
4181}