1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use serde_json::to_string_pretty;
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::{ActionLog, ActionLogTelemetry};
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47 pub indented: bool,
48}
49
50#[derive(Debug)]
51pub struct Checkpoint {
52 git_checkpoint: GitStoreCheckpoint,
53 pub show: bool,
54}
55
56impl UserMessage {
57 fn to_markdown(&self, cx: &App) -> String {
58 let mut markdown = String::new();
59 if self
60 .checkpoint
61 .as_ref()
62 .is_some_and(|checkpoint| checkpoint.show)
63 {
64 writeln!(markdown, "## User (checkpoint)").unwrap();
65 } else {
66 writeln!(markdown, "## User").unwrap();
67 }
68 writeln!(markdown).unwrap();
69 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
70 writeln!(markdown).unwrap();
71 markdown
72 }
73}
74
75#[derive(Debug, PartialEq)]
76pub struct AssistantMessage {
77 pub chunks: Vec<AssistantMessageChunk>,
78 pub indented: bool,
79}
80
81impl AssistantMessage {
82 pub fn to_markdown(&self, cx: &App) -> String {
83 format!(
84 "## Assistant\n\n{}\n\n",
85 self.chunks
86 .iter()
87 .map(|chunk| chunk.to_markdown(cx))
88 .join("\n\n")
89 )
90 }
91}
92
93#[derive(Debug, PartialEq)]
94pub enum AssistantMessageChunk {
95 Message { block: ContentBlock },
96 Thought { block: ContentBlock },
97}
98
99impl AssistantMessageChunk {
100 pub fn from_str(
101 chunk: &str,
102 language_registry: &Arc<LanguageRegistry>,
103 path_style: PathStyle,
104 cx: &mut App,
105 ) -> Self {
106 Self::Message {
107 block: ContentBlock::new(chunk.into(), language_registry, path_style, cx),
108 }
109 }
110
111 fn to_markdown(&self, cx: &App) -> String {
112 match self {
113 Self::Message { block } => block.to_markdown(cx).to_string(),
114 Self::Thought { block } => {
115 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
116 }
117 }
118 }
119}
120
121#[derive(Debug)]
122pub enum AgentThreadEntry {
123 UserMessage(UserMessage),
124 AssistantMessage(AssistantMessage),
125 ToolCall(ToolCall),
126}
127
128impl AgentThreadEntry {
129 pub fn is_indented(&self) -> bool {
130 match self {
131 Self::UserMessage(message) => message.indented,
132 Self::AssistantMessage(message) => message.indented,
133 Self::ToolCall(_) => false,
134 }
135 }
136
137 pub fn to_markdown(&self, cx: &App) -> String {
138 match self {
139 Self::UserMessage(message) => message.to_markdown(cx),
140 Self::AssistantMessage(message) => message.to_markdown(cx),
141 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
142 }
143 }
144
145 pub fn user_message(&self) -> Option<&UserMessage> {
146 if let AgentThreadEntry::UserMessage(message) = self {
147 Some(message)
148 } else {
149 None
150 }
151 }
152
153 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
154 if let AgentThreadEntry::ToolCall(call) = self {
155 itertools::Either::Left(call.diffs())
156 } else {
157 itertools::Either::Right(std::iter::empty())
158 }
159 }
160
161 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
162 if let AgentThreadEntry::ToolCall(call) = self {
163 itertools::Either::Left(call.terminals())
164 } else {
165 itertools::Either::Right(std::iter::empty())
166 }
167 }
168
169 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
170 if let AgentThreadEntry::ToolCall(ToolCall {
171 locations,
172 resolved_locations,
173 ..
174 }) = self
175 {
176 Some((
177 locations.get(ix)?.clone(),
178 resolved_locations.get(ix)?.clone()?,
179 ))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug)]
187pub struct ToolCall {
188 pub id: acp::ToolCallId,
189 pub label: Entity<Markdown>,
190 pub kind: acp::ToolKind,
191 pub content: Vec<ToolCallContent>,
192 pub status: ToolCallStatus,
193 pub locations: Vec<acp::ToolCallLocation>,
194 pub resolved_locations: Vec<Option<AgentLocation>>,
195 pub raw_input: Option<serde_json::Value>,
196 pub raw_input_markdown: Option<Entity<Markdown>>,
197 pub raw_output: Option<serde_json::Value>,
198}
199
200impl ToolCall {
201 fn from_acp(
202 tool_call: acp::ToolCall,
203 status: ToolCallStatus,
204 language_registry: Arc<LanguageRegistry>,
205 path_style: PathStyle,
206 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
207 cx: &mut App,
208 ) -> Result<Self> {
209 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
210 first_line.to_owned() + "…"
211 } else {
212 tool_call.title
213 };
214 let mut content = Vec::with_capacity(tool_call.content.len());
215 for item in tool_call.content {
216 if let Some(item) = ToolCallContent::from_acp(
217 item,
218 language_registry.clone(),
219 path_style,
220 terminals,
221 cx,
222 )? {
223 content.push(item);
224 }
225 }
226
227 let raw_input_markdown = tool_call
228 .raw_input
229 .as_ref()
230 .and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
231
232 let result = Self {
233 id: tool_call.tool_call_id,
234 label: cx
235 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
236 kind: tool_call.kind,
237 content,
238 locations: tool_call.locations,
239 resolved_locations: Vec::default(),
240 status,
241 raw_input: tool_call.raw_input,
242 raw_input_markdown,
243 raw_output: tool_call.raw_output,
244 };
245 Ok(result)
246 }
247
248 fn update_fields(
249 &mut self,
250 fields: acp::ToolCallUpdateFields,
251 language_registry: Arc<LanguageRegistry>,
252 path_style: PathStyle,
253 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
254 cx: &mut App,
255 ) -> Result<()> {
256 let acp::ToolCallUpdateFields {
257 kind,
258 status,
259 title,
260 content,
261 locations,
262 raw_input,
263 raw_output,
264 ..
265 } = fields;
266
267 if let Some(kind) = kind {
268 self.kind = kind;
269 }
270
271 if let Some(status) = status {
272 self.status = status.into();
273 }
274
275 if let Some(title) = title {
276 self.label.update(cx, |label, cx| {
277 if let Some((first_line, _)) = title.split_once("\n") {
278 label.replace(first_line.to_owned() + "…", cx)
279 } else {
280 label.replace(title, cx);
281 }
282 });
283 }
284
285 if let Some(content) = content {
286 let mut new_content_len = content.len();
287 let mut content = content.into_iter();
288
289 // Reuse existing content if we can
290 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
291 let valid_content =
292 old.update_from_acp(new, language_registry.clone(), path_style, terminals, cx)?;
293 if !valid_content {
294 new_content_len -= 1;
295 }
296 }
297 for new in content {
298 if let Some(new) = ToolCallContent::from_acp(
299 new,
300 language_registry.clone(),
301 path_style,
302 terminals,
303 cx,
304 )? {
305 self.content.push(new);
306 } else {
307 new_content_len -= 1;
308 }
309 }
310 self.content.truncate(new_content_len);
311 }
312
313 if let Some(locations) = locations {
314 self.locations = locations;
315 }
316
317 if let Some(raw_input) = raw_input {
318 self.raw_input_markdown = markdown_for_raw_output(&raw_input, &language_registry, cx);
319 self.raw_input = Some(raw_input);
320 }
321
322 if let Some(raw_output) = raw_output {
323 if self.content.is_empty()
324 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
325 {
326 self.content
327 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
328 markdown,
329 }));
330 }
331 self.raw_output = Some(raw_output);
332 }
333 Ok(())
334 }
335
336 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
337 self.content.iter().filter_map(|content| match content {
338 ToolCallContent::Diff(diff) => Some(diff),
339 ToolCallContent::ContentBlock(_) => None,
340 ToolCallContent::Terminal(_) => None,
341 })
342 }
343
344 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
345 self.content.iter().filter_map(|content| match content {
346 ToolCallContent::Terminal(terminal) => Some(terminal),
347 ToolCallContent::ContentBlock(_) => None,
348 ToolCallContent::Diff(_) => None,
349 })
350 }
351
352 fn to_markdown(&self, cx: &App) -> String {
353 let mut markdown = format!(
354 "**Tool Call: {}**\nStatus: {}\n\n",
355 self.label.read(cx).source(),
356 self.status
357 );
358 for content in &self.content {
359 markdown.push_str(content.to_markdown(cx).as_str());
360 markdown.push_str("\n\n");
361 }
362 markdown
363 }
364
365 async fn resolve_location(
366 location: acp::ToolCallLocation,
367 project: WeakEntity<Project>,
368 cx: &mut AsyncApp,
369 ) -> Option<ResolvedLocation> {
370 let buffer = project
371 .update(cx, |project, cx| {
372 project
373 .project_path_for_absolute_path(&location.path, cx)
374 .map(|path| project.open_buffer(path, cx))
375 })
376 .ok()??;
377 let buffer = buffer.await.log_err()?;
378 let position = buffer
379 .update(cx, |buffer, _| {
380 let snapshot = buffer.snapshot();
381 if let Some(row) = location.line {
382 let column = snapshot.indent_size_for_line(row).len;
383 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
384 snapshot.anchor_before(point)
385 } else {
386 Anchor::min_for_buffer(snapshot.remote_id())
387 }
388 })
389 .ok()?;
390
391 Some(ResolvedLocation { buffer, position })
392 }
393
394 fn resolve_locations(
395 &self,
396 project: Entity<Project>,
397 cx: &mut App,
398 ) -> Task<Vec<Option<ResolvedLocation>>> {
399 let locations = self.locations.clone();
400 project.update(cx, |_, cx| {
401 cx.spawn(async move |project, cx| {
402 let mut new_locations = Vec::new();
403 for location in locations {
404 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
405 }
406 new_locations
407 })
408 })
409 }
410}
411
412// Separate so we can hold a strong reference to the buffer
413// for saving on the thread
414#[derive(Clone, Debug, PartialEq, Eq)]
415struct ResolvedLocation {
416 buffer: Entity<Buffer>,
417 position: Anchor,
418}
419
420impl From<&ResolvedLocation> for AgentLocation {
421 fn from(value: &ResolvedLocation) -> Self {
422 Self {
423 buffer: value.buffer.downgrade(),
424 position: value.position,
425 }
426 }
427}
428
429#[derive(Debug)]
430pub enum ToolCallStatus {
431 /// The tool call hasn't started running yet, but we start showing it to
432 /// the user.
433 Pending,
434 /// The tool call is waiting for confirmation from the user.
435 WaitingForConfirmation {
436 options: Vec<acp::PermissionOption>,
437 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
438 },
439 /// The tool call is currently running.
440 InProgress,
441 /// The tool call completed successfully.
442 Completed,
443 /// The tool call failed.
444 Failed,
445 /// The user rejected the tool call.
446 Rejected,
447 /// The user canceled generation so the tool call was canceled.
448 Canceled,
449}
450
451impl From<acp::ToolCallStatus> for ToolCallStatus {
452 fn from(status: acp::ToolCallStatus) -> Self {
453 match status {
454 acp::ToolCallStatus::Pending => Self::Pending,
455 acp::ToolCallStatus::InProgress => Self::InProgress,
456 acp::ToolCallStatus::Completed => Self::Completed,
457 acp::ToolCallStatus::Failed => Self::Failed,
458 _ => Self::Pending,
459 }
460 }
461}
462
463impl Display for ToolCallStatus {
464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465 write!(
466 f,
467 "{}",
468 match self {
469 ToolCallStatus::Pending => "Pending",
470 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
471 ToolCallStatus::InProgress => "In Progress",
472 ToolCallStatus::Completed => "Completed",
473 ToolCallStatus::Failed => "Failed",
474 ToolCallStatus::Rejected => "Rejected",
475 ToolCallStatus::Canceled => "Canceled",
476 }
477 )
478 }
479}
480
481#[derive(Debug, PartialEq, Clone)]
482pub enum ContentBlock {
483 Empty,
484 Markdown { markdown: Entity<Markdown> },
485 ResourceLink { resource_link: acp::ResourceLink },
486 Image { image: Arc<gpui::Image> },
487}
488
489impl ContentBlock {
490 pub fn new(
491 block: acp::ContentBlock,
492 language_registry: &Arc<LanguageRegistry>,
493 path_style: PathStyle,
494 cx: &mut App,
495 ) -> Self {
496 let mut this = Self::Empty;
497 this.append(block, language_registry, path_style, cx);
498 this
499 }
500
501 pub fn new_combined(
502 blocks: impl IntoIterator<Item = acp::ContentBlock>,
503 language_registry: Arc<LanguageRegistry>,
504 path_style: PathStyle,
505 cx: &mut App,
506 ) -> Self {
507 let mut this = Self::Empty;
508 for block in blocks {
509 this.append(block, &language_registry, path_style, cx);
510 }
511 this
512 }
513
514 pub fn append(
515 &mut self,
516 block: acp::ContentBlock,
517 language_registry: &Arc<LanguageRegistry>,
518 path_style: PathStyle,
519 cx: &mut App,
520 ) {
521 match (&mut *self, &block) {
522 (ContentBlock::Empty, acp::ContentBlock::ResourceLink(resource_link)) => {
523 *self = ContentBlock::ResourceLink {
524 resource_link: resource_link.clone(),
525 };
526 }
527 (ContentBlock::Empty, acp::ContentBlock::Image(image_content)) => {
528 if let Some(image) = Self::decode_image(image_content) {
529 *self = ContentBlock::Image { image };
530 } else {
531 let new_content = Self::image_md(image_content);
532 *self = Self::create_markdown_block(new_content, language_registry, cx);
533 }
534 }
535 (ContentBlock::Empty, _) => {
536 let new_content = Self::block_string_contents(&block, path_style);
537 *self = Self::create_markdown_block(new_content, language_registry, cx);
538 }
539 (ContentBlock::Markdown { markdown }, _) => {
540 let new_content = Self::block_string_contents(&block, path_style);
541 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
542 }
543 (ContentBlock::ResourceLink { resource_link }, _) => {
544 let existing_content = Self::resource_link_md(&resource_link.uri, path_style);
545 let new_content = Self::block_string_contents(&block, path_style);
546 let combined = format!("{}\n{}", existing_content, new_content);
547 *self = Self::create_markdown_block(combined, language_registry, cx);
548 }
549 (ContentBlock::Image { .. }, _) => {
550 let new_content = Self::block_string_contents(&block, path_style);
551 let combined = format!("`Image`\n{}", new_content);
552 *self = Self::create_markdown_block(combined, language_registry, cx);
553 }
554 }
555 }
556
557 fn decode_image(image_content: &acp::ImageContent) -> Option<Arc<gpui::Image>> {
558 use base64::Engine as _;
559
560 let bytes = base64::engine::general_purpose::STANDARD
561 .decode(image_content.data.as_bytes())
562 .ok()?;
563 let format = gpui::ImageFormat::from_mime_type(&image_content.mime_type)?;
564 Some(Arc::new(gpui::Image::from_bytes(format, bytes)))
565 }
566
567 fn create_markdown_block(
568 content: String,
569 language_registry: &Arc<LanguageRegistry>,
570 cx: &mut App,
571 ) -> ContentBlock {
572 ContentBlock::Markdown {
573 markdown: cx
574 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
575 }
576 }
577
578 fn block_string_contents(block: &acp::ContentBlock, path_style: PathStyle) -> String {
579 match block {
580 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
581 acp::ContentBlock::ResourceLink(resource_link) => {
582 Self::resource_link_md(&resource_link.uri, path_style)
583 }
584 acp::ContentBlock::Resource(acp::EmbeddedResource {
585 resource:
586 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
587 uri,
588 ..
589 }),
590 ..
591 }) => Self::resource_link_md(uri, path_style),
592 acp::ContentBlock::Image(image) => Self::image_md(image),
593 _ => String::new(),
594 }
595 }
596
597 fn resource_link_md(uri: &str, path_style: PathStyle) -> String {
598 if let Some(uri) = MentionUri::parse(uri, path_style).log_err() {
599 uri.as_link().to_string()
600 } else {
601 uri.to_string()
602 }
603 }
604
605 fn image_md(_image: &acp::ImageContent) -> String {
606 "`Image`".into()
607 }
608
609 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
610 match self {
611 ContentBlock::Empty => "",
612 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
613 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
614 ContentBlock::Image { .. } => "`Image`",
615 }
616 }
617
618 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
619 match self {
620 ContentBlock::Empty => None,
621 ContentBlock::Markdown { markdown } => Some(markdown),
622 ContentBlock::ResourceLink { .. } => None,
623 ContentBlock::Image { .. } => None,
624 }
625 }
626
627 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
628 match self {
629 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
630 _ => None,
631 }
632 }
633
634 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
635 match self {
636 ContentBlock::Image { image } => Some(image),
637 _ => None,
638 }
639 }
640}
641
642#[derive(Debug)]
643pub enum ToolCallContent {
644 ContentBlock(ContentBlock),
645 Diff(Entity<Diff>),
646 Terminal(Entity<Terminal>),
647}
648
649impl ToolCallContent {
650 pub fn from_acp(
651 content: acp::ToolCallContent,
652 language_registry: Arc<LanguageRegistry>,
653 path_style: PathStyle,
654 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
655 cx: &mut App,
656 ) -> Result<Option<Self>> {
657 match content {
658 acp::ToolCallContent::Content(acp::Content { content, .. }) => {
659 Ok(Some(Self::ContentBlock(ContentBlock::new(
660 content,
661 &language_registry,
662 path_style,
663 cx,
664 ))))
665 }
666 acp::ToolCallContent::Diff(diff) => Ok(Some(Self::Diff(cx.new(|cx| {
667 Diff::finalized(
668 diff.path.to_string_lossy().into_owned(),
669 diff.old_text,
670 diff.new_text,
671 language_registry,
672 cx,
673 )
674 })))),
675 acp::ToolCallContent::Terminal(acp::Terminal { terminal_id, .. }) => terminals
676 .get(&terminal_id)
677 .cloned()
678 .map(|terminal| Some(Self::Terminal(terminal)))
679 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
680 _ => Ok(None),
681 }
682 }
683
684 pub fn update_from_acp(
685 &mut self,
686 new: acp::ToolCallContent,
687 language_registry: Arc<LanguageRegistry>,
688 path_style: PathStyle,
689 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
690 cx: &mut App,
691 ) -> Result<bool> {
692 let needs_update = match (&self, &new) {
693 (Self::Diff(old_diff), acp::ToolCallContent::Diff(new_diff)) => {
694 old_diff.read(cx).needs_update(
695 new_diff.old_text.as_deref().unwrap_or(""),
696 &new_diff.new_text,
697 cx,
698 )
699 }
700 _ => true,
701 };
702
703 if let Some(update) = Self::from_acp(new, language_registry, path_style, terminals, cx)? {
704 if needs_update {
705 *self = update;
706 }
707 Ok(true)
708 } else {
709 Ok(false)
710 }
711 }
712
713 pub fn to_markdown(&self, cx: &App) -> String {
714 match self {
715 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
716 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
717 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
718 }
719 }
720
721 pub fn image(&self) -> Option<&Arc<gpui::Image>> {
722 match self {
723 Self::ContentBlock(content) => content.image(),
724 _ => None,
725 }
726 }
727}
728
729#[derive(Debug, PartialEq)]
730pub enum ToolCallUpdate {
731 UpdateFields(acp::ToolCallUpdate),
732 UpdateDiff(ToolCallUpdateDiff),
733 UpdateTerminal(ToolCallUpdateTerminal),
734}
735
736impl ToolCallUpdate {
737 fn id(&self) -> &acp::ToolCallId {
738 match self {
739 Self::UpdateFields(update) => &update.tool_call_id,
740 Self::UpdateDiff(diff) => &diff.id,
741 Self::UpdateTerminal(terminal) => &terminal.id,
742 }
743 }
744}
745
746impl From<acp::ToolCallUpdate> for ToolCallUpdate {
747 fn from(update: acp::ToolCallUpdate) -> Self {
748 Self::UpdateFields(update)
749 }
750}
751
752impl From<ToolCallUpdateDiff> for ToolCallUpdate {
753 fn from(diff: ToolCallUpdateDiff) -> Self {
754 Self::UpdateDiff(diff)
755 }
756}
757
758#[derive(Debug, PartialEq)]
759pub struct ToolCallUpdateDiff {
760 pub id: acp::ToolCallId,
761 pub diff: Entity<Diff>,
762}
763
764impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
765 fn from(terminal: ToolCallUpdateTerminal) -> Self {
766 Self::UpdateTerminal(terminal)
767 }
768}
769
770#[derive(Debug, PartialEq)]
771pub struct ToolCallUpdateTerminal {
772 pub id: acp::ToolCallId,
773 pub terminal: Entity<Terminal>,
774}
775
776#[derive(Debug, Default)]
777pub struct Plan {
778 pub entries: Vec<PlanEntry>,
779}
780
781#[derive(Debug)]
782pub struct PlanStats<'a> {
783 pub in_progress_entry: Option<&'a PlanEntry>,
784 pub pending: u32,
785 pub completed: u32,
786}
787
788impl Plan {
789 pub fn is_empty(&self) -> bool {
790 self.entries.is_empty()
791 }
792
793 pub fn stats(&self) -> PlanStats<'_> {
794 let mut stats = PlanStats {
795 in_progress_entry: None,
796 pending: 0,
797 completed: 0,
798 };
799
800 for entry in &self.entries {
801 match &entry.status {
802 acp::PlanEntryStatus::Pending => {
803 stats.pending += 1;
804 }
805 acp::PlanEntryStatus::InProgress => {
806 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
807 }
808 acp::PlanEntryStatus::Completed => {
809 stats.completed += 1;
810 }
811 _ => {}
812 }
813 }
814
815 stats
816 }
817}
818
819#[derive(Debug)]
820pub struct PlanEntry {
821 pub content: Entity<Markdown>,
822 pub priority: acp::PlanEntryPriority,
823 pub status: acp::PlanEntryStatus,
824}
825
826impl PlanEntry {
827 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
828 Self {
829 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
830 priority: entry.priority,
831 status: entry.status,
832 }
833 }
834}
835
836#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
837pub struct TokenUsage {
838 pub max_tokens: u64,
839 pub used_tokens: u64,
840}
841
842impl TokenUsage {
843 pub fn ratio(&self) -> TokenUsageRatio {
844 #[cfg(debug_assertions)]
845 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
846 .unwrap_or("0.8".to_string())
847 .parse()
848 .unwrap();
849 #[cfg(not(debug_assertions))]
850 let warning_threshold: f32 = 0.8;
851
852 // When the maximum is unknown because there is no selected model,
853 // avoid showing the token limit warning.
854 if self.max_tokens == 0 {
855 TokenUsageRatio::Normal
856 } else if self.used_tokens >= self.max_tokens {
857 TokenUsageRatio::Exceeded
858 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
859 TokenUsageRatio::Warning
860 } else {
861 TokenUsageRatio::Normal
862 }
863 }
864}
865
866#[derive(Debug, Clone, PartialEq, Eq)]
867pub enum TokenUsageRatio {
868 Normal,
869 Warning,
870 Exceeded,
871}
872
873#[derive(Debug, Clone)]
874pub struct RetryStatus {
875 pub last_error: SharedString,
876 pub attempt: usize,
877 pub max_attempts: usize,
878 pub started_at: Instant,
879 pub duration: Duration,
880}
881
882pub struct AcpThread {
883 title: SharedString,
884 entries: Vec<AgentThreadEntry>,
885 plan: Plan,
886 project: Entity<Project>,
887 action_log: Entity<ActionLog>,
888 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
889 send_task: Option<Task<()>>,
890 connection: Rc<dyn AgentConnection>,
891 session_id: acp::SessionId,
892 token_usage: Option<TokenUsage>,
893 prompt_capabilities: acp::PromptCapabilities,
894 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
895 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
896 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
897 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
898}
899
900impl From<&AcpThread> for ActionLogTelemetry {
901 fn from(value: &AcpThread) -> Self {
902 Self {
903 agent_telemetry_id: value.connection().telemetry_id(),
904 session_id: value.session_id.0.clone(),
905 }
906 }
907}
908
909#[derive(Debug)]
910pub enum AcpThreadEvent {
911 NewEntry,
912 TitleUpdated,
913 TokenUsageUpdated,
914 EntryUpdated(usize),
915 EntriesRemoved(Range<usize>),
916 ToolAuthorizationRequired,
917 Retry(RetryStatus),
918 Stopped,
919 Error,
920 LoadError(LoadError),
921 PromptCapabilitiesUpdated,
922 Refusal,
923 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
924 ModeUpdated(acp::SessionModeId),
925 ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
926}
927
928impl EventEmitter<AcpThreadEvent> for AcpThread {}
929
930#[derive(Debug, Clone)]
931pub enum TerminalProviderEvent {
932 Created {
933 terminal_id: acp::TerminalId,
934 label: String,
935 cwd: Option<PathBuf>,
936 output_byte_limit: Option<u64>,
937 terminal: Entity<::terminal::Terminal>,
938 },
939 Output {
940 terminal_id: acp::TerminalId,
941 data: Vec<u8>,
942 },
943 TitleChanged {
944 terminal_id: acp::TerminalId,
945 title: String,
946 },
947 Exit {
948 terminal_id: acp::TerminalId,
949 status: acp::TerminalExitStatus,
950 },
951}
952
953#[derive(Debug, Clone)]
954pub enum TerminalProviderCommand {
955 WriteInput {
956 terminal_id: acp::TerminalId,
957 bytes: Vec<u8>,
958 },
959 Resize {
960 terminal_id: acp::TerminalId,
961 cols: u16,
962 rows: u16,
963 },
964 Close {
965 terminal_id: acp::TerminalId,
966 },
967}
968
969impl AcpThread {
970 pub fn on_terminal_provider_event(
971 &mut self,
972 event: TerminalProviderEvent,
973 cx: &mut Context<Self>,
974 ) {
975 match event {
976 TerminalProviderEvent::Created {
977 terminal_id,
978 label,
979 cwd,
980 output_byte_limit,
981 terminal,
982 } => {
983 let entity = self.register_terminal_created(
984 terminal_id.clone(),
985 label,
986 cwd,
987 output_byte_limit,
988 terminal,
989 cx,
990 );
991
992 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
993 for data in chunks.drain(..) {
994 entity.update(cx, |term, cx| {
995 term.inner().update(cx, |inner, cx| {
996 inner.write_output(&data, cx);
997 })
998 });
999 }
1000 }
1001
1002 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
1003 entity.update(cx, |_term, cx| {
1004 cx.notify();
1005 });
1006 }
1007
1008 cx.notify();
1009 }
1010 TerminalProviderEvent::Output { terminal_id, data } => {
1011 if let Some(entity) = self.terminals.get(&terminal_id) {
1012 entity.update(cx, |term, cx| {
1013 term.inner().update(cx, |inner, cx| {
1014 inner.write_output(&data, cx);
1015 })
1016 });
1017 } else {
1018 self.pending_terminal_output
1019 .entry(terminal_id)
1020 .or_default()
1021 .push(data);
1022 }
1023 }
1024 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
1025 if let Some(entity) = self.terminals.get(&terminal_id) {
1026 entity.update(cx, |term, cx| {
1027 term.inner().update(cx, |inner, cx| {
1028 inner.breadcrumb_text = title;
1029 cx.emit(::terminal::Event::BreadcrumbsChanged);
1030 })
1031 });
1032 }
1033 }
1034 TerminalProviderEvent::Exit {
1035 terminal_id,
1036 status,
1037 } => {
1038 if let Some(entity) = self.terminals.get(&terminal_id) {
1039 entity.update(cx, |_term, cx| {
1040 cx.notify();
1041 });
1042 } else {
1043 self.pending_terminal_exit.insert(terminal_id, status);
1044 }
1045 }
1046 }
1047 }
1048}
1049
1050#[derive(PartialEq, Eq, Debug)]
1051pub enum ThreadStatus {
1052 Idle,
1053 Generating,
1054}
1055
1056#[derive(Debug, Clone)]
1057pub enum LoadError {
1058 Unsupported {
1059 command: SharedString,
1060 current_version: SharedString,
1061 minimum_version: SharedString,
1062 },
1063 FailedToInstall(SharedString),
1064 Exited {
1065 status: ExitStatus,
1066 },
1067 Other(SharedString),
1068}
1069
1070impl Display for LoadError {
1071 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1072 match self {
1073 LoadError::Unsupported {
1074 command: path,
1075 current_version,
1076 minimum_version,
1077 } => {
1078 write!(
1079 f,
1080 "version {current_version} from {path} is not supported (need at least {minimum_version})"
1081 )
1082 }
1083 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
1084 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
1085 LoadError::Other(msg) => write!(f, "{msg}"),
1086 }
1087 }
1088}
1089
1090impl Error for LoadError {}
1091
1092impl AcpThread {
1093 pub fn new(
1094 title: impl Into<SharedString>,
1095 connection: Rc<dyn AgentConnection>,
1096 project: Entity<Project>,
1097 action_log: Entity<ActionLog>,
1098 session_id: acp::SessionId,
1099 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
1100 cx: &mut Context<Self>,
1101 ) -> Self {
1102 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1103 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1104 loop {
1105 let caps = prompt_capabilities_rx.recv().await?;
1106 this.update(cx, |this, cx| {
1107 this.prompt_capabilities = caps;
1108 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1109 })?;
1110 }
1111 });
1112
1113 Self {
1114 action_log,
1115 shared_buffers: Default::default(),
1116 entries: Default::default(),
1117 plan: Default::default(),
1118 title: title.into(),
1119 project,
1120 send_task: None,
1121 connection,
1122 session_id,
1123 token_usage: None,
1124 prompt_capabilities,
1125 _observe_prompt_capabilities: task,
1126 terminals: HashMap::default(),
1127 pending_terminal_output: HashMap::default(),
1128 pending_terminal_exit: HashMap::default(),
1129 }
1130 }
1131
1132 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1133 self.prompt_capabilities.clone()
1134 }
1135
1136 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1137 &self.connection
1138 }
1139
1140 pub fn action_log(&self) -> &Entity<ActionLog> {
1141 &self.action_log
1142 }
1143
1144 pub fn project(&self) -> &Entity<Project> {
1145 &self.project
1146 }
1147
1148 pub fn title(&self) -> SharedString {
1149 self.title.clone()
1150 }
1151
1152 pub fn entries(&self) -> &[AgentThreadEntry] {
1153 &self.entries
1154 }
1155
1156 pub fn session_id(&self) -> &acp::SessionId {
1157 &self.session_id
1158 }
1159
1160 pub fn status(&self) -> ThreadStatus {
1161 if self.send_task.is_some() {
1162 ThreadStatus::Generating
1163 } else {
1164 ThreadStatus::Idle
1165 }
1166 }
1167
1168 pub fn token_usage(&self) -> Option<&TokenUsage> {
1169 self.token_usage.as_ref()
1170 }
1171
1172 pub fn has_pending_edit_tool_calls(&self) -> bool {
1173 for entry in self.entries.iter().rev() {
1174 match entry {
1175 AgentThreadEntry::UserMessage(_) => return false,
1176 AgentThreadEntry::ToolCall(
1177 call @ ToolCall {
1178 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1179 ..
1180 },
1181 ) if call.diffs().next().is_some() => {
1182 return true;
1183 }
1184 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1185 }
1186 }
1187
1188 false
1189 }
1190
1191 pub fn used_tools_since_last_user_message(&self) -> bool {
1192 for entry in self.entries.iter().rev() {
1193 match entry {
1194 AgentThreadEntry::UserMessage(..) => return false,
1195 AgentThreadEntry::AssistantMessage(..) => continue,
1196 AgentThreadEntry::ToolCall(..) => return true,
1197 }
1198 }
1199
1200 false
1201 }
1202
1203 pub fn handle_session_update(
1204 &mut self,
1205 update: acp::SessionUpdate,
1206 cx: &mut Context<Self>,
1207 ) -> Result<(), acp::Error> {
1208 match update {
1209 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1210 self.push_user_content_block(None, content, cx);
1211 }
1212 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1213 self.push_assistant_content_block(content, false, cx);
1214 }
1215 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1216 self.push_assistant_content_block(content, true, cx);
1217 }
1218 acp::SessionUpdate::ToolCall(tool_call) => {
1219 self.upsert_tool_call(tool_call, cx)?;
1220 }
1221 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1222 self.update_tool_call(tool_call_update, cx)?;
1223 }
1224 acp::SessionUpdate::Plan(plan) => {
1225 self.update_plan(plan, cx);
1226 }
1227 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1228 available_commands,
1229 ..
1230 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1231 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1232 current_mode_id,
1233 ..
1234 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1235 acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
1236 config_options,
1237 ..
1238 }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
1239 _ => {}
1240 }
1241 Ok(())
1242 }
1243
1244 pub fn push_user_content_block(
1245 &mut self,
1246 message_id: Option<UserMessageId>,
1247 chunk: acp::ContentBlock,
1248 cx: &mut Context<Self>,
1249 ) {
1250 self.push_user_content_block_with_indent(message_id, chunk, false, cx)
1251 }
1252
1253 pub fn push_user_content_block_with_indent(
1254 &mut self,
1255 message_id: Option<UserMessageId>,
1256 chunk: acp::ContentBlock,
1257 indented: bool,
1258 cx: &mut Context<Self>,
1259 ) {
1260 let language_registry = self.project.read(cx).languages().clone();
1261 let path_style = self.project.read(cx).path_style(cx);
1262 let entries_len = self.entries.len();
1263
1264 if let Some(last_entry) = self.entries.last_mut()
1265 && let AgentThreadEntry::UserMessage(UserMessage {
1266 id,
1267 content,
1268 chunks,
1269 indented: existing_indented,
1270 ..
1271 }) = last_entry
1272 && *existing_indented == indented
1273 {
1274 *id = message_id.or(id.take());
1275 content.append(chunk.clone(), &language_registry, path_style, cx);
1276 chunks.push(chunk);
1277 let idx = entries_len - 1;
1278 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1279 } else {
1280 let content = ContentBlock::new(chunk.clone(), &language_registry, path_style, cx);
1281 self.push_entry(
1282 AgentThreadEntry::UserMessage(UserMessage {
1283 id: message_id,
1284 content,
1285 chunks: vec![chunk],
1286 checkpoint: None,
1287 indented,
1288 }),
1289 cx,
1290 );
1291 }
1292 }
1293
1294 pub fn push_assistant_content_block(
1295 &mut self,
1296 chunk: acp::ContentBlock,
1297 is_thought: bool,
1298 cx: &mut Context<Self>,
1299 ) {
1300 self.push_assistant_content_block_with_indent(chunk, is_thought, false, cx)
1301 }
1302
1303 pub fn push_assistant_content_block_with_indent(
1304 &mut self,
1305 chunk: acp::ContentBlock,
1306 is_thought: bool,
1307 indented: bool,
1308 cx: &mut Context<Self>,
1309 ) {
1310 let language_registry = self.project.read(cx).languages().clone();
1311 let path_style = self.project.read(cx).path_style(cx);
1312 let entries_len = self.entries.len();
1313 if let Some(last_entry) = self.entries.last_mut()
1314 && let AgentThreadEntry::AssistantMessage(AssistantMessage {
1315 chunks,
1316 indented: existing_indented,
1317 }) = last_entry
1318 && *existing_indented == indented
1319 {
1320 let idx = entries_len - 1;
1321 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1322 match (chunks.last_mut(), is_thought) {
1323 (Some(AssistantMessageChunk::Message { block }), false)
1324 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1325 block.append(chunk, &language_registry, path_style, cx)
1326 }
1327 _ => {
1328 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1329 if is_thought {
1330 chunks.push(AssistantMessageChunk::Thought { block })
1331 } else {
1332 chunks.push(AssistantMessageChunk::Message { block })
1333 }
1334 }
1335 }
1336 } else {
1337 let block = ContentBlock::new(chunk, &language_registry, path_style, cx);
1338 let chunk = if is_thought {
1339 AssistantMessageChunk::Thought { block }
1340 } else {
1341 AssistantMessageChunk::Message { block }
1342 };
1343
1344 self.push_entry(
1345 AgentThreadEntry::AssistantMessage(AssistantMessage {
1346 chunks: vec![chunk],
1347 indented,
1348 }),
1349 cx,
1350 );
1351 }
1352 }
1353
1354 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1355 self.entries.push(entry);
1356 cx.emit(AcpThreadEvent::NewEntry);
1357 }
1358
1359 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1360 self.connection.set_title(&self.session_id, cx).is_some()
1361 }
1362
1363 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1364 if title != self.title {
1365 self.title = title.clone();
1366 cx.emit(AcpThreadEvent::TitleUpdated);
1367 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1368 return set_title.run(title, cx);
1369 }
1370 }
1371 Task::ready(Ok(()))
1372 }
1373
1374 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1375 self.token_usage = usage;
1376 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1377 }
1378
1379 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1380 cx.emit(AcpThreadEvent::Retry(status));
1381 }
1382
1383 pub fn update_tool_call(
1384 &mut self,
1385 update: impl Into<ToolCallUpdate>,
1386 cx: &mut Context<Self>,
1387 ) -> Result<()> {
1388 let update = update.into();
1389 let languages = self.project.read(cx).languages().clone();
1390 let path_style = self.project.read(cx).path_style(cx);
1391
1392 let ix = match self.index_for_tool_call(update.id()) {
1393 Some(ix) => ix,
1394 None => {
1395 // Tool call not found - create a failed tool call entry
1396 let failed_tool_call = ToolCall {
1397 id: update.id().clone(),
1398 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1399 kind: acp::ToolKind::Fetch,
1400 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1401 "Tool call not found".into(),
1402 &languages,
1403 path_style,
1404 cx,
1405 ))],
1406 status: ToolCallStatus::Failed,
1407 locations: Vec::new(),
1408 resolved_locations: Vec::new(),
1409 raw_input: None,
1410 raw_input_markdown: None,
1411 raw_output: None,
1412 };
1413 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1414 return Ok(());
1415 }
1416 };
1417 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1418 unreachable!()
1419 };
1420
1421 match update {
1422 ToolCallUpdate::UpdateFields(update) => {
1423 let location_updated = update.fields.locations.is_some();
1424 call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?;
1425 if location_updated {
1426 self.resolve_locations(update.tool_call_id, cx);
1427 }
1428 }
1429 ToolCallUpdate::UpdateDiff(update) => {
1430 call.content.clear();
1431 call.content.push(ToolCallContent::Diff(update.diff));
1432 }
1433 ToolCallUpdate::UpdateTerminal(update) => {
1434 call.content.clear();
1435 call.content
1436 .push(ToolCallContent::Terminal(update.terminal));
1437 }
1438 }
1439
1440 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1441
1442 Ok(())
1443 }
1444
1445 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1446 pub fn upsert_tool_call(
1447 &mut self,
1448 tool_call: acp::ToolCall,
1449 cx: &mut Context<Self>,
1450 ) -> Result<(), acp::Error> {
1451 let status = tool_call.status.into();
1452 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1453 }
1454
1455 /// Fails if id does not match an existing entry.
1456 pub fn upsert_tool_call_inner(
1457 &mut self,
1458 update: acp::ToolCallUpdate,
1459 status: ToolCallStatus,
1460 cx: &mut Context<Self>,
1461 ) -> Result<(), acp::Error> {
1462 let language_registry = self.project.read(cx).languages().clone();
1463 let path_style = self.project.read(cx).path_style(cx);
1464 let id = update.tool_call_id.clone();
1465
1466 let agent_telemetry_id = self.connection().telemetry_id();
1467 let session = self.session_id();
1468 if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
1469 let status = if matches!(status, ToolCallStatus::Completed) {
1470 "completed"
1471 } else {
1472 "failed"
1473 };
1474 telemetry::event!(
1475 "Agent Tool Call Completed",
1476 agent_telemetry_id,
1477 session,
1478 status
1479 );
1480 }
1481
1482 if let Some(ix) = self.index_for_tool_call(&id) {
1483 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1484 unreachable!()
1485 };
1486
1487 call.update_fields(
1488 update.fields,
1489 language_registry,
1490 path_style,
1491 &self.terminals,
1492 cx,
1493 )?;
1494 call.status = status;
1495
1496 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1497 } else {
1498 let call = ToolCall::from_acp(
1499 update.try_into()?,
1500 status,
1501 language_registry,
1502 self.project.read(cx).path_style(cx),
1503 &self.terminals,
1504 cx,
1505 )?;
1506 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1507 };
1508
1509 self.resolve_locations(id, cx);
1510 Ok(())
1511 }
1512
1513 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1514 self.entries
1515 .iter()
1516 .enumerate()
1517 .rev()
1518 .find_map(|(index, entry)| {
1519 if let AgentThreadEntry::ToolCall(tool_call) = entry
1520 && &tool_call.id == id
1521 {
1522 Some(index)
1523 } else {
1524 None
1525 }
1526 })
1527 }
1528
1529 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1530 // The tool call we are looking for is typically the last one, or very close to the end.
1531 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1532 self.entries
1533 .iter_mut()
1534 .enumerate()
1535 .rev()
1536 .find_map(|(index, tool_call)| {
1537 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1538 && &tool_call.id == id
1539 {
1540 Some((index, tool_call))
1541 } else {
1542 None
1543 }
1544 })
1545 }
1546
1547 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1548 self.entries
1549 .iter()
1550 .enumerate()
1551 .rev()
1552 .find_map(|(index, tool_call)| {
1553 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1554 && &tool_call.id == id
1555 {
1556 Some((index, tool_call))
1557 } else {
1558 None
1559 }
1560 })
1561 }
1562
1563 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1564 let project = self.project.clone();
1565 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1566 return;
1567 };
1568 let task = tool_call.resolve_locations(project, cx);
1569 cx.spawn(async move |this, cx| {
1570 let resolved_locations = task.await;
1571
1572 this.update(cx, |this, cx| {
1573 let project = this.project.clone();
1574
1575 for location in resolved_locations.iter().flatten() {
1576 this.shared_buffers
1577 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1578 }
1579 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1580 return;
1581 };
1582
1583 if let Some(Some(location)) = resolved_locations.last() {
1584 project.update(cx, |project, cx| {
1585 let should_ignore = if let Some(agent_location) = project
1586 .agent_location()
1587 .filter(|agent_location| agent_location.buffer == location.buffer)
1588 {
1589 let snapshot = location.buffer.read(cx).snapshot();
1590 let old_position = agent_location.position.to_point(&snapshot);
1591 let new_position = location.position.to_point(&snapshot);
1592
1593 // ignore this so that when we get updates from the edit tool
1594 // the position doesn't reset to the startof line
1595 old_position.row == new_position.row
1596 && old_position.column > new_position.column
1597 } else {
1598 false
1599 };
1600 if !should_ignore {
1601 project.set_agent_location(Some(location.into()), cx);
1602 }
1603 });
1604 }
1605
1606 let resolved_locations = resolved_locations
1607 .iter()
1608 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1609 .collect::<Vec<_>>();
1610
1611 if tool_call.resolved_locations != resolved_locations {
1612 tool_call.resolved_locations = resolved_locations;
1613 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1614 }
1615 })
1616 })
1617 .detach();
1618 }
1619
1620 pub fn request_tool_call_authorization(
1621 &mut self,
1622 tool_call: acp::ToolCallUpdate,
1623 options: Vec<acp::PermissionOption>,
1624 respect_always_allow_setting: bool,
1625 cx: &mut Context<Self>,
1626 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1627 let (tx, rx) = oneshot::channel();
1628
1629 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1630 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1631 // some tools would (incorrectly) continue to auto-accept.
1632 if let Some(allow_once_option) = options.iter().find_map(|option| {
1633 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1634 Some(option.option_id.clone())
1635 } else {
1636 None
1637 }
1638 }) {
1639 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1640 return Ok(async {
1641 acp::RequestPermissionOutcome::Selected(acp::SelectedPermissionOutcome::new(
1642 allow_once_option,
1643 ))
1644 }
1645 .boxed());
1646 }
1647 }
1648
1649 let status = ToolCallStatus::WaitingForConfirmation {
1650 options,
1651 respond_tx: tx,
1652 };
1653
1654 self.upsert_tool_call_inner(tool_call, status, cx)?;
1655 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1656
1657 let fut = async {
1658 match rx.await {
1659 Ok(option) => acp::RequestPermissionOutcome::Selected(
1660 acp::SelectedPermissionOutcome::new(option),
1661 ),
1662 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1663 }
1664 }
1665 .boxed();
1666
1667 Ok(fut)
1668 }
1669
1670 pub fn authorize_tool_call(
1671 &mut self,
1672 id: acp::ToolCallId,
1673 option_id: acp::PermissionOptionId,
1674 option_kind: acp::PermissionOptionKind,
1675 cx: &mut Context<Self>,
1676 ) {
1677 let Some((ix, call)) = self.tool_call_mut(&id) else {
1678 return;
1679 };
1680
1681 let new_status = match option_kind {
1682 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1683 ToolCallStatus::Rejected
1684 }
1685 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1686 ToolCallStatus::InProgress
1687 }
1688 _ => ToolCallStatus::InProgress,
1689 };
1690
1691 let curr_status = mem::replace(&mut call.status, new_status);
1692
1693 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1694 respond_tx.send(option_id).log_err();
1695 } else if cfg!(debug_assertions) {
1696 panic!("tried to authorize an already authorized tool call");
1697 }
1698
1699 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1700 }
1701
1702 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1703 let mut first_tool_call = None;
1704
1705 for entry in self.entries.iter().rev() {
1706 match &entry {
1707 AgentThreadEntry::ToolCall(call) => {
1708 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1709 first_tool_call = Some(call);
1710 } else {
1711 continue;
1712 }
1713 }
1714 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1715 // Reached the beginning of the turn.
1716 // If we had pending permission requests in the previous turn, they have been cancelled.
1717 break;
1718 }
1719 }
1720 }
1721
1722 first_tool_call
1723 }
1724
1725 pub fn plan(&self) -> &Plan {
1726 &self.plan
1727 }
1728
1729 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1730 let new_entries_len = request.entries.len();
1731 let mut new_entries = request.entries.into_iter();
1732
1733 // Reuse existing markdown to prevent flickering
1734 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1735 let PlanEntry {
1736 content,
1737 priority,
1738 status,
1739 } = old;
1740 content.update(cx, |old, cx| {
1741 old.replace(new.content, cx);
1742 });
1743 *priority = new.priority;
1744 *status = new.status;
1745 }
1746 for new in new_entries {
1747 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1748 }
1749 self.plan.entries.truncate(new_entries_len);
1750
1751 cx.notify();
1752 }
1753
1754 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1755 self.plan
1756 .entries
1757 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1758 cx.notify();
1759 }
1760
1761 #[cfg(any(test, feature = "test-support"))]
1762 pub fn send_raw(
1763 &mut self,
1764 message: &str,
1765 cx: &mut Context<Self>,
1766 ) -> BoxFuture<'static, Result<()>> {
1767 self.send(vec![message.into()], cx)
1768 }
1769
1770 pub fn send(
1771 &mut self,
1772 message: Vec<acp::ContentBlock>,
1773 cx: &mut Context<Self>,
1774 ) -> BoxFuture<'static, Result<()>> {
1775 let block = ContentBlock::new_combined(
1776 message.clone(),
1777 self.project.read(cx).languages().clone(),
1778 self.project.read(cx).path_style(cx),
1779 cx,
1780 );
1781 let request = acp::PromptRequest::new(self.session_id.clone(), message.clone());
1782 let git_store = self.project.read(cx).git_store().clone();
1783
1784 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1785 Some(UserMessageId::new())
1786 } else {
1787 None
1788 };
1789
1790 self.run_turn(cx, async move |this, cx| {
1791 this.update(cx, |this, cx| {
1792 this.push_entry(
1793 AgentThreadEntry::UserMessage(UserMessage {
1794 id: message_id.clone(),
1795 content: block,
1796 chunks: message,
1797 checkpoint: None,
1798 indented: false,
1799 }),
1800 cx,
1801 );
1802 })
1803 .ok();
1804
1805 let old_checkpoint = git_store
1806 .update(cx, |git, cx| git.checkpoint(cx))?
1807 .await
1808 .context("failed to get old checkpoint")
1809 .log_err();
1810 this.update(cx, |this, cx| {
1811 if let Some((_ix, message)) = this.last_user_message() {
1812 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1813 git_checkpoint,
1814 show: false,
1815 });
1816 }
1817 this.connection.prompt(message_id, request, cx)
1818 })?
1819 .await
1820 })
1821 }
1822
1823 pub fn can_resume(&self, cx: &App) -> bool {
1824 self.connection.resume(&self.session_id, cx).is_some()
1825 }
1826
1827 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1828 self.run_turn(cx, async move |this, cx| {
1829 this.update(cx, |this, cx| {
1830 this.connection
1831 .resume(&this.session_id, cx)
1832 .map(|resume| resume.run(cx))
1833 })?
1834 .context("resuming a session is not supported")?
1835 .await
1836 })
1837 }
1838
1839 fn run_turn(
1840 &mut self,
1841 cx: &mut Context<Self>,
1842 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1843 ) -> BoxFuture<'static, Result<()>> {
1844 self.clear_completed_plan_entries(cx);
1845
1846 let (tx, rx) = oneshot::channel();
1847 let cancel_task = self.cancel(cx);
1848
1849 self.send_task = Some(cx.spawn(async move |this, cx| {
1850 cancel_task.await;
1851 tx.send(f(this, cx).await).ok();
1852 }));
1853
1854 cx.spawn(async move |this, cx| {
1855 let response = rx.await;
1856
1857 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1858 .await?;
1859
1860 this.update(cx, |this, cx| {
1861 this.project
1862 .update(cx, |project, cx| project.set_agent_location(None, cx));
1863 match response {
1864 Ok(Err(e)) => {
1865 this.send_task.take();
1866 cx.emit(AcpThreadEvent::Error);
1867 Err(e)
1868 }
1869 result => {
1870 let canceled = matches!(
1871 result,
1872 Ok(Ok(acp::PromptResponse {
1873 stop_reason: acp::StopReason::Cancelled,
1874 ..
1875 }))
1876 );
1877
1878 // We only take the task if the current prompt wasn't canceled.
1879 //
1880 // This prompt may have been canceled because another one was sent
1881 // while it was still generating. In these cases, dropping `send_task`
1882 // would cause the next generation to be canceled.
1883 if !canceled {
1884 this.send_task.take();
1885 }
1886
1887 // Handle refusal - distinguish between user prompt and tool call refusals
1888 if let Ok(Ok(acp::PromptResponse {
1889 stop_reason: acp::StopReason::Refusal,
1890 ..
1891 })) = result
1892 {
1893 if let Some((user_msg_ix, _)) = this.last_user_message() {
1894 // Check if there's a completed tool call with results after the last user message
1895 // This indicates the refusal is in response to tool output, not the user's prompt
1896 let has_completed_tool_call_after_user_msg =
1897 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1898 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1899 // Check if the tool call has completed and has output
1900 matches!(tool_call.status, ToolCallStatus::Completed)
1901 && tool_call.raw_output.is_some()
1902 } else {
1903 false
1904 }
1905 });
1906
1907 if has_completed_tool_call_after_user_msg {
1908 // Refusal is due to tool output - don't truncate, just notify
1909 // The model refused based on what the tool returned
1910 cx.emit(AcpThreadEvent::Refusal);
1911 } else {
1912 // User prompt was refused - truncate back to before the user message
1913 let range = user_msg_ix..this.entries.len();
1914 if range.start < range.end {
1915 this.entries.truncate(user_msg_ix);
1916 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1917 }
1918 cx.emit(AcpThreadEvent::Refusal);
1919 }
1920 } else {
1921 // No user message found, treat as general refusal
1922 cx.emit(AcpThreadEvent::Refusal);
1923 }
1924 }
1925
1926 cx.emit(AcpThreadEvent::Stopped);
1927 Ok(())
1928 }
1929 }
1930 })?
1931 })
1932 .boxed()
1933 }
1934
1935 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1936 let Some(send_task) = self.send_task.take() else {
1937 return Task::ready(());
1938 };
1939
1940 for entry in self.entries.iter_mut() {
1941 if let AgentThreadEntry::ToolCall(call) = entry {
1942 let cancel = matches!(
1943 call.status,
1944 ToolCallStatus::Pending
1945 | ToolCallStatus::WaitingForConfirmation { .. }
1946 | ToolCallStatus::InProgress
1947 );
1948
1949 if cancel {
1950 call.status = ToolCallStatus::Canceled;
1951 }
1952 }
1953 }
1954
1955 self.connection.cancel(&self.session_id, cx);
1956
1957 // Wait for the send task to complete
1958 cx.foreground_executor().spawn(send_task)
1959 }
1960
1961 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1962 pub fn restore_checkpoint(
1963 &mut self,
1964 id: UserMessageId,
1965 cx: &mut Context<Self>,
1966 ) -> Task<Result<()>> {
1967 let Some((_, message)) = self.user_message_mut(&id) else {
1968 return Task::ready(Err(anyhow!("message not found")));
1969 };
1970
1971 let checkpoint = message
1972 .checkpoint
1973 .as_ref()
1974 .map(|c| c.git_checkpoint.clone());
1975
1976 // Cancel any in-progress generation before restoring
1977 let cancel_task = self.cancel(cx);
1978 let rewind = self.rewind(id.clone(), cx);
1979 let git_store = self.project.read(cx).git_store().clone();
1980
1981 cx.spawn(async move |_, cx| {
1982 cancel_task.await;
1983 rewind.await?;
1984 if let Some(checkpoint) = checkpoint {
1985 git_store
1986 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1987 .await?;
1988 }
1989
1990 Ok(())
1991 })
1992 }
1993
1994 /// Rewinds this thread to before the entry at `index`, removing it and all
1995 /// subsequent entries while rejecting any action_log changes made from that point.
1996 /// Unlike `restore_checkpoint`, this method does not restore from git.
1997 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1998 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1999 return Task::ready(Err(anyhow!("not supported")));
2000 };
2001
2002 let telemetry = ActionLogTelemetry::from(&*self);
2003 cx.spawn(async move |this, cx| {
2004 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
2005 this.update(cx, |this, cx| {
2006 if let Some((ix, _)) = this.user_message_mut(&id) {
2007 // Collect all terminals from entries that will be removed
2008 let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
2009 .iter()
2010 .flat_map(|entry| entry.terminals())
2011 .filter_map(|terminal| terminal.read(cx).id().clone().into())
2012 .collect();
2013
2014 let range = ix..this.entries.len();
2015 this.entries.truncate(ix);
2016 cx.emit(AcpThreadEvent::EntriesRemoved(range));
2017
2018 // Kill and remove the terminals
2019 for terminal_id in terminals_to_remove {
2020 if let Some(terminal) = this.terminals.remove(&terminal_id) {
2021 terminal.update(cx, |terminal, cx| {
2022 terminal.kill(cx);
2023 });
2024 }
2025 }
2026 }
2027 this.action_log().update(cx, |action_log, cx| {
2028 action_log.reject_all_edits(Some(telemetry), cx)
2029 })
2030 })?
2031 .await;
2032 Ok(())
2033 })
2034 }
2035
2036 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
2037 let git_store = self.project.read(cx).git_store().clone();
2038
2039 let Some((_, message)) = self.last_user_message() else {
2040 return Task::ready(Ok(()));
2041 };
2042 let Some(user_message_id) = message.id.clone() else {
2043 return Task::ready(Ok(()));
2044 };
2045 let Some(checkpoint) = message.checkpoint.as_ref() else {
2046 return Task::ready(Ok(()));
2047 };
2048 let old_checkpoint = checkpoint.git_checkpoint.clone();
2049
2050 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
2051 cx.spawn(async move |this, cx| {
2052 let Some(new_checkpoint) = new_checkpoint
2053 .await
2054 .context("failed to get new checkpoint")
2055 .log_err()
2056 else {
2057 return Ok(());
2058 };
2059
2060 let equal = git_store
2061 .update(cx, |git, cx| {
2062 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
2063 })?
2064 .await
2065 .unwrap_or(true);
2066
2067 this.update(cx, |this, cx| {
2068 if let Some((ix, message)) = this.user_message_mut(&user_message_id) {
2069 if let Some(checkpoint) = message.checkpoint.as_mut() {
2070 checkpoint.show = !equal;
2071 cx.emit(AcpThreadEvent::EntryUpdated(ix));
2072 }
2073 }
2074 })?;
2075
2076 Ok(())
2077 })
2078 }
2079
2080 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
2081 self.entries
2082 .iter_mut()
2083 .enumerate()
2084 .rev()
2085 .find_map(|(ix, entry)| {
2086 if let AgentThreadEntry::UserMessage(message) = entry {
2087 Some((ix, message))
2088 } else {
2089 None
2090 }
2091 })
2092 }
2093
2094 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
2095 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
2096 if let AgentThreadEntry::UserMessage(message) = entry {
2097 if message.id.as_ref() == Some(id) {
2098 Some((ix, message))
2099 } else {
2100 None
2101 }
2102 } else {
2103 None
2104 }
2105 })
2106 }
2107
2108 pub fn read_text_file(
2109 &self,
2110 path: PathBuf,
2111 line: Option<u32>,
2112 limit: Option<u32>,
2113 reuse_shared_snapshot: bool,
2114 cx: &mut Context<Self>,
2115 ) -> Task<Result<String, acp::Error>> {
2116 // Args are 1-based, move to 0-based
2117 let line = line.unwrap_or_default().saturating_sub(1);
2118 let limit = limit.unwrap_or(u32::MAX);
2119 let project = self.project.clone();
2120 let action_log = self.action_log.clone();
2121 cx.spawn(async move |this, cx| {
2122 let load = project
2123 .update(cx, |project, cx| {
2124 let path = project
2125 .project_path_for_absolute_path(&path, cx)
2126 .ok_or_else(|| {
2127 acp::Error::resource_not_found(Some(path.display().to_string()))
2128 })?;
2129 Ok(project.open_buffer(path, cx))
2130 })
2131 .map_err(|e| acp::Error::internal_error().data(e.to_string()))
2132 .flatten()?;
2133
2134 let buffer = load.await?;
2135
2136 let snapshot = if reuse_shared_snapshot {
2137 this.read_with(cx, |this, _| {
2138 this.shared_buffers.get(&buffer.clone()).cloned()
2139 })
2140 .log_err()
2141 .flatten()
2142 } else {
2143 None
2144 };
2145
2146 let snapshot = if let Some(snapshot) = snapshot {
2147 snapshot
2148 } else {
2149 action_log.update(cx, |action_log, cx| {
2150 action_log.buffer_read(buffer.clone(), cx);
2151 })?;
2152
2153 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
2154 this.update(cx, |this, _| {
2155 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
2156 })?;
2157 snapshot
2158 };
2159
2160 let max_point = snapshot.max_point();
2161 let start_position = Point::new(line, 0);
2162
2163 if start_position > max_point {
2164 return Err(acp::Error::invalid_params().data(format!(
2165 "Attempting to read beyond the end of the file, line {}:{}",
2166 max_point.row + 1,
2167 max_point.column
2168 )));
2169 }
2170
2171 let start = snapshot.anchor_before(start_position);
2172 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
2173
2174 project.update(cx, |project, cx| {
2175 project.set_agent_location(
2176 Some(AgentLocation {
2177 buffer: buffer.downgrade(),
2178 position: start,
2179 }),
2180 cx,
2181 );
2182 })?;
2183
2184 Ok(snapshot.text_for_range(start..end).collect::<String>())
2185 })
2186 }
2187
2188 pub fn write_text_file(
2189 &self,
2190 path: PathBuf,
2191 content: String,
2192 cx: &mut Context<Self>,
2193 ) -> Task<Result<()>> {
2194 let project = self.project.clone();
2195 let action_log = self.action_log.clone();
2196 cx.spawn(async move |this, cx| {
2197 let load = project.update(cx, |project, cx| {
2198 let path = project
2199 .project_path_for_absolute_path(&path, cx)
2200 .context("invalid path")?;
2201 anyhow::Ok(project.open_buffer(path, cx))
2202 });
2203 let buffer = load??.await?;
2204 let snapshot = this.update(cx, |this, cx| {
2205 this.shared_buffers
2206 .get(&buffer)
2207 .cloned()
2208 .unwrap_or_else(|| buffer.read(cx).snapshot())
2209 })?;
2210 let edits = cx
2211 .background_executor()
2212 .spawn(async move {
2213 let old_text = snapshot.text();
2214 text_diff(old_text.as_str(), &content)
2215 .into_iter()
2216 .map(|(range, replacement)| {
2217 (
2218 snapshot.anchor_after(range.start)
2219 ..snapshot.anchor_before(range.end),
2220 replacement,
2221 )
2222 })
2223 .collect::<Vec<_>>()
2224 })
2225 .await;
2226
2227 project.update(cx, |project, cx| {
2228 project.set_agent_location(
2229 Some(AgentLocation {
2230 buffer: buffer.downgrade(),
2231 position: edits
2232 .last()
2233 .map(|(range, _)| range.end)
2234 .unwrap_or(Anchor::min_for_buffer(buffer.read(cx).remote_id())),
2235 }),
2236 cx,
2237 );
2238 })?;
2239
2240 let format_on_save = cx.update(|cx| {
2241 action_log.update(cx, |action_log, cx| {
2242 action_log.buffer_read(buffer.clone(), cx);
2243 });
2244
2245 let format_on_save = buffer.update(cx, |buffer, cx| {
2246 buffer.edit(edits, None, cx);
2247
2248 let settings = language::language_settings::language_settings(
2249 buffer.language().map(|l| l.name()),
2250 buffer.file(),
2251 cx,
2252 );
2253
2254 settings.format_on_save != FormatOnSave::Off
2255 });
2256 action_log.update(cx, |action_log, cx| {
2257 action_log.buffer_edited(buffer.clone(), cx);
2258 });
2259 format_on_save
2260 })?;
2261
2262 if format_on_save {
2263 let format_task = project.update(cx, |project, cx| {
2264 project.format(
2265 HashSet::from_iter([buffer.clone()]),
2266 LspFormatTarget::Buffers,
2267 false,
2268 FormatTrigger::Save,
2269 cx,
2270 )
2271 })?;
2272 format_task.await.log_err();
2273
2274 action_log.update(cx, |action_log, cx| {
2275 action_log.buffer_edited(buffer.clone(), cx);
2276 })?;
2277 }
2278
2279 project
2280 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2281 .await
2282 })
2283 }
2284
2285 pub fn create_terminal(
2286 &self,
2287 command: String,
2288 args: Vec<String>,
2289 extra_env: Vec<acp::EnvVariable>,
2290 cwd: Option<PathBuf>,
2291 output_byte_limit: Option<u64>,
2292 cx: &mut Context<Self>,
2293 ) -> Task<Result<Entity<Terminal>>> {
2294 let env = match &cwd {
2295 Some(dir) => self.project.update(cx, |project, cx| {
2296 project.environment().update(cx, |env, cx| {
2297 env.directory_environment(dir.as_path().into(), cx)
2298 })
2299 }),
2300 None => Task::ready(None).shared(),
2301 };
2302 let env = cx.spawn(async move |_, _| {
2303 let mut env = env.await.unwrap_or_default();
2304 // Disables paging for `git` and hopefully other commands
2305 env.insert("PAGER".into(), "".into());
2306 for var in extra_env {
2307 env.insert(var.name, var.value);
2308 }
2309 env
2310 });
2311
2312 let project = self.project.clone();
2313 let language_registry = project.read(cx).languages().clone();
2314 let is_windows = project.read(cx).path_style(cx).is_windows();
2315
2316 let terminal_id = acp::TerminalId::new(Uuid::new_v4().to_string());
2317 let terminal_task = cx.spawn({
2318 let terminal_id = terminal_id.clone();
2319 async move |_this, cx| {
2320 let env = env.await;
2321 let shell = project
2322 .update(cx, |project, cx| {
2323 project
2324 .remote_client()
2325 .and_then(|r| r.read(cx).default_system_shell())
2326 })?
2327 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2328 let (task_command, task_args) =
2329 ShellBuilder::new(&Shell::Program(shell), is_windows)
2330 .redirect_stdin_to_dev_null()
2331 .build(Some(command.clone()), &args);
2332 let terminal = project
2333 .update(cx, |project, cx| {
2334 project.create_terminal_task(
2335 task::SpawnInTerminal {
2336 command: Some(task_command),
2337 args: task_args,
2338 cwd: cwd.clone(),
2339 env,
2340 ..Default::default()
2341 },
2342 cx,
2343 )
2344 })?
2345 .await?;
2346
2347 cx.new(|cx| {
2348 Terminal::new(
2349 terminal_id,
2350 &format!("{} {}", command, args.join(" ")),
2351 cwd,
2352 output_byte_limit.map(|l| l as usize),
2353 terminal,
2354 language_registry,
2355 cx,
2356 )
2357 })
2358 }
2359 });
2360
2361 cx.spawn(async move |this, cx| {
2362 let terminal = terminal_task.await?;
2363 this.update(cx, |this, _cx| {
2364 this.terminals.insert(terminal_id, terminal.clone());
2365 terminal
2366 })
2367 })
2368 }
2369
2370 pub fn kill_terminal(
2371 &mut self,
2372 terminal_id: acp::TerminalId,
2373 cx: &mut Context<Self>,
2374 ) -> Result<()> {
2375 self.terminals
2376 .get(&terminal_id)
2377 .context("Terminal not found")?
2378 .update(cx, |terminal, cx| {
2379 terminal.kill(cx);
2380 });
2381
2382 Ok(())
2383 }
2384
2385 pub fn release_terminal(
2386 &mut self,
2387 terminal_id: acp::TerminalId,
2388 cx: &mut Context<Self>,
2389 ) -> Result<()> {
2390 self.terminals
2391 .remove(&terminal_id)
2392 .context("Terminal not found")?
2393 .update(cx, |terminal, cx| {
2394 terminal.kill(cx);
2395 });
2396
2397 Ok(())
2398 }
2399
2400 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2401 self.terminals
2402 .get(&terminal_id)
2403 .context("Terminal not found")
2404 .cloned()
2405 }
2406
2407 pub fn to_markdown(&self, cx: &App) -> String {
2408 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2409 }
2410
2411 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2412 cx.emit(AcpThreadEvent::LoadError(error));
2413 }
2414
2415 pub fn register_terminal_created(
2416 &mut self,
2417 terminal_id: acp::TerminalId,
2418 command_label: String,
2419 working_dir: Option<PathBuf>,
2420 output_byte_limit: Option<u64>,
2421 terminal: Entity<::terminal::Terminal>,
2422 cx: &mut Context<Self>,
2423 ) -> Entity<Terminal> {
2424 let language_registry = self.project.read(cx).languages().clone();
2425
2426 let entity = cx.new(|cx| {
2427 Terminal::new(
2428 terminal_id.clone(),
2429 &command_label,
2430 working_dir.clone(),
2431 output_byte_limit.map(|l| l as usize),
2432 terminal,
2433 language_registry,
2434 cx,
2435 )
2436 });
2437 self.terminals.insert(terminal_id.clone(), entity.clone());
2438 entity
2439 }
2440}
2441
2442fn markdown_for_raw_output(
2443 raw_output: &serde_json::Value,
2444 language_registry: &Arc<LanguageRegistry>,
2445 cx: &mut App,
2446) -> Option<Entity<Markdown>> {
2447 match raw_output {
2448 serde_json::Value::Null => None,
2449 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2450 Markdown::new(
2451 value.to_string().into(),
2452 Some(language_registry.clone()),
2453 None,
2454 cx,
2455 )
2456 })),
2457 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2458 Markdown::new(
2459 value.to_string().into(),
2460 Some(language_registry.clone()),
2461 None,
2462 cx,
2463 )
2464 })),
2465 serde_json::Value::String(value) => Some(cx.new(|cx| {
2466 Markdown::new(
2467 value.clone().into(),
2468 Some(language_registry.clone()),
2469 None,
2470 cx,
2471 )
2472 })),
2473 value => Some(cx.new(|cx| {
2474 let pretty_json = to_string_pretty(value).unwrap_or_else(|_| value.to_string());
2475
2476 Markdown::new(
2477 format!("```json\n{}\n```", pretty_json).into(),
2478 Some(language_registry.clone()),
2479 None,
2480 cx,
2481 )
2482 })),
2483 }
2484}
2485
2486#[cfg(test)]
2487mod tests {
2488 use super::*;
2489 use anyhow::anyhow;
2490 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2491 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2492 use indoc::indoc;
2493 use project::{FakeFs, Fs};
2494 use rand::{distr, prelude::*};
2495 use serde_json::json;
2496 use settings::SettingsStore;
2497 use smol::stream::StreamExt as _;
2498 use std::{
2499 any::Any,
2500 cell::RefCell,
2501 path::Path,
2502 rc::Rc,
2503 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2504 time::Duration,
2505 };
2506 use util::path;
2507
2508 fn init_test(cx: &mut TestAppContext) {
2509 env_logger::try_init().ok();
2510 cx.update(|cx| {
2511 let settings_store = SettingsStore::test(cx);
2512 cx.set_global(settings_store);
2513 });
2514 }
2515
2516 #[gpui::test]
2517 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2518 init_test(cx);
2519
2520 let fs = FakeFs::new(cx.executor());
2521 let project = Project::test(fs, [], cx).await;
2522 let connection = Rc::new(FakeAgentConnection::new());
2523 let thread = cx
2524 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2525 .await
2526 .unwrap();
2527
2528 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2529
2530 // Send Output BEFORE Created - should be buffered by acp_thread
2531 thread.update(cx, |thread, cx| {
2532 thread.on_terminal_provider_event(
2533 TerminalProviderEvent::Output {
2534 terminal_id: terminal_id.clone(),
2535 data: b"hello buffered".to_vec(),
2536 },
2537 cx,
2538 );
2539 });
2540
2541 // Create a display-only terminal and then send Created
2542 let lower = cx.new(|cx| {
2543 let builder = ::terminal::TerminalBuilder::new_display_only(
2544 ::terminal::terminal_settings::CursorShape::default(),
2545 ::terminal::terminal_settings::AlternateScroll::On,
2546 None,
2547 0,
2548 )
2549 .unwrap();
2550 builder.subscribe(cx)
2551 });
2552
2553 thread.update(cx, |thread, cx| {
2554 thread.on_terminal_provider_event(
2555 TerminalProviderEvent::Created {
2556 terminal_id: terminal_id.clone(),
2557 label: "Buffered Test".to_string(),
2558 cwd: None,
2559 output_byte_limit: None,
2560 terminal: lower.clone(),
2561 },
2562 cx,
2563 );
2564 });
2565
2566 // After Created, buffered Output should have been flushed into the renderer
2567 let content = thread.read_with(cx, |thread, cx| {
2568 let term = thread.terminal(terminal_id.clone()).unwrap();
2569 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2570 });
2571
2572 assert!(
2573 content.contains("hello buffered"),
2574 "expected buffered output to render, got: {content}"
2575 );
2576 }
2577
2578 #[gpui::test]
2579 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2580 init_test(cx);
2581
2582 let fs = FakeFs::new(cx.executor());
2583 let project = Project::test(fs, [], cx).await;
2584 let connection = Rc::new(FakeAgentConnection::new());
2585 let thread = cx
2586 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2587 .await
2588 .unwrap();
2589
2590 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
2591
2592 // Send Output BEFORE Created
2593 thread.update(cx, |thread, cx| {
2594 thread.on_terminal_provider_event(
2595 TerminalProviderEvent::Output {
2596 terminal_id: terminal_id.clone(),
2597 data: b"pre-exit data".to_vec(),
2598 },
2599 cx,
2600 );
2601 });
2602
2603 // Send Exit BEFORE Created
2604 thread.update(cx, |thread, cx| {
2605 thread.on_terminal_provider_event(
2606 TerminalProviderEvent::Exit {
2607 terminal_id: terminal_id.clone(),
2608 status: acp::TerminalExitStatus::new().exit_code(0),
2609 },
2610 cx,
2611 );
2612 });
2613
2614 // Now create a display-only lower-level terminal and send Created
2615 let lower = cx.new(|cx| {
2616 let builder = ::terminal::TerminalBuilder::new_display_only(
2617 ::terminal::terminal_settings::CursorShape::default(),
2618 ::terminal::terminal_settings::AlternateScroll::On,
2619 None,
2620 0,
2621 )
2622 .unwrap();
2623 builder.subscribe(cx)
2624 });
2625
2626 thread.update(cx, |thread, cx| {
2627 thread.on_terminal_provider_event(
2628 TerminalProviderEvent::Created {
2629 terminal_id: terminal_id.clone(),
2630 label: "Buffered Exit Test".to_string(),
2631 cwd: None,
2632 output_byte_limit: None,
2633 terminal: lower.clone(),
2634 },
2635 cx,
2636 );
2637 });
2638
2639 // Output should be present after Created (flushed from buffer)
2640 let content = thread.read_with(cx, |thread, cx| {
2641 let term = thread.terminal(terminal_id.clone()).unwrap();
2642 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2643 });
2644
2645 assert!(
2646 content.contains("pre-exit data"),
2647 "expected pre-exit data to render, got: {content}"
2648 );
2649 }
2650
2651 #[gpui::test]
2652 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2653 init_test(cx);
2654
2655 let fs = FakeFs::new(cx.executor());
2656 let project = Project::test(fs, [], cx).await;
2657 let connection = Rc::new(FakeAgentConnection::new());
2658 let thread = cx
2659 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2660 .await
2661 .unwrap();
2662
2663 // Test creating a new user message
2664 thread.update(cx, |thread, cx| {
2665 thread.push_user_content_block(None, "Hello, ".into(), cx);
2666 });
2667
2668 thread.update(cx, |thread, cx| {
2669 assert_eq!(thread.entries.len(), 1);
2670 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2671 assert_eq!(user_msg.id, None);
2672 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2673 } else {
2674 panic!("Expected UserMessage");
2675 }
2676 });
2677
2678 // Test appending to existing user message
2679 let message_1_id = UserMessageId::new();
2680 thread.update(cx, |thread, cx| {
2681 thread.push_user_content_block(Some(message_1_id.clone()), "world!".into(), cx);
2682 });
2683
2684 thread.update(cx, |thread, cx| {
2685 assert_eq!(thread.entries.len(), 1);
2686 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2687 assert_eq!(user_msg.id, Some(message_1_id));
2688 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2689 } else {
2690 panic!("Expected UserMessage");
2691 }
2692 });
2693
2694 // Test creating new user message after assistant message
2695 thread.update(cx, |thread, cx| {
2696 thread.push_assistant_content_block("Assistant response".into(), false, cx);
2697 });
2698
2699 let message_2_id = UserMessageId::new();
2700 thread.update(cx, |thread, cx| {
2701 thread.push_user_content_block(
2702 Some(message_2_id.clone()),
2703 "New user message".into(),
2704 cx,
2705 );
2706 });
2707
2708 thread.update(cx, |thread, cx| {
2709 assert_eq!(thread.entries.len(), 3);
2710 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2711 assert_eq!(user_msg.id, Some(message_2_id));
2712 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2713 } else {
2714 panic!("Expected UserMessage at index 2");
2715 }
2716 });
2717 }
2718
2719 #[gpui::test]
2720 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2721 init_test(cx);
2722
2723 let fs = FakeFs::new(cx.executor());
2724 let project = Project::test(fs, [], cx).await;
2725 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2726 |_, thread, mut cx| {
2727 async move {
2728 thread.update(&mut cx, |thread, cx| {
2729 thread
2730 .handle_session_update(
2731 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2732 "Thinking ".into(),
2733 )),
2734 cx,
2735 )
2736 .unwrap();
2737 thread
2738 .handle_session_update(
2739 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new(
2740 "hard!".into(),
2741 )),
2742 cx,
2743 )
2744 .unwrap();
2745 })?;
2746 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2747 }
2748 .boxed_local()
2749 },
2750 ));
2751
2752 let thread = cx
2753 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2754 .await
2755 .unwrap();
2756
2757 thread
2758 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2759 .await
2760 .unwrap();
2761
2762 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2763 assert_eq!(
2764 output,
2765 indoc! {r#"
2766 ## User
2767
2768 Hello from Zed!
2769
2770 ## Assistant
2771
2772 <thinking>
2773 Thinking hard!
2774 </thinking>
2775
2776 "#}
2777 );
2778 }
2779
2780 #[gpui::test]
2781 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2782 init_test(cx);
2783
2784 let fs = FakeFs::new(cx.executor());
2785 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2786 .await;
2787 let project = Project::test(fs.clone(), [], cx).await;
2788 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2789 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2790 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2791 move |_, thread, mut cx| {
2792 let read_file_tx = read_file_tx.clone();
2793 async move {
2794 let content = thread
2795 .update(&mut cx, |thread, cx| {
2796 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2797 })
2798 .unwrap()
2799 .await
2800 .unwrap();
2801 assert_eq!(content, "one\ntwo\nthree\n");
2802 read_file_tx.take().unwrap().send(()).unwrap();
2803 thread
2804 .update(&mut cx, |thread, cx| {
2805 thread.write_text_file(
2806 path!("/tmp/foo").into(),
2807 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2808 cx,
2809 )
2810 })
2811 .unwrap()
2812 .await
2813 .unwrap();
2814 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
2815 }
2816 .boxed_local()
2817 },
2818 ));
2819
2820 let (worktree, pathbuf) = project
2821 .update(cx, |project, cx| {
2822 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2823 })
2824 .await
2825 .unwrap();
2826 let buffer = project
2827 .update(cx, |project, cx| {
2828 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2829 })
2830 .await
2831 .unwrap();
2832
2833 let thread = cx
2834 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2835 .await
2836 .unwrap();
2837
2838 let request = thread.update(cx, |thread, cx| {
2839 thread.send_raw("Extend the count in /tmp/foo", cx)
2840 });
2841 read_file_rx.await.ok();
2842 buffer.update(cx, |buffer, cx| {
2843 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2844 });
2845 cx.run_until_parked();
2846 assert_eq!(
2847 buffer.read_with(cx, |buffer, _| buffer.text()),
2848 "zero\none\ntwo\nthree\nfour\nfive\n"
2849 );
2850 assert_eq!(
2851 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2852 "zero\none\ntwo\nthree\nfour\nfive\n"
2853 );
2854 request.await.unwrap();
2855 }
2856
2857 #[gpui::test]
2858 async fn test_reading_from_line(cx: &mut TestAppContext) {
2859 init_test(cx);
2860
2861 let fs = FakeFs::new(cx.executor());
2862 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2863 .await;
2864 let project = Project::test(fs.clone(), [], cx).await;
2865 project
2866 .update(cx, |project, cx| {
2867 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2868 })
2869 .await
2870 .unwrap();
2871
2872 let connection = Rc::new(FakeAgentConnection::new());
2873
2874 let thread = cx
2875 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2876 .await
2877 .unwrap();
2878
2879 // Whole file
2880 let content = thread
2881 .update(cx, |thread, cx| {
2882 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2883 })
2884 .await
2885 .unwrap();
2886
2887 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2888
2889 // Only start line
2890 let content = thread
2891 .update(cx, |thread, cx| {
2892 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2893 })
2894 .await
2895 .unwrap();
2896
2897 assert_eq!(content, "three\nfour\n");
2898
2899 // Only limit
2900 let content = thread
2901 .update(cx, |thread, cx| {
2902 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2903 })
2904 .await
2905 .unwrap();
2906
2907 assert_eq!(content, "one\ntwo\n");
2908
2909 // Range
2910 let content = thread
2911 .update(cx, |thread, cx| {
2912 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2913 })
2914 .await
2915 .unwrap();
2916
2917 assert_eq!(content, "two\nthree\n");
2918
2919 // Invalid
2920 let err = thread
2921 .update(cx, |thread, cx| {
2922 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2923 })
2924 .await
2925 .unwrap_err();
2926
2927 assert_eq!(
2928 err.to_string(),
2929 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2930 );
2931 }
2932
2933 #[gpui::test]
2934 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2935 init_test(cx);
2936
2937 let fs = FakeFs::new(cx.executor());
2938 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2939 let project = Project::test(fs.clone(), [], cx).await;
2940 project
2941 .update(cx, |project, cx| {
2942 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2943 })
2944 .await
2945 .unwrap();
2946
2947 let connection = Rc::new(FakeAgentConnection::new());
2948
2949 let thread = cx
2950 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2951 .await
2952 .unwrap();
2953
2954 // Whole file
2955 let content = thread
2956 .update(cx, |thread, cx| {
2957 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2958 })
2959 .await
2960 .unwrap();
2961
2962 assert_eq!(content, "");
2963
2964 // Only start line
2965 let content = thread
2966 .update(cx, |thread, cx| {
2967 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2968 })
2969 .await
2970 .unwrap();
2971
2972 assert_eq!(content, "");
2973
2974 // Only limit
2975 let content = thread
2976 .update(cx, |thread, cx| {
2977 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2978 })
2979 .await
2980 .unwrap();
2981
2982 assert_eq!(content, "");
2983
2984 // Range
2985 let content = thread
2986 .update(cx, |thread, cx| {
2987 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2988 })
2989 .await
2990 .unwrap();
2991
2992 assert_eq!(content, "");
2993
2994 // Invalid
2995 let err = thread
2996 .update(cx, |thread, cx| {
2997 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2998 })
2999 .await
3000 .unwrap_err();
3001
3002 assert_eq!(
3003 err.to_string(),
3004 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
3005 );
3006 }
3007 #[gpui::test]
3008 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
3009 init_test(cx);
3010
3011 let fs = FakeFs::new(cx.executor());
3012 fs.insert_tree(path!("/tmp"), json!({})).await;
3013 let project = Project::test(fs.clone(), [], cx).await;
3014 project
3015 .update(cx, |project, cx| {
3016 project.find_or_create_worktree(path!("/tmp"), true, cx)
3017 })
3018 .await
3019 .unwrap();
3020
3021 let connection = Rc::new(FakeAgentConnection::new());
3022
3023 let thread = cx
3024 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
3025 .await
3026 .unwrap();
3027
3028 // Out of project file
3029 let err = thread
3030 .update(cx, |thread, cx| {
3031 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
3032 })
3033 .await
3034 .unwrap_err();
3035
3036 assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
3037 }
3038
3039 #[gpui::test]
3040 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
3041 init_test(cx);
3042
3043 let fs = FakeFs::new(cx.executor());
3044 let project = Project::test(fs, [], cx).await;
3045 let id = acp::ToolCallId::new("test");
3046
3047 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3048 let id = id.clone();
3049 move |_, thread, mut cx| {
3050 let id = id.clone();
3051 async move {
3052 thread
3053 .update(&mut cx, |thread, cx| {
3054 thread.handle_session_update(
3055 acp::SessionUpdate::ToolCall(
3056 acp::ToolCall::new(id.clone(), "Label")
3057 .kind(acp::ToolKind::Fetch)
3058 .status(acp::ToolCallStatus::InProgress),
3059 ),
3060 cx,
3061 )
3062 })
3063 .unwrap()
3064 .unwrap();
3065 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3066 }
3067 .boxed_local()
3068 }
3069 }));
3070
3071 let thread = cx
3072 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3073 .await
3074 .unwrap();
3075
3076 let request = thread.update(cx, |thread, cx| {
3077 thread.send_raw("Fetch https://example.com", cx)
3078 });
3079
3080 run_until_first_tool_call(&thread, cx).await;
3081
3082 thread.read_with(cx, |thread, _| {
3083 assert!(matches!(
3084 thread.entries[1],
3085 AgentThreadEntry::ToolCall(ToolCall {
3086 status: ToolCallStatus::InProgress,
3087 ..
3088 })
3089 ));
3090 });
3091
3092 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
3093
3094 thread.read_with(cx, |thread, _| {
3095 assert!(matches!(
3096 &thread.entries[1],
3097 AgentThreadEntry::ToolCall(ToolCall {
3098 status: ToolCallStatus::Canceled,
3099 ..
3100 })
3101 ));
3102 });
3103
3104 thread
3105 .update(cx, |thread, cx| {
3106 thread.handle_session_update(
3107 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3108 id,
3109 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3110 )),
3111 cx,
3112 )
3113 })
3114 .unwrap();
3115
3116 request.await.unwrap();
3117
3118 thread.read_with(cx, |thread, _| {
3119 assert!(matches!(
3120 thread.entries[1],
3121 AgentThreadEntry::ToolCall(ToolCall {
3122 status: ToolCallStatus::Completed,
3123 ..
3124 })
3125 ));
3126 });
3127 }
3128
3129 #[gpui::test]
3130 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3131 init_test(cx);
3132 let fs = FakeFs::new(cx.background_executor.clone());
3133 fs.insert_tree(path!("/test"), json!({})).await;
3134 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3135
3136 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3137 move |_, thread, mut cx| {
3138 async move {
3139 thread
3140 .update(&mut cx, |thread, cx| {
3141 thread.handle_session_update(
3142 acp::SessionUpdate::ToolCall(
3143 acp::ToolCall::new("test", "Label")
3144 .kind(acp::ToolKind::Edit)
3145 .status(acp::ToolCallStatus::Completed)
3146 .content(vec![acp::ToolCallContent::Diff(acp::Diff::new(
3147 "/test/test.txt",
3148 "foo",
3149 ))]),
3150 ),
3151 cx,
3152 )
3153 })
3154 .unwrap()
3155 .unwrap();
3156 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3157 }
3158 .boxed_local()
3159 }
3160 }));
3161
3162 let thread = cx
3163 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3164 .await
3165 .unwrap();
3166
3167 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3168 .await
3169 .unwrap();
3170
3171 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3172 }
3173
3174 #[gpui::test(iterations = 10)]
3175 async fn test_checkpoints(cx: &mut TestAppContext) {
3176 init_test(cx);
3177 let fs = FakeFs::new(cx.background_executor.clone());
3178 fs.insert_tree(
3179 path!("/test"),
3180 json!({
3181 ".git": {}
3182 }),
3183 )
3184 .await;
3185 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3186
3187 let simulate_changes = Arc::new(AtomicBool::new(true));
3188 let next_filename = Arc::new(AtomicUsize::new(0));
3189 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3190 let simulate_changes = simulate_changes.clone();
3191 let next_filename = next_filename.clone();
3192 let fs = fs.clone();
3193 move |request, thread, mut cx| {
3194 let fs = fs.clone();
3195 let simulate_changes = simulate_changes.clone();
3196 let next_filename = next_filename.clone();
3197 async move {
3198 if simulate_changes.load(SeqCst) {
3199 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3200 fs.write(Path::new(&filename), b"").await?;
3201 }
3202
3203 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3204 panic!("expected text content block");
3205 };
3206 thread.update(&mut cx, |thread, cx| {
3207 thread
3208 .handle_session_update(
3209 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3210 content.text.to_uppercase().into(),
3211 )),
3212 cx,
3213 )
3214 .unwrap();
3215 })?;
3216 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3217 }
3218 .boxed_local()
3219 }
3220 }));
3221 let thread = cx
3222 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3223 .await
3224 .unwrap();
3225
3226 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3227 .await
3228 .unwrap();
3229 thread.read_with(cx, |thread, cx| {
3230 assert_eq!(
3231 thread.to_markdown(cx),
3232 indoc! {"
3233 ## User (checkpoint)
3234
3235 Lorem
3236
3237 ## Assistant
3238
3239 LOREM
3240
3241 "}
3242 );
3243 });
3244 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3245
3246 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3247 .await
3248 .unwrap();
3249 thread.read_with(cx, |thread, cx| {
3250 assert_eq!(
3251 thread.to_markdown(cx),
3252 indoc! {"
3253 ## User (checkpoint)
3254
3255 Lorem
3256
3257 ## Assistant
3258
3259 LOREM
3260
3261 ## User (checkpoint)
3262
3263 ipsum
3264
3265 ## Assistant
3266
3267 IPSUM
3268
3269 "}
3270 );
3271 });
3272 assert_eq!(
3273 fs.files(),
3274 vec![
3275 Path::new(path!("/test/file-0")),
3276 Path::new(path!("/test/file-1"))
3277 ]
3278 );
3279
3280 // Checkpoint isn't stored when there are no changes.
3281 simulate_changes.store(false, SeqCst);
3282 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3283 .await
3284 .unwrap();
3285 thread.read_with(cx, |thread, cx| {
3286 assert_eq!(
3287 thread.to_markdown(cx),
3288 indoc! {"
3289 ## User (checkpoint)
3290
3291 Lorem
3292
3293 ## Assistant
3294
3295 LOREM
3296
3297 ## User (checkpoint)
3298
3299 ipsum
3300
3301 ## Assistant
3302
3303 IPSUM
3304
3305 ## User
3306
3307 dolor
3308
3309 ## Assistant
3310
3311 DOLOR
3312
3313 "}
3314 );
3315 });
3316 assert_eq!(
3317 fs.files(),
3318 vec![
3319 Path::new(path!("/test/file-0")),
3320 Path::new(path!("/test/file-1"))
3321 ]
3322 );
3323
3324 // Rewinding the conversation truncates the history and restores the checkpoint.
3325 thread
3326 .update(cx, |thread, cx| {
3327 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3328 panic!("unexpected entries {:?}", thread.entries)
3329 };
3330 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3331 })
3332 .await
3333 .unwrap();
3334 thread.read_with(cx, |thread, cx| {
3335 assert_eq!(
3336 thread.to_markdown(cx),
3337 indoc! {"
3338 ## User (checkpoint)
3339
3340 Lorem
3341
3342 ## Assistant
3343
3344 LOREM
3345
3346 "}
3347 );
3348 });
3349 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3350 }
3351
3352 #[gpui::test]
3353 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3354 use std::sync::atomic::AtomicUsize;
3355 init_test(cx);
3356
3357 let fs = FakeFs::new(cx.executor());
3358 let project = Project::test(fs, None, cx).await;
3359
3360 // Create a connection that simulates refusal after tool result
3361 let prompt_count = Arc::new(AtomicUsize::new(0));
3362 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3363 let prompt_count = prompt_count.clone();
3364 move |_request, thread, mut cx| {
3365 let count = prompt_count.fetch_add(1, SeqCst);
3366 async move {
3367 if count == 0 {
3368 // First prompt: Generate a tool call with result
3369 thread.update(&mut cx, |thread, cx| {
3370 thread
3371 .handle_session_update(
3372 acp::SessionUpdate::ToolCall(
3373 acp::ToolCall::new("tool1", "Test Tool")
3374 .kind(acp::ToolKind::Fetch)
3375 .status(acp::ToolCallStatus::Completed)
3376 .raw_input(serde_json::json!({"query": "test"}))
3377 .raw_output(serde_json::json!({"result": "inappropriate content"})),
3378 ),
3379 cx,
3380 )
3381 .unwrap();
3382 })?;
3383
3384 // Now return refusal because of the tool result
3385 Ok(acp::PromptResponse::new(acp::StopReason::Refusal))
3386 } else {
3387 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3388 }
3389 }
3390 .boxed_local()
3391 }
3392 }));
3393
3394 let thread = cx
3395 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3396 .await
3397 .unwrap();
3398
3399 // Track if we see a Refusal event
3400 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3401 let saw_refusal_event_captured = saw_refusal_event.clone();
3402 thread.update(cx, |_thread, cx| {
3403 cx.subscribe(
3404 &thread,
3405 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3406 if matches!(event, AcpThreadEvent::Refusal) {
3407 *saw_refusal_event_captured.lock().unwrap() = true;
3408 }
3409 },
3410 )
3411 .detach();
3412 });
3413
3414 // Send a user message - this will trigger tool call and then refusal
3415 let send_task = thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
3416 cx.background_executor.spawn(send_task).detach();
3417 cx.run_until_parked();
3418
3419 // Verify that:
3420 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3421 // 2. The user message was NOT truncated
3422 assert!(
3423 *saw_refusal_event.lock().unwrap(),
3424 "Refusal event should be emitted for tool result refusals"
3425 );
3426
3427 thread.read_with(cx, |thread, _| {
3428 let entries = thread.entries();
3429 assert!(entries.len() >= 2, "Should have user message and tool call");
3430
3431 // Verify user message is still there
3432 assert!(
3433 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3434 "User message should not be truncated"
3435 );
3436
3437 // Verify tool call is there with result
3438 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3439 assert!(
3440 tool_call.raw_output.is_some(),
3441 "Tool call should have output"
3442 );
3443 } else {
3444 panic!("Expected tool call at index 1");
3445 }
3446 });
3447 }
3448
3449 #[gpui::test]
3450 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3451 init_test(cx);
3452
3453 let fs = FakeFs::new(cx.executor());
3454 let project = Project::test(fs, None, cx).await;
3455
3456 let refuse_next = Arc::new(AtomicBool::new(false));
3457 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3458 let refuse_next = refuse_next.clone();
3459 move |_request, _thread, _cx| {
3460 if refuse_next.load(SeqCst) {
3461 async move { Ok(acp::PromptResponse::new(acp::StopReason::Refusal)) }
3462 .boxed_local()
3463 } else {
3464 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }
3465 .boxed_local()
3466 }
3467 }
3468 }));
3469
3470 let thread = cx
3471 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3472 .await
3473 .unwrap();
3474
3475 // Track if we see a Refusal event
3476 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3477 let saw_refusal_event_captured = saw_refusal_event.clone();
3478 thread.update(cx, |_thread, cx| {
3479 cx.subscribe(
3480 &thread,
3481 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3482 if matches!(event, AcpThreadEvent::Refusal) {
3483 *saw_refusal_event_captured.lock().unwrap() = true;
3484 }
3485 },
3486 )
3487 .detach();
3488 });
3489
3490 // Send a message that will be refused
3491 refuse_next.store(true, SeqCst);
3492 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3493 .await
3494 .unwrap();
3495
3496 // Verify that a Refusal event WAS emitted for user prompt refusal
3497 assert!(
3498 *saw_refusal_event.lock().unwrap(),
3499 "Refusal event should be emitted for user prompt refusals"
3500 );
3501
3502 // Verify the message was truncated (user prompt refusal)
3503 thread.read_with(cx, |thread, cx| {
3504 assert_eq!(thread.to_markdown(cx), "");
3505 });
3506 }
3507
3508 #[gpui::test]
3509 async fn test_refusal(cx: &mut TestAppContext) {
3510 init_test(cx);
3511 let fs = FakeFs::new(cx.background_executor.clone());
3512 fs.insert_tree(path!("/"), json!({})).await;
3513 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3514
3515 let refuse_next = Arc::new(AtomicBool::new(false));
3516 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3517 let refuse_next = refuse_next.clone();
3518 move |request, thread, mut cx| {
3519 let refuse_next = refuse_next.clone();
3520 async move {
3521 if refuse_next.load(SeqCst) {
3522 return Ok(acp::PromptResponse::new(acp::StopReason::Refusal));
3523 }
3524
3525 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3526 panic!("expected text content block");
3527 };
3528 thread.update(&mut cx, |thread, cx| {
3529 thread
3530 .handle_session_update(
3531 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
3532 content.text.to_uppercase().into(),
3533 )),
3534 cx,
3535 )
3536 .unwrap();
3537 })?;
3538 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
3539 }
3540 .boxed_local()
3541 }
3542 }));
3543 let thread = cx
3544 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3545 .await
3546 .unwrap();
3547
3548 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3549 .await
3550 .unwrap();
3551 thread.read_with(cx, |thread, cx| {
3552 assert_eq!(
3553 thread.to_markdown(cx),
3554 indoc! {"
3555 ## User
3556
3557 hello
3558
3559 ## Assistant
3560
3561 HELLO
3562
3563 "}
3564 );
3565 });
3566
3567 // Simulate refusing the second message. The message should be truncated
3568 // when a user prompt is refused.
3569 refuse_next.store(true, SeqCst);
3570 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3571 .await
3572 .unwrap();
3573 thread.read_with(cx, |thread, cx| {
3574 assert_eq!(
3575 thread.to_markdown(cx),
3576 indoc! {"
3577 ## User
3578
3579 hello
3580
3581 ## Assistant
3582
3583 HELLO
3584
3585 "}
3586 );
3587 });
3588 }
3589
3590 async fn run_until_first_tool_call(
3591 thread: &Entity<AcpThread>,
3592 cx: &mut TestAppContext,
3593 ) -> usize {
3594 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3595
3596 let subscription = cx.update(|cx| {
3597 cx.subscribe(thread, move |thread, _, cx| {
3598 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3599 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3600 return tx.try_send(ix).unwrap();
3601 }
3602 }
3603 })
3604 });
3605
3606 select! {
3607 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3608 panic!("Timeout waiting for tool call")
3609 }
3610 ix = rx.next().fuse() => {
3611 drop(subscription);
3612 ix.unwrap()
3613 }
3614 }
3615 }
3616
3617 #[derive(Clone, Default)]
3618 struct FakeAgentConnection {
3619 auth_methods: Vec<acp::AuthMethod>,
3620 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3621 on_user_message: Option<
3622 Rc<
3623 dyn Fn(
3624 acp::PromptRequest,
3625 WeakEntity<AcpThread>,
3626 AsyncApp,
3627 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3628 + 'static,
3629 >,
3630 >,
3631 }
3632
3633 impl FakeAgentConnection {
3634 fn new() -> Self {
3635 Self {
3636 auth_methods: Vec::new(),
3637 on_user_message: None,
3638 sessions: Arc::default(),
3639 }
3640 }
3641
3642 #[expect(unused)]
3643 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3644 self.auth_methods = auth_methods;
3645 self
3646 }
3647
3648 fn on_user_message(
3649 mut self,
3650 handler: impl Fn(
3651 acp::PromptRequest,
3652 WeakEntity<AcpThread>,
3653 AsyncApp,
3654 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3655 + 'static,
3656 ) -> Self {
3657 self.on_user_message.replace(Rc::new(handler));
3658 self
3659 }
3660 }
3661
3662 impl AgentConnection for FakeAgentConnection {
3663 fn telemetry_id(&self) -> SharedString {
3664 "fake".into()
3665 }
3666
3667 fn auth_methods(&self) -> &[acp::AuthMethod] {
3668 &self.auth_methods
3669 }
3670
3671 fn new_thread(
3672 self: Rc<Self>,
3673 project: Entity<Project>,
3674 _cwd: &Path,
3675 cx: &mut App,
3676 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3677 let session_id = acp::SessionId::new(
3678 rand::rng()
3679 .sample_iter(&distr::Alphanumeric)
3680 .take(7)
3681 .map(char::from)
3682 .collect::<String>(),
3683 );
3684 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3685 let thread = cx.new(|cx| {
3686 AcpThread::new(
3687 "Test",
3688 self.clone(),
3689 project,
3690 action_log,
3691 session_id.clone(),
3692 watch::Receiver::constant(
3693 acp::PromptCapabilities::new()
3694 .image(true)
3695 .audio(true)
3696 .embedded_context(true),
3697 ),
3698 cx,
3699 )
3700 });
3701 self.sessions.lock().insert(session_id, thread.downgrade());
3702 Task::ready(Ok(thread))
3703 }
3704
3705 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3706 if self.auth_methods().iter().any(|m| m.id == method) {
3707 Task::ready(Ok(()))
3708 } else {
3709 Task::ready(Err(anyhow!("Invalid Auth Method")))
3710 }
3711 }
3712
3713 fn prompt(
3714 &self,
3715 _id: Option<UserMessageId>,
3716 params: acp::PromptRequest,
3717 cx: &mut App,
3718 ) -> Task<gpui::Result<acp::PromptResponse>> {
3719 let sessions = self.sessions.lock();
3720 let thread = sessions.get(¶ms.session_id).unwrap();
3721 if let Some(handler) = &self.on_user_message {
3722 let handler = handler.clone();
3723 let thread = thread.clone();
3724 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3725 } else {
3726 Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
3727 }
3728 }
3729
3730 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3731 let sessions = self.sessions.lock();
3732 let thread = sessions.get(session_id).unwrap().clone();
3733
3734 cx.spawn(async move |cx| {
3735 thread
3736 .update(cx, |thread, cx| thread.cancel(cx))
3737 .unwrap()
3738 .await
3739 })
3740 .detach();
3741 }
3742
3743 fn truncate(
3744 &self,
3745 session_id: &acp::SessionId,
3746 _cx: &App,
3747 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3748 Some(Rc::new(FakeAgentSessionEditor {
3749 _session_id: session_id.clone(),
3750 }))
3751 }
3752
3753 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3754 self
3755 }
3756 }
3757
3758 struct FakeAgentSessionEditor {
3759 _session_id: acp::SessionId,
3760 }
3761
3762 impl AgentSessionTruncate for FakeAgentSessionEditor {
3763 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3764 Task::ready(Ok(()))
3765 }
3766 }
3767
3768 #[gpui::test]
3769 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3770 init_test(cx);
3771
3772 let fs = FakeFs::new(cx.executor());
3773 let project = Project::test(fs, [], cx).await;
3774 let connection = Rc::new(FakeAgentConnection::new());
3775 let thread = cx
3776 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3777 .await
3778 .unwrap();
3779
3780 // Try to update a tool call that doesn't exist
3781 let nonexistent_id = acp::ToolCallId::new("nonexistent-tool-call");
3782 thread.update(cx, |thread, cx| {
3783 let result = thread.handle_session_update(
3784 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new(
3785 nonexistent_id.clone(),
3786 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed),
3787 )),
3788 cx,
3789 );
3790
3791 // The update should succeed (not return an error)
3792 assert!(result.is_ok());
3793
3794 // There should now be exactly one entry in the thread
3795 assert_eq!(thread.entries.len(), 1);
3796
3797 // The entry should be a failed tool call
3798 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3799 assert_eq!(tool_call.id, nonexistent_id);
3800 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3801 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3802
3803 // Check that the content contains the error message
3804 assert_eq!(tool_call.content.len(), 1);
3805 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3806 match content_block {
3807 ContentBlock::Markdown { markdown } => {
3808 let markdown_text = markdown.read(cx).source();
3809 assert!(markdown_text.contains("Tool call not found"));
3810 }
3811 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3812 ContentBlock::ResourceLink { .. } => {
3813 panic!("Expected markdown content, got resource link")
3814 }
3815 ContentBlock::Image { .. } => {
3816 panic!("Expected markdown content, got image")
3817 }
3818 }
3819 } else {
3820 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3821 }
3822 } else {
3823 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3824 }
3825 });
3826 }
3827
3828 /// Tests that restoring a checkpoint properly cleans up terminals that were
3829 /// created after that checkpoint, and cancels any in-progress generation.
3830 ///
3831 /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
3832 /// that were started after that checkpoint should be terminated, and any in-progress
3833 /// AI generation should be canceled.
3834 #[gpui::test]
3835 async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
3836 init_test(cx);
3837
3838 let fs = FakeFs::new(cx.executor());
3839 let project = Project::test(fs, [], cx).await;
3840 let connection = Rc::new(FakeAgentConnection::new());
3841 let thread = cx
3842 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3843 .await
3844 .unwrap();
3845
3846 // Send first user message to create a checkpoint
3847 cx.update(|cx| {
3848 thread.update(cx, |thread, cx| {
3849 thread.send(vec!["first message".into()], cx)
3850 })
3851 })
3852 .await
3853 .unwrap();
3854
3855 // Send second message (creates another checkpoint) - we'll restore to this one
3856 cx.update(|cx| {
3857 thread.update(cx, |thread, cx| {
3858 thread.send(vec!["second message".into()], cx)
3859 })
3860 })
3861 .await
3862 .unwrap();
3863
3864 // Create 2 terminals BEFORE the checkpoint that have completed running
3865 let terminal_id_1 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3866 let mock_terminal_1 = cx.new(|cx| {
3867 let builder = ::terminal::TerminalBuilder::new_display_only(
3868 ::terminal::terminal_settings::CursorShape::default(),
3869 ::terminal::terminal_settings::AlternateScroll::On,
3870 None,
3871 0,
3872 )
3873 .unwrap();
3874 builder.subscribe(cx)
3875 });
3876
3877 thread.update(cx, |thread, cx| {
3878 thread.on_terminal_provider_event(
3879 TerminalProviderEvent::Created {
3880 terminal_id: terminal_id_1.clone(),
3881 label: "echo 'first'".to_string(),
3882 cwd: Some(PathBuf::from("/test")),
3883 output_byte_limit: None,
3884 terminal: mock_terminal_1.clone(),
3885 },
3886 cx,
3887 );
3888 });
3889
3890 thread.update(cx, |thread, cx| {
3891 thread.on_terminal_provider_event(
3892 TerminalProviderEvent::Output {
3893 terminal_id: terminal_id_1.clone(),
3894 data: b"first\n".to_vec(),
3895 },
3896 cx,
3897 );
3898 });
3899
3900 thread.update(cx, |thread, cx| {
3901 thread.on_terminal_provider_event(
3902 TerminalProviderEvent::Exit {
3903 terminal_id: terminal_id_1.clone(),
3904 status: acp::TerminalExitStatus::new().exit_code(0),
3905 },
3906 cx,
3907 );
3908 });
3909
3910 let terminal_id_2 = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3911 let mock_terminal_2 = cx.new(|cx| {
3912 let builder = ::terminal::TerminalBuilder::new_display_only(
3913 ::terminal::terminal_settings::CursorShape::default(),
3914 ::terminal::terminal_settings::AlternateScroll::On,
3915 None,
3916 0,
3917 )
3918 .unwrap();
3919 builder.subscribe(cx)
3920 });
3921
3922 thread.update(cx, |thread, cx| {
3923 thread.on_terminal_provider_event(
3924 TerminalProviderEvent::Created {
3925 terminal_id: terminal_id_2.clone(),
3926 label: "echo 'second'".to_string(),
3927 cwd: Some(PathBuf::from("/test")),
3928 output_byte_limit: None,
3929 terminal: mock_terminal_2.clone(),
3930 },
3931 cx,
3932 );
3933 });
3934
3935 thread.update(cx, |thread, cx| {
3936 thread.on_terminal_provider_event(
3937 TerminalProviderEvent::Output {
3938 terminal_id: terminal_id_2.clone(),
3939 data: b"second\n".to_vec(),
3940 },
3941 cx,
3942 );
3943 });
3944
3945 thread.update(cx, |thread, cx| {
3946 thread.on_terminal_provider_event(
3947 TerminalProviderEvent::Exit {
3948 terminal_id: terminal_id_2.clone(),
3949 status: acp::TerminalExitStatus::new().exit_code(0),
3950 },
3951 cx,
3952 );
3953 });
3954
3955 // Get the second message ID to restore to
3956 let second_message_id = thread.read_with(cx, |thread, _| {
3957 // At this point we have:
3958 // - Index 0: First user message (with checkpoint)
3959 // - Index 1: Second user message (with checkpoint)
3960 // No assistant responses because FakeAgentConnection just returns EndTurn
3961 let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
3962 panic!("expected user message at index 1");
3963 };
3964 message.id.clone().unwrap()
3965 });
3966
3967 // Create a terminal AFTER the checkpoint we'll restore to.
3968 // This simulates the AI agent starting a long-running terminal command.
3969 let terminal_id = acp::TerminalId::new(uuid::Uuid::new_v4().to_string());
3970 let mock_terminal = cx.new(|cx| {
3971 let builder = ::terminal::TerminalBuilder::new_display_only(
3972 ::terminal::terminal_settings::CursorShape::default(),
3973 ::terminal::terminal_settings::AlternateScroll::On,
3974 None,
3975 0,
3976 )
3977 .unwrap();
3978 builder.subscribe(cx)
3979 });
3980
3981 // Register the terminal as created
3982 thread.update(cx, |thread, cx| {
3983 thread.on_terminal_provider_event(
3984 TerminalProviderEvent::Created {
3985 terminal_id: terminal_id.clone(),
3986 label: "sleep 1000".to_string(),
3987 cwd: Some(PathBuf::from("/test")),
3988 output_byte_limit: None,
3989 terminal: mock_terminal.clone(),
3990 },
3991 cx,
3992 );
3993 });
3994
3995 // Simulate the terminal producing output (still running)
3996 thread.update(cx, |thread, cx| {
3997 thread.on_terminal_provider_event(
3998 TerminalProviderEvent::Output {
3999 terminal_id: terminal_id.clone(),
4000 data: b"terminal is running...\n".to_vec(),
4001 },
4002 cx,
4003 );
4004 });
4005
4006 // Create a tool call entry that references this terminal
4007 // This represents the agent requesting a terminal command
4008 thread.update(cx, |thread, cx| {
4009 thread
4010 .handle_session_update(
4011 acp::SessionUpdate::ToolCall(
4012 acp::ToolCall::new("terminal-tool-1", "Running command")
4013 .kind(acp::ToolKind::Execute)
4014 .status(acp::ToolCallStatus::InProgress)
4015 .content(vec![acp::ToolCallContent::Terminal(acp::Terminal::new(
4016 terminal_id.clone(),
4017 ))])
4018 .raw_input(serde_json::json!({"command": "sleep 1000", "cd": "/test"})),
4019 ),
4020 cx,
4021 )
4022 .unwrap();
4023 });
4024
4025 // Verify terminal exists and is in the thread
4026 let terminal_exists_before =
4027 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4028 assert!(
4029 terminal_exists_before,
4030 "Terminal should exist before checkpoint restore"
4031 );
4032
4033 // Verify the terminal's underlying task is still running (not completed)
4034 let terminal_running_before = thread.read_with(cx, |thread, _cx| {
4035 let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
4036 terminal_entity.read_with(cx, |term, _cx| {
4037 term.output().is_none() // output is None means it's still running
4038 })
4039 });
4040 assert!(
4041 terminal_running_before,
4042 "Terminal should be running before checkpoint restore"
4043 );
4044
4045 // Verify we have the expected entries before restore
4046 let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
4047 assert!(
4048 entry_count_before > 1,
4049 "Should have multiple entries before restore"
4050 );
4051
4052 // Restore the checkpoint to the second message.
4053 // This should:
4054 // 1. Cancel any in-progress generation (via the cancel() call)
4055 // 2. Remove the terminal that was created after that point
4056 thread
4057 .update(cx, |thread, cx| {
4058 thread.restore_checkpoint(second_message_id, cx)
4059 })
4060 .await
4061 .unwrap();
4062
4063 // Verify that no send_task is in progress after restore
4064 // (cancel() clears the send_task)
4065 let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
4066 assert!(
4067 !has_send_task_after,
4068 "Should not have a send_task after restore (cancel should have cleared it)"
4069 );
4070
4071 // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
4072 let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
4073 assert_eq!(
4074 entry_count, 1,
4075 "Should have 1 entry after restore (only the first user message)"
4076 );
4077
4078 // Verify the 2 completed terminals from before the checkpoint still exist
4079 let terminal_1_exists = thread.read_with(cx, |thread, _| {
4080 thread.terminals.contains_key(&terminal_id_1)
4081 });
4082 assert!(
4083 terminal_1_exists,
4084 "Terminal 1 (from before checkpoint) should still exist"
4085 );
4086
4087 let terminal_2_exists = thread.read_with(cx, |thread, _| {
4088 thread.terminals.contains_key(&terminal_id_2)
4089 });
4090 assert!(
4091 terminal_2_exists,
4092 "Terminal 2 (from before checkpoint) should still exist"
4093 );
4094
4095 // Verify they're still in completed state
4096 let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
4097 let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
4098 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4099 });
4100 assert!(terminal_1_completed, "Terminal 1 should still be completed");
4101
4102 let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
4103 let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
4104 terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
4105 });
4106 assert!(terminal_2_completed, "Terminal 2 should still be completed");
4107
4108 // Verify the running terminal (created after checkpoint) was removed
4109 let terminal_3_exists =
4110 thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
4111 assert!(
4112 !terminal_3_exists,
4113 "Terminal 3 (created after checkpoint) should have been removed"
4114 );
4115
4116 // Verify total count is 2 (the two from before the checkpoint)
4117 let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
4118 assert_eq!(
4119 terminal_count, 2,
4120 "Should have exactly 2 terminals (the completed ones from before checkpoint)"
4121 );
4122 }
4123
4124 /// Tests that update_last_checkpoint correctly updates the original message's checkpoint
4125 /// even when a new user message is added while the async checkpoint comparison is in progress.
4126 ///
4127 /// This is a regression test for a bug where update_last_checkpoint would fail with
4128 /// "no checkpoint" if a new user message (without a checkpoint) was added between when
4129 /// update_last_checkpoint started and when its async closure ran.
4130 #[gpui::test]
4131 async fn test_update_last_checkpoint_with_new_message_added(cx: &mut TestAppContext) {
4132 init_test(cx);
4133
4134 let fs = FakeFs::new(cx.executor());
4135 fs.insert_tree(path!("/test"), json!({".git": {}, "file.txt": "content"}))
4136 .await;
4137 let project = Project::test(fs.clone(), [Path::new(path!("/test"))], cx).await;
4138
4139 let handler_done = Arc::new(AtomicBool::new(false));
4140 let handler_done_clone = handler_done.clone();
4141 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
4142 move |_, _thread, _cx| {
4143 handler_done_clone.store(true, SeqCst);
4144 async move { Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }.boxed_local()
4145 },
4146 ));
4147
4148 let thread = cx
4149 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
4150 .await
4151 .unwrap();
4152
4153 let send_future = thread.update(cx, |thread, cx| thread.send_raw("First message", cx));
4154 let send_task = cx.background_executor.spawn(send_future);
4155
4156 // Tick until handler completes, then a few more to let update_last_checkpoint start
4157 while !handler_done.load(SeqCst) {
4158 cx.executor().tick();
4159 }
4160 for _ in 0..5 {
4161 cx.executor().tick();
4162 }
4163
4164 thread.update(cx, |thread, cx| {
4165 thread.push_entry(
4166 AgentThreadEntry::UserMessage(UserMessage {
4167 id: Some(UserMessageId::new()),
4168 content: ContentBlock::Empty,
4169 chunks: vec!["Injected message (no checkpoint)".into()],
4170 checkpoint: None,
4171 indented: false,
4172 }),
4173 cx,
4174 );
4175 });
4176
4177 cx.run_until_parked();
4178 let result = send_task.await;
4179
4180 assert!(
4181 result.is_ok(),
4182 "send should succeed even when new message added during update_last_checkpoint: {:?}",
4183 result.err()
4184 );
4185 }
4186}