1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use ::terminal::terminal_settings::TerminalSettings;
7use agent_settings::AgentSettings;
8use collections::HashSet;
9pub use connection::*;
10pub use diff::*;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::ActionLog;
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77}
78
79impl AssistantMessage {
80 pub fn to_markdown(&self, cx: &App) -> String {
81 format!(
82 "## Assistant\n\n{}\n\n",
83 self.chunks
84 .iter()
85 .map(|chunk| chunk.to_markdown(cx))
86 .join("\n\n")
87 )
88 }
89}
90
91#[derive(Debug, PartialEq)]
92pub enum AssistantMessageChunk {
93 Message { block: ContentBlock },
94 Thought { block: ContentBlock },
95}
96
97impl AssistantMessageChunk {
98 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Message {
100 block: ContentBlock::new(chunk.into(), language_registry, cx),
101 }
102 }
103
104 fn to_markdown(&self, cx: &App) -> String {
105 match self {
106 Self::Message { block } => block.to_markdown(cx).to_string(),
107 Self::Thought { block } => {
108 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
109 }
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum AgentThreadEntry {
116 UserMessage(UserMessage),
117 AssistantMessage(AssistantMessage),
118 ToolCall(ToolCall),
119}
120
121impl AgentThreadEntry {
122 pub fn to_markdown(&self, cx: &App) -> String {
123 match self {
124 Self::UserMessage(message) => message.to_markdown(cx),
125 Self::AssistantMessage(message) => message.to_markdown(cx),
126 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
127 }
128 }
129
130 pub fn user_message(&self) -> Option<&UserMessage> {
131 if let AgentThreadEntry::UserMessage(message) = self {
132 Some(message)
133 } else {
134 None
135 }
136 }
137
138 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.diffs())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
147 if let AgentThreadEntry::ToolCall(call) = self {
148 itertools::Either::Left(call.terminals())
149 } else {
150 itertools::Either::Right(std::iter::empty())
151 }
152 }
153
154 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
155 if let AgentThreadEntry::ToolCall(ToolCall {
156 locations,
157 resolved_locations,
158 ..
159 }) = self
160 {
161 Some((
162 locations.get(ix)?.clone(),
163 resolved_locations.get(ix)?.clone()?,
164 ))
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 pub content: Vec<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179 pub resolved_locations: Vec<Option<AgentLocation>>,
180 pub raw_input: Option<serde_json::Value>,
181 pub raw_output: Option<serde_json::Value>,
182}
183
184impl ToolCall {
185 fn from_acp(
186 tool_call: acp::ToolCall,
187 status: ToolCallStatus,
188 language_registry: Arc<LanguageRegistry>,
189 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
190 cx: &mut App,
191 ) -> Result<Self> {
192 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
193 first_line.to_owned() + "…"
194 } else {
195 tool_call.title
196 };
197 let mut content = Vec::with_capacity(tool_call.content.len());
198 for item in tool_call.content {
199 content.push(ToolCallContent::from_acp(
200 item,
201 language_registry.clone(),
202 terminals,
203 cx,
204 )?);
205 }
206
207 let result = Self {
208 id: tool_call.id,
209 label: cx
210 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
211 kind: tool_call.kind,
212 content,
213 locations: tool_call.locations,
214 resolved_locations: Vec::default(),
215 status,
216 raw_input: tool_call.raw_input,
217 raw_output: tool_call.raw_output,
218 };
219 Ok(result)
220 }
221
222 fn update_fields(
223 &mut self,
224 fields: acp::ToolCallUpdateFields,
225 language_registry: Arc<LanguageRegistry>,
226 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
227 cx: &mut App,
228 ) -> Result<()> {
229 let acp::ToolCallUpdateFields {
230 kind,
231 status,
232 title,
233 content,
234 locations,
235 raw_input,
236 raw_output,
237 } = fields;
238
239 if let Some(kind) = kind {
240 self.kind = kind;
241 }
242
243 if let Some(status) = status {
244 self.status = status.into();
245 }
246
247 if let Some(title) = title {
248 self.label.update(cx, |label, cx| {
249 if let Some((first_line, _)) = title.split_once("\n") {
250 label.replace(first_line.to_owned() + "…", cx)
251 } else {
252 label.replace(title, cx);
253 }
254 });
255 }
256
257 if let Some(content) = content {
258 let new_content_len = content.len();
259 let mut content = content.into_iter();
260
261 // Reuse existing content if we can
262 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
263 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
264 }
265 for new in content {
266 self.content.push(ToolCallContent::from_acp(
267 new,
268 language_registry.clone(),
269 terminals,
270 cx,
271 )?)
272 }
273 self.content.truncate(new_content_len);
274 }
275
276 if let Some(locations) = locations {
277 self.locations = locations;
278 }
279
280 if let Some(raw_input) = raw_input {
281 self.raw_input = Some(raw_input);
282 }
283
284 if let Some(raw_output) = raw_output {
285 if self.content.is_empty()
286 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
287 {
288 self.content
289 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
290 markdown,
291 }));
292 }
293 self.raw_output = Some(raw_output);
294 }
295 Ok(())
296 }
297
298 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
299 self.content.iter().filter_map(|content| match content {
300 ToolCallContent::Diff(diff) => Some(diff),
301 ToolCallContent::ContentBlock(_) => None,
302 ToolCallContent::Terminal(_) => None,
303 })
304 }
305
306 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
307 self.content.iter().filter_map(|content| match content {
308 ToolCallContent::Terminal(terminal) => Some(terminal),
309 ToolCallContent::ContentBlock(_) => None,
310 ToolCallContent::Diff(_) => None,
311 })
312 }
313
314 fn to_markdown(&self, cx: &App) -> String {
315 let mut markdown = format!(
316 "**Tool Call: {}**\nStatus: {}\n\n",
317 self.label.read(cx).source(),
318 self.status
319 );
320 for content in &self.content {
321 markdown.push_str(content.to_markdown(cx).as_str());
322 markdown.push_str("\n\n");
323 }
324 markdown
325 }
326
327 async fn resolve_location(
328 location: acp::ToolCallLocation,
329 project: WeakEntity<Project>,
330 cx: &mut AsyncApp,
331 ) -> Option<AgentLocation> {
332 let buffer = project
333 .update(cx, |project, cx| {
334 project
335 .project_path_for_absolute_path(&location.path, cx)
336 .map(|path| project.open_buffer(path, cx))
337 })
338 .ok()??;
339 let buffer = buffer.await.log_err()?;
340 let position = buffer
341 .update(cx, |buffer, _| {
342 if let Some(row) = location.line {
343 let snapshot = buffer.snapshot();
344 let column = snapshot.indent_size_for_line(row).len;
345 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
346 snapshot.anchor_before(point)
347 } else {
348 Anchor::MIN
349 }
350 })
351 .ok()?;
352
353 Some(AgentLocation {
354 buffer: buffer.downgrade(),
355 position,
356 })
357 }
358
359 fn resolve_locations(
360 &self,
361 project: Entity<Project>,
362 cx: &mut App,
363 ) -> Task<Vec<Option<AgentLocation>>> {
364 let locations = self.locations.clone();
365 project.update(cx, |_, cx| {
366 cx.spawn(async move |project, cx| {
367 let mut new_locations = Vec::new();
368 for location in locations {
369 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
370 }
371 new_locations
372 })
373 })
374 }
375}
376
377#[derive(Debug)]
378pub enum ToolCallStatus {
379 /// The tool call hasn't started running yet, but we start showing it to
380 /// the user.
381 Pending,
382 /// The tool call is waiting for confirmation from the user.
383 WaitingForConfirmation {
384 options: Vec<acp::PermissionOption>,
385 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
386 },
387 /// The tool call is currently running.
388 InProgress,
389 /// The tool call completed successfully.
390 Completed,
391 /// The tool call failed.
392 Failed,
393 /// The user rejected the tool call.
394 Rejected,
395 /// The user canceled generation so the tool call was canceled.
396 Canceled,
397}
398
399impl From<acp::ToolCallStatus> for ToolCallStatus {
400 fn from(status: acp::ToolCallStatus) -> Self {
401 match status {
402 acp::ToolCallStatus::Pending => Self::Pending,
403 acp::ToolCallStatus::InProgress => Self::InProgress,
404 acp::ToolCallStatus::Completed => Self::Completed,
405 acp::ToolCallStatus::Failed => Self::Failed,
406 }
407 }
408}
409
410impl Display for ToolCallStatus {
411 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
412 write!(
413 f,
414 "{}",
415 match self {
416 ToolCallStatus::Pending => "Pending",
417 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
418 ToolCallStatus::InProgress => "In Progress",
419 ToolCallStatus::Completed => "Completed",
420 ToolCallStatus::Failed => "Failed",
421 ToolCallStatus::Rejected => "Rejected",
422 ToolCallStatus::Canceled => "Canceled",
423 }
424 )
425 }
426}
427
428#[derive(Debug, PartialEq, Clone)]
429pub enum ContentBlock {
430 Empty,
431 Markdown { markdown: Entity<Markdown> },
432 ResourceLink { resource_link: acp::ResourceLink },
433}
434
435impl ContentBlock {
436 pub fn new(
437 block: acp::ContentBlock,
438 language_registry: &Arc<LanguageRegistry>,
439 cx: &mut App,
440 ) -> Self {
441 let mut this = Self::Empty;
442 this.append(block, language_registry, cx);
443 this
444 }
445
446 pub fn new_combined(
447 blocks: impl IntoIterator<Item = acp::ContentBlock>,
448 language_registry: Arc<LanguageRegistry>,
449 cx: &mut App,
450 ) -> Self {
451 let mut this = Self::Empty;
452 for block in blocks {
453 this.append(block, &language_registry, cx);
454 }
455 this
456 }
457
458 pub fn append(
459 &mut self,
460 block: acp::ContentBlock,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) {
464 if matches!(self, ContentBlock::Empty)
465 && let acp::ContentBlock::ResourceLink(resource_link) = block
466 {
467 *self = ContentBlock::ResourceLink { resource_link };
468 return;
469 }
470
471 let new_content = self.block_string_contents(block);
472
473 match self {
474 ContentBlock::Empty => {
475 *self = Self::create_markdown_block(new_content, language_registry, cx);
476 }
477 ContentBlock::Markdown { markdown } => {
478 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
479 }
480 ContentBlock::ResourceLink { resource_link } => {
481 let existing_content = Self::resource_link_md(&resource_link.uri);
482 let combined = format!("{}\n{}", existing_content, new_content);
483
484 *self = Self::create_markdown_block(combined, language_registry, cx);
485 }
486 }
487 }
488
489 fn create_markdown_block(
490 content: String,
491 language_registry: &Arc<LanguageRegistry>,
492 cx: &mut App,
493 ) -> ContentBlock {
494 ContentBlock::Markdown {
495 markdown: cx
496 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
497 }
498 }
499
500 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
501 match block {
502 acp::ContentBlock::Text(text_content) => text_content.text,
503 acp::ContentBlock::ResourceLink(resource_link) => {
504 Self::resource_link_md(&resource_link.uri)
505 }
506 acp::ContentBlock::Resource(acp::EmbeddedResource {
507 resource:
508 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
509 uri,
510 ..
511 }),
512 ..
513 }) => Self::resource_link_md(&uri),
514 acp::ContentBlock::Image(image) => Self::image_md(&image),
515 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
516 }
517 }
518
519 fn resource_link_md(uri: &str) -> String {
520 if let Some(uri) = MentionUri::parse(uri).log_err() {
521 uri.as_link().to_string()
522 } else {
523 uri.to_string()
524 }
525 }
526
527 fn image_md(_image: &acp::ImageContent) -> String {
528 "`Image`".into()
529 }
530
531 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
532 match self {
533 ContentBlock::Empty => "",
534 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
535 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
536 }
537 }
538
539 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
540 match self {
541 ContentBlock::Empty => None,
542 ContentBlock::Markdown { markdown } => Some(markdown),
543 ContentBlock::ResourceLink { .. } => None,
544 }
545 }
546
547 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
548 match self {
549 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
550 _ => None,
551 }
552 }
553}
554
555#[derive(Debug)]
556pub enum ToolCallContent {
557 ContentBlock(ContentBlock),
558 Diff(Entity<Diff>),
559 Terminal(Entity<Terminal>),
560}
561
562impl ToolCallContent {
563 pub fn from_acp(
564 content: acp::ToolCallContent,
565 language_registry: Arc<LanguageRegistry>,
566 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
567 cx: &mut App,
568 ) -> Result<Self> {
569 match content {
570 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
571 content,
572 &language_registry,
573 cx,
574 ))),
575 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
576 Diff::finalized(
577 diff.path.to_string_lossy().into_owned(),
578 diff.old_text,
579 diff.new_text,
580 language_registry,
581 cx,
582 )
583 }))),
584 acp::ToolCallContent::Terminal { terminal_id } => terminals
585 .get(&terminal_id)
586 .cloned()
587 .map(Self::Terminal)
588 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
589 }
590 }
591
592 pub fn update_from_acp(
593 &mut self,
594 new: acp::ToolCallContent,
595 language_registry: Arc<LanguageRegistry>,
596 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
597 cx: &mut App,
598 ) -> Result<()> {
599 let needs_update = match (&self, &new) {
600 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
601 old_diff.read(cx).needs_update(
602 new_diff.old_text.as_deref().unwrap_or(""),
603 &new_diff.new_text,
604 cx,
605 )
606 }
607 _ => true,
608 };
609
610 if needs_update {
611 *self = Self::from_acp(new, language_registry, terminals, cx)?;
612 }
613 Ok(())
614 }
615
616 pub fn to_markdown(&self, cx: &App) -> String {
617 match self {
618 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
619 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
620 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
621 }
622 }
623}
624
625#[derive(Debug, PartialEq)]
626pub enum ToolCallUpdate {
627 UpdateFields(acp::ToolCallUpdate),
628 UpdateDiff(ToolCallUpdateDiff),
629 UpdateTerminal(ToolCallUpdateTerminal),
630}
631
632impl ToolCallUpdate {
633 fn id(&self) -> &acp::ToolCallId {
634 match self {
635 Self::UpdateFields(update) => &update.id,
636 Self::UpdateDiff(diff) => &diff.id,
637 Self::UpdateTerminal(terminal) => &terminal.id,
638 }
639 }
640}
641
642impl From<acp::ToolCallUpdate> for ToolCallUpdate {
643 fn from(update: acp::ToolCallUpdate) -> Self {
644 Self::UpdateFields(update)
645 }
646}
647
648impl From<ToolCallUpdateDiff> for ToolCallUpdate {
649 fn from(diff: ToolCallUpdateDiff) -> Self {
650 Self::UpdateDiff(diff)
651 }
652}
653
654#[derive(Debug, PartialEq)]
655pub struct ToolCallUpdateDiff {
656 pub id: acp::ToolCallId,
657 pub diff: Entity<Diff>,
658}
659
660impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
661 fn from(terminal: ToolCallUpdateTerminal) -> Self {
662 Self::UpdateTerminal(terminal)
663 }
664}
665
666#[derive(Debug, PartialEq)]
667pub struct ToolCallUpdateTerminal {
668 pub id: acp::ToolCallId,
669 pub terminal: Entity<Terminal>,
670}
671
672#[derive(Debug, Default)]
673pub struct Plan {
674 pub entries: Vec<PlanEntry>,
675}
676
677#[derive(Debug)]
678pub struct PlanStats<'a> {
679 pub in_progress_entry: Option<&'a PlanEntry>,
680 pub pending: u32,
681 pub completed: u32,
682}
683
684impl Plan {
685 pub fn is_empty(&self) -> bool {
686 self.entries.is_empty()
687 }
688
689 pub fn stats(&self) -> PlanStats<'_> {
690 let mut stats = PlanStats {
691 in_progress_entry: None,
692 pending: 0,
693 completed: 0,
694 };
695
696 for entry in &self.entries {
697 match &entry.status {
698 acp::PlanEntryStatus::Pending => {
699 stats.pending += 1;
700 }
701 acp::PlanEntryStatus::InProgress => {
702 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
703 }
704 acp::PlanEntryStatus::Completed => {
705 stats.completed += 1;
706 }
707 }
708 }
709
710 stats
711 }
712}
713
714#[derive(Debug)]
715pub struct PlanEntry {
716 pub content: Entity<Markdown>,
717 pub priority: acp::PlanEntryPriority,
718 pub status: acp::PlanEntryStatus,
719}
720
721impl PlanEntry {
722 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
723 Self {
724 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
725 priority: entry.priority,
726 status: entry.status,
727 }
728 }
729}
730
731#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
732pub struct TokenUsage {
733 pub max_tokens: u64,
734 pub used_tokens: u64,
735}
736
737impl TokenUsage {
738 pub fn ratio(&self) -> TokenUsageRatio {
739 #[cfg(debug_assertions)]
740 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
741 .unwrap_or("0.8".to_string())
742 .parse()
743 .unwrap();
744 #[cfg(not(debug_assertions))]
745 let warning_threshold: f32 = 0.8;
746
747 // When the maximum is unknown because there is no selected model,
748 // avoid showing the token limit warning.
749 if self.max_tokens == 0 {
750 TokenUsageRatio::Normal
751 } else if self.used_tokens >= self.max_tokens {
752 TokenUsageRatio::Exceeded
753 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
754 TokenUsageRatio::Warning
755 } else {
756 TokenUsageRatio::Normal
757 }
758 }
759}
760
761#[derive(Debug, Clone, PartialEq, Eq)]
762pub enum TokenUsageRatio {
763 Normal,
764 Warning,
765 Exceeded,
766}
767
768#[derive(Debug, Clone)]
769pub struct RetryStatus {
770 pub last_error: SharedString,
771 pub attempt: usize,
772 pub max_attempts: usize,
773 pub started_at: Instant,
774 pub duration: Duration,
775}
776
777pub struct AcpThread {
778 title: SharedString,
779 entries: Vec<AgentThreadEntry>,
780 plan: Plan,
781 project: Entity<Project>,
782 action_log: Entity<ActionLog>,
783 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
784 send_task: Option<Task<()>>,
785 connection: Rc<dyn AgentConnection>,
786 session_id: acp::SessionId,
787 token_usage: Option<TokenUsage>,
788 prompt_capabilities: acp::PromptCapabilities,
789 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
790 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
791 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
792 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
793}
794
795#[derive(Debug)]
796pub enum AcpThreadEvent {
797 NewEntry,
798 TitleUpdated,
799 TokenUsageUpdated,
800 EntryUpdated(usize),
801 EntriesRemoved(Range<usize>),
802 ToolAuthorizationRequired,
803 Retry(RetryStatus),
804 Stopped,
805 Error,
806 LoadError(LoadError),
807 PromptCapabilitiesUpdated,
808 Refusal,
809 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
810 ModeUpdated(acp::SessionModeId),
811}
812
813impl EventEmitter<AcpThreadEvent> for AcpThread {}
814
815#[derive(Debug, Clone)]
816pub enum TerminalProviderEvent {
817 Created {
818 terminal_id: acp::TerminalId,
819 label: String,
820 cwd: Option<PathBuf>,
821 output_byte_limit: Option<u64>,
822 terminal: Entity<::terminal::Terminal>,
823 },
824 Output {
825 terminal_id: acp::TerminalId,
826 data: Vec<u8>,
827 },
828 TitleChanged {
829 terminal_id: acp::TerminalId,
830 title: String,
831 },
832 Exit {
833 terminal_id: acp::TerminalId,
834 status: acp::TerminalExitStatus,
835 },
836}
837
838#[derive(Debug, Clone)]
839pub enum TerminalProviderCommand {
840 WriteInput {
841 terminal_id: acp::TerminalId,
842 bytes: Vec<u8>,
843 },
844 Resize {
845 terminal_id: acp::TerminalId,
846 cols: u16,
847 rows: u16,
848 },
849 Close {
850 terminal_id: acp::TerminalId,
851 },
852}
853
854impl AcpThread {
855 pub fn on_terminal_provider_event(
856 &mut self,
857 event: TerminalProviderEvent,
858 cx: &mut Context<Self>,
859 ) {
860 match event {
861 TerminalProviderEvent::Created {
862 terminal_id,
863 label,
864 cwd,
865 output_byte_limit,
866 terminal,
867 } => {
868 let entity = self.register_terminal_created(
869 terminal_id.clone(),
870 label,
871 cwd,
872 output_byte_limit,
873 terminal,
874 cx,
875 );
876
877 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
878 for data in chunks.drain(..) {
879 entity.update(cx, |term, cx| {
880 term.inner().update(cx, |inner, cx| {
881 inner.write_output(&data, cx);
882 })
883 });
884 }
885 }
886
887 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
888 entity.update(cx, |_term, cx| {
889 cx.notify();
890 });
891 }
892
893 cx.notify();
894 }
895 TerminalProviderEvent::Output { terminal_id, data } => {
896 if let Some(entity) = self.terminals.get(&terminal_id) {
897 entity.update(cx, |term, cx| {
898 term.inner().update(cx, |inner, cx| {
899 inner.write_output(&data, cx);
900 })
901 });
902 } else {
903 self.pending_terminal_output
904 .entry(terminal_id)
905 .or_default()
906 .push(data);
907 }
908 }
909 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
910 if let Some(entity) = self.terminals.get(&terminal_id) {
911 entity.update(cx, |term, cx| {
912 term.inner().update(cx, |inner, cx| {
913 inner.breadcrumb_text = title;
914 cx.emit(::terminal::Event::BreadcrumbsChanged);
915 })
916 });
917 }
918 }
919 TerminalProviderEvent::Exit {
920 terminal_id,
921 status,
922 } => {
923 if let Some(entity) = self.terminals.get(&terminal_id) {
924 entity.update(cx, |_term, cx| {
925 cx.notify();
926 });
927 } else {
928 self.pending_terminal_exit.insert(terminal_id, status);
929 }
930 }
931 }
932 }
933}
934
935#[derive(PartialEq, Eq, Debug)]
936pub enum ThreadStatus {
937 Idle,
938 Generating,
939}
940
941#[derive(Debug, Clone)]
942pub enum LoadError {
943 Unsupported {
944 command: SharedString,
945 current_version: SharedString,
946 minimum_version: SharedString,
947 },
948 FailedToInstall(SharedString),
949 Exited {
950 status: ExitStatus,
951 },
952 Other(SharedString),
953}
954
955impl Display for LoadError {
956 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
957 match self {
958 LoadError::Unsupported {
959 command: path,
960 current_version,
961 minimum_version,
962 } => {
963 write!(
964 f,
965 "version {current_version} from {path} is not supported (need at least {minimum_version})"
966 )
967 }
968 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
969 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
970 LoadError::Other(msg) => write!(f, "{msg}"),
971 }
972 }
973}
974
975impl Error for LoadError {}
976
977impl AcpThread {
978 pub fn new(
979 title: impl Into<SharedString>,
980 connection: Rc<dyn AgentConnection>,
981 project: Entity<Project>,
982 action_log: Entity<ActionLog>,
983 session_id: acp::SessionId,
984 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
985 cx: &mut Context<Self>,
986 ) -> Self {
987 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
988 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
989 loop {
990 let caps = prompt_capabilities_rx.recv().await?;
991 this.update(cx, |this, cx| {
992 this.prompt_capabilities = caps;
993 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
994 })?;
995 }
996 });
997
998 Self {
999 action_log,
1000 shared_buffers: Default::default(),
1001 entries: Default::default(),
1002 plan: Default::default(),
1003 title: title.into(),
1004 project,
1005 send_task: None,
1006 connection,
1007 session_id,
1008 token_usage: None,
1009 prompt_capabilities,
1010 _observe_prompt_capabilities: task,
1011 terminals: HashMap::default(),
1012 pending_terminal_output: HashMap::default(),
1013 pending_terminal_exit: HashMap::default(),
1014 }
1015 }
1016
1017 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1018 self.prompt_capabilities.clone()
1019 }
1020
1021 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1022 &self.connection
1023 }
1024
1025 pub fn action_log(&self) -> &Entity<ActionLog> {
1026 &self.action_log
1027 }
1028
1029 pub fn project(&self) -> &Entity<Project> {
1030 &self.project
1031 }
1032
1033 pub fn title(&self) -> SharedString {
1034 self.title.clone()
1035 }
1036
1037 pub fn entries(&self) -> &[AgentThreadEntry] {
1038 &self.entries
1039 }
1040
1041 pub fn session_id(&self) -> &acp::SessionId {
1042 &self.session_id
1043 }
1044
1045 pub fn status(&self) -> ThreadStatus {
1046 if self.send_task.is_some() {
1047 ThreadStatus::Generating
1048 } else {
1049 ThreadStatus::Idle
1050 }
1051 }
1052
1053 pub fn token_usage(&self) -> Option<&TokenUsage> {
1054 self.token_usage.as_ref()
1055 }
1056
1057 pub fn has_pending_edit_tool_calls(&self) -> bool {
1058 for entry in self.entries.iter().rev() {
1059 match entry {
1060 AgentThreadEntry::UserMessage(_) => return false,
1061 AgentThreadEntry::ToolCall(
1062 call @ ToolCall {
1063 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1064 ..
1065 },
1066 ) if call.diffs().next().is_some() => {
1067 return true;
1068 }
1069 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1070 }
1071 }
1072
1073 false
1074 }
1075
1076 pub fn used_tools_since_last_user_message(&self) -> bool {
1077 for entry in self.entries.iter().rev() {
1078 match entry {
1079 AgentThreadEntry::UserMessage(..) => return false,
1080 AgentThreadEntry::AssistantMessage(..) => continue,
1081 AgentThreadEntry::ToolCall(..) => return true,
1082 }
1083 }
1084
1085 false
1086 }
1087
1088 pub fn handle_session_update(
1089 &mut self,
1090 update: acp::SessionUpdate,
1091 cx: &mut Context<Self>,
1092 ) -> Result<(), acp::Error> {
1093 match update {
1094 acp::SessionUpdate::UserMessageChunk { content } => {
1095 self.push_user_content_block(None, content, cx);
1096 }
1097 acp::SessionUpdate::AgentMessageChunk { content } => {
1098 self.push_assistant_content_block(content, false, cx);
1099 }
1100 acp::SessionUpdate::AgentThoughtChunk { content } => {
1101 self.push_assistant_content_block(content, true, cx);
1102 }
1103 acp::SessionUpdate::ToolCall(tool_call) => {
1104 self.upsert_tool_call(tool_call, cx)?;
1105 }
1106 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1107 self.update_tool_call(tool_call_update, cx)?;
1108 }
1109 acp::SessionUpdate::Plan(plan) => {
1110 self.update_plan(plan, cx);
1111 }
1112 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
1113 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
1114 }
1115 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
1116 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
1117 }
1118 }
1119 Ok(())
1120 }
1121
1122 pub fn push_user_content_block(
1123 &mut self,
1124 message_id: Option<UserMessageId>,
1125 chunk: acp::ContentBlock,
1126 cx: &mut Context<Self>,
1127 ) {
1128 let language_registry = self.project.read(cx).languages().clone();
1129 let entries_len = self.entries.len();
1130
1131 if let Some(last_entry) = self.entries.last_mut()
1132 && let AgentThreadEntry::UserMessage(UserMessage {
1133 id,
1134 content,
1135 chunks,
1136 ..
1137 }) = last_entry
1138 {
1139 *id = message_id.or(id.take());
1140 content.append(chunk.clone(), &language_registry, cx);
1141 chunks.push(chunk);
1142 let idx = entries_len - 1;
1143 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1144 } else {
1145 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1146 self.push_entry(
1147 AgentThreadEntry::UserMessage(UserMessage {
1148 id: message_id,
1149 content,
1150 chunks: vec![chunk],
1151 checkpoint: None,
1152 }),
1153 cx,
1154 );
1155 }
1156 }
1157
1158 pub fn push_assistant_content_block(
1159 &mut self,
1160 chunk: acp::ContentBlock,
1161 is_thought: bool,
1162 cx: &mut Context<Self>,
1163 ) {
1164 let language_registry = self.project.read(cx).languages().clone();
1165 let entries_len = self.entries.len();
1166 if let Some(last_entry) = self.entries.last_mut()
1167 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1168 {
1169 let idx = entries_len - 1;
1170 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1171 match (chunks.last_mut(), is_thought) {
1172 (Some(AssistantMessageChunk::Message { block }), false)
1173 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1174 block.append(chunk, &language_registry, cx)
1175 }
1176 _ => {
1177 let block = ContentBlock::new(chunk, &language_registry, cx);
1178 if is_thought {
1179 chunks.push(AssistantMessageChunk::Thought { block })
1180 } else {
1181 chunks.push(AssistantMessageChunk::Message { block })
1182 }
1183 }
1184 }
1185 } else {
1186 let block = ContentBlock::new(chunk, &language_registry, cx);
1187 let chunk = if is_thought {
1188 AssistantMessageChunk::Thought { block }
1189 } else {
1190 AssistantMessageChunk::Message { block }
1191 };
1192
1193 self.push_entry(
1194 AgentThreadEntry::AssistantMessage(AssistantMessage {
1195 chunks: vec![chunk],
1196 }),
1197 cx,
1198 );
1199 }
1200 }
1201
1202 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1203 self.entries.push(entry);
1204 cx.emit(AcpThreadEvent::NewEntry);
1205 }
1206
1207 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1208 self.connection.set_title(&self.session_id, cx).is_some()
1209 }
1210
1211 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1212 if title != self.title {
1213 self.title = title.clone();
1214 cx.emit(AcpThreadEvent::TitleUpdated);
1215 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1216 return set_title.run(title, cx);
1217 }
1218 }
1219 Task::ready(Ok(()))
1220 }
1221
1222 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1223 self.token_usage = usage;
1224 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1225 }
1226
1227 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1228 cx.emit(AcpThreadEvent::Retry(status));
1229 }
1230
1231 pub fn update_tool_call(
1232 &mut self,
1233 update: impl Into<ToolCallUpdate>,
1234 cx: &mut Context<Self>,
1235 ) -> Result<()> {
1236 let update = update.into();
1237 let languages = self.project.read(cx).languages().clone();
1238
1239 let ix = match self.index_for_tool_call(update.id()) {
1240 Some(ix) => ix,
1241 None => {
1242 // Tool call not found - create a failed tool call entry
1243 let failed_tool_call = ToolCall {
1244 id: update.id().clone(),
1245 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1246 kind: acp::ToolKind::Fetch,
1247 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1248 acp::ContentBlock::Text(acp::TextContent {
1249 text: "Tool call not found".to_string(),
1250 annotations: None,
1251 meta: None,
1252 }),
1253 &languages,
1254 cx,
1255 ))],
1256 status: ToolCallStatus::Failed,
1257 locations: Vec::new(),
1258 resolved_locations: Vec::new(),
1259 raw_input: None,
1260 raw_output: None,
1261 };
1262 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1263 return Ok(());
1264 }
1265 };
1266 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1267 unreachable!()
1268 };
1269
1270 match update {
1271 ToolCallUpdate::UpdateFields(update) => {
1272 let location_updated = update.fields.locations.is_some();
1273 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1274 if location_updated {
1275 self.resolve_locations(update.id, cx);
1276 }
1277 }
1278 ToolCallUpdate::UpdateDiff(update) => {
1279 call.content.clear();
1280 call.content.push(ToolCallContent::Diff(update.diff));
1281 }
1282 ToolCallUpdate::UpdateTerminal(update) => {
1283 call.content.clear();
1284 call.content
1285 .push(ToolCallContent::Terminal(update.terminal));
1286 }
1287 }
1288
1289 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1290
1291 Ok(())
1292 }
1293
1294 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1295 pub fn upsert_tool_call(
1296 &mut self,
1297 tool_call: acp::ToolCall,
1298 cx: &mut Context<Self>,
1299 ) -> Result<(), acp::Error> {
1300 let status = tool_call.status.into();
1301 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1302 }
1303
1304 /// Fails if id does not match an existing entry.
1305 pub fn upsert_tool_call_inner(
1306 &mut self,
1307 update: acp::ToolCallUpdate,
1308 status: ToolCallStatus,
1309 cx: &mut Context<Self>,
1310 ) -> Result<(), acp::Error> {
1311 let language_registry = self.project.read(cx).languages().clone();
1312 let id = update.id.clone();
1313
1314 if let Some(ix) = self.index_for_tool_call(&id) {
1315 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1316 unreachable!()
1317 };
1318
1319 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1320 call.status = status;
1321
1322 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1323 } else {
1324 let call = ToolCall::from_acp(
1325 update.try_into()?,
1326 status,
1327 language_registry,
1328 &self.terminals,
1329 cx,
1330 )?;
1331 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1332 };
1333
1334 self.resolve_locations(id, cx);
1335 Ok(())
1336 }
1337
1338 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1339 self.entries
1340 .iter()
1341 .enumerate()
1342 .rev()
1343 .find_map(|(index, entry)| {
1344 if let AgentThreadEntry::ToolCall(tool_call) = entry
1345 && &tool_call.id == id
1346 {
1347 Some(index)
1348 } else {
1349 None
1350 }
1351 })
1352 }
1353
1354 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1355 // The tool call we are looking for is typically the last one, or very close to the end.
1356 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1357 self.entries
1358 .iter_mut()
1359 .enumerate()
1360 .rev()
1361 .find_map(|(index, tool_call)| {
1362 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1363 && &tool_call.id == id
1364 {
1365 Some((index, tool_call))
1366 } else {
1367 None
1368 }
1369 })
1370 }
1371
1372 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1373 self.entries
1374 .iter()
1375 .enumerate()
1376 .rev()
1377 .find_map(|(index, tool_call)| {
1378 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1379 && &tool_call.id == id
1380 {
1381 Some((index, tool_call))
1382 } else {
1383 None
1384 }
1385 })
1386 }
1387
1388 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1389 let project = self.project.clone();
1390 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1391 return;
1392 };
1393 let task = tool_call.resolve_locations(project, cx);
1394 cx.spawn(async move |this, cx| {
1395 let resolved_locations = task.await;
1396 this.update(cx, |this, cx| {
1397 let project = this.project.clone();
1398 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1399 return;
1400 };
1401 if let Some(Some(location)) = resolved_locations.last() {
1402 project.update(cx, |project, cx| {
1403 if let Some(agent_location) = project.agent_location() {
1404 let should_ignore = agent_location.buffer == location.buffer
1405 && location
1406 .buffer
1407 .update(cx, |buffer, _| {
1408 let snapshot = buffer.snapshot();
1409 let old_position =
1410 agent_location.position.to_point(&snapshot);
1411 let new_position = location.position.to_point(&snapshot);
1412 // ignore this so that when we get updates from the edit tool
1413 // the position doesn't reset to the startof line
1414 old_position.row == new_position.row
1415 && old_position.column > new_position.column
1416 })
1417 .ok()
1418 .unwrap_or_default();
1419 if !should_ignore {
1420 project.set_agent_location(Some(location.clone()), cx);
1421 }
1422 }
1423 });
1424 }
1425 if tool_call.resolved_locations != resolved_locations {
1426 tool_call.resolved_locations = resolved_locations;
1427 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1428 }
1429 })
1430 })
1431 .detach();
1432 }
1433
1434 pub fn request_tool_call_authorization(
1435 &mut self,
1436 tool_call: acp::ToolCallUpdate,
1437 options: Vec<acp::PermissionOption>,
1438 respect_always_allow_setting: bool,
1439 cx: &mut Context<Self>,
1440 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1441 let (tx, rx) = oneshot::channel();
1442
1443 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1444 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1445 // some tools would (incorrectly) continue to auto-accept.
1446 if let Some(allow_once_option) = options.iter().find_map(|option| {
1447 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1448 Some(option.id.clone())
1449 } else {
1450 None
1451 }
1452 }) {
1453 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1454 return Ok(async {
1455 acp::RequestPermissionOutcome::Selected {
1456 option_id: allow_once_option,
1457 }
1458 }
1459 .boxed());
1460 }
1461 }
1462
1463 let status = ToolCallStatus::WaitingForConfirmation {
1464 options,
1465 respond_tx: tx,
1466 };
1467
1468 self.upsert_tool_call_inner(tool_call, status, cx)?;
1469 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1470
1471 let fut = async {
1472 match rx.await {
1473 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1474 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1475 }
1476 }
1477 .boxed();
1478
1479 Ok(fut)
1480 }
1481
1482 pub fn authorize_tool_call(
1483 &mut self,
1484 id: acp::ToolCallId,
1485 option_id: acp::PermissionOptionId,
1486 option_kind: acp::PermissionOptionKind,
1487 cx: &mut Context<Self>,
1488 ) {
1489 let Some((ix, call)) = self.tool_call_mut(&id) else {
1490 return;
1491 };
1492
1493 let new_status = match option_kind {
1494 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1495 ToolCallStatus::Rejected
1496 }
1497 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1498 ToolCallStatus::InProgress
1499 }
1500 };
1501
1502 let curr_status = mem::replace(&mut call.status, new_status);
1503
1504 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1505 respond_tx.send(option_id).log_err();
1506 } else if cfg!(debug_assertions) {
1507 panic!("tried to authorize an already authorized tool call");
1508 }
1509
1510 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1511 }
1512
1513 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1514 let mut first_tool_call = None;
1515
1516 for entry in self.entries.iter().rev() {
1517 match &entry {
1518 AgentThreadEntry::ToolCall(call) => {
1519 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1520 first_tool_call = Some(call);
1521 } else {
1522 continue;
1523 }
1524 }
1525 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1526 // Reached the beginning of the turn.
1527 // If we had pending permission requests in the previous turn, they have been cancelled.
1528 break;
1529 }
1530 }
1531 }
1532
1533 first_tool_call
1534 }
1535
1536 pub fn plan(&self) -> &Plan {
1537 &self.plan
1538 }
1539
1540 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1541 let new_entries_len = request.entries.len();
1542 let mut new_entries = request.entries.into_iter();
1543
1544 // Reuse existing markdown to prevent flickering
1545 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1546 let PlanEntry {
1547 content,
1548 priority,
1549 status,
1550 } = old;
1551 content.update(cx, |old, cx| {
1552 old.replace(new.content, cx);
1553 });
1554 *priority = new.priority;
1555 *status = new.status;
1556 }
1557 for new in new_entries {
1558 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1559 }
1560 self.plan.entries.truncate(new_entries_len);
1561
1562 cx.notify();
1563 }
1564
1565 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1566 self.plan
1567 .entries
1568 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1569 cx.notify();
1570 }
1571
1572 #[cfg(any(test, feature = "test-support"))]
1573 pub fn send_raw(
1574 &mut self,
1575 message: &str,
1576 cx: &mut Context<Self>,
1577 ) -> BoxFuture<'static, Result<()>> {
1578 self.send(
1579 vec![acp::ContentBlock::Text(acp::TextContent {
1580 text: message.to_string(),
1581 annotations: None,
1582 meta: None,
1583 })],
1584 cx,
1585 )
1586 }
1587
1588 pub fn send(
1589 &mut self,
1590 message: Vec<acp::ContentBlock>,
1591 cx: &mut Context<Self>,
1592 ) -> BoxFuture<'static, Result<()>> {
1593 let block = ContentBlock::new_combined(
1594 message.clone(),
1595 self.project.read(cx).languages().clone(),
1596 cx,
1597 );
1598 let request = acp::PromptRequest {
1599 prompt: message.clone(),
1600 session_id: self.session_id.clone(),
1601 meta: None,
1602 };
1603 let git_store = self.project.read(cx).git_store().clone();
1604
1605 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1606 Some(UserMessageId::new())
1607 } else {
1608 None
1609 };
1610
1611 self.run_turn(cx, async move |this, cx| {
1612 this.update(cx, |this, cx| {
1613 this.push_entry(
1614 AgentThreadEntry::UserMessage(UserMessage {
1615 id: message_id.clone(),
1616 content: block,
1617 chunks: message,
1618 checkpoint: None,
1619 }),
1620 cx,
1621 );
1622 })
1623 .ok();
1624
1625 let old_checkpoint = git_store
1626 .update(cx, |git, cx| git.checkpoint(cx))?
1627 .await
1628 .context("failed to get old checkpoint")
1629 .log_err();
1630 this.update(cx, |this, cx| {
1631 if let Some((_ix, message)) = this.last_user_message() {
1632 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1633 git_checkpoint,
1634 show: false,
1635 });
1636 }
1637 this.connection.prompt(message_id, request, cx)
1638 })?
1639 .await
1640 })
1641 }
1642
1643 pub fn can_resume(&self, cx: &App) -> bool {
1644 self.connection.resume(&self.session_id, cx).is_some()
1645 }
1646
1647 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1648 self.run_turn(cx, async move |this, cx| {
1649 this.update(cx, |this, cx| {
1650 this.connection
1651 .resume(&this.session_id, cx)
1652 .map(|resume| resume.run(cx))
1653 })?
1654 .context("resuming a session is not supported")?
1655 .await
1656 })
1657 }
1658
1659 fn run_turn(
1660 &mut self,
1661 cx: &mut Context<Self>,
1662 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1663 ) -> BoxFuture<'static, Result<()>> {
1664 self.clear_completed_plan_entries(cx);
1665
1666 let (tx, rx) = oneshot::channel();
1667 let cancel_task = self.cancel(cx);
1668
1669 self.send_task = Some(cx.spawn(async move |this, cx| {
1670 cancel_task.await;
1671 tx.send(f(this, cx).await).ok();
1672 }));
1673
1674 cx.spawn(async move |this, cx| {
1675 let response = rx.await;
1676
1677 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1678 .await?;
1679
1680 this.update(cx, |this, cx| {
1681 this.project
1682 .update(cx, |project, cx| project.set_agent_location(None, cx));
1683 match response {
1684 Ok(Err(e)) => {
1685 this.send_task.take();
1686 cx.emit(AcpThreadEvent::Error);
1687 Err(e)
1688 }
1689 result => {
1690 let canceled = matches!(
1691 result,
1692 Ok(Ok(acp::PromptResponse {
1693 stop_reason: acp::StopReason::Cancelled,
1694 meta: None,
1695 }))
1696 );
1697
1698 // We only take the task if the current prompt wasn't canceled.
1699 //
1700 // This prompt may have been canceled because another one was sent
1701 // while it was still generating. In these cases, dropping `send_task`
1702 // would cause the next generation to be canceled.
1703 if !canceled {
1704 this.send_task.take();
1705 }
1706
1707 // Handle refusal - distinguish between user prompt and tool call refusals
1708 if let Ok(Ok(acp::PromptResponse {
1709 stop_reason: acp::StopReason::Refusal,
1710 meta: _,
1711 })) = result
1712 {
1713 if let Some((user_msg_ix, _)) = this.last_user_message() {
1714 // Check if there's a completed tool call with results after the last user message
1715 // This indicates the refusal is in response to tool output, not the user's prompt
1716 let has_completed_tool_call_after_user_msg =
1717 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1718 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1719 // Check if the tool call has completed and has output
1720 matches!(tool_call.status, ToolCallStatus::Completed)
1721 && tool_call.raw_output.is_some()
1722 } else {
1723 false
1724 }
1725 });
1726
1727 if has_completed_tool_call_after_user_msg {
1728 // Refusal is due to tool output - don't truncate, just notify
1729 // The model refused based on what the tool returned
1730 cx.emit(AcpThreadEvent::Refusal);
1731 } else {
1732 // User prompt was refused - truncate back to before the user message
1733 let range = user_msg_ix..this.entries.len();
1734 if range.start < range.end {
1735 this.entries.truncate(user_msg_ix);
1736 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1737 }
1738 cx.emit(AcpThreadEvent::Refusal);
1739 }
1740 } else {
1741 // No user message found, treat as general refusal
1742 cx.emit(AcpThreadEvent::Refusal);
1743 }
1744 }
1745
1746 cx.emit(AcpThreadEvent::Stopped);
1747 Ok(())
1748 }
1749 }
1750 })?
1751 })
1752 .boxed()
1753 }
1754
1755 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1756 let Some(send_task) = self.send_task.take() else {
1757 return Task::ready(());
1758 };
1759
1760 for entry in self.entries.iter_mut() {
1761 if let AgentThreadEntry::ToolCall(call) = entry {
1762 let cancel = matches!(
1763 call.status,
1764 ToolCallStatus::Pending
1765 | ToolCallStatus::WaitingForConfirmation { .. }
1766 | ToolCallStatus::InProgress
1767 );
1768
1769 if cancel {
1770 call.status = ToolCallStatus::Canceled;
1771 }
1772 }
1773 }
1774
1775 self.connection.cancel(&self.session_id, cx);
1776
1777 // Wait for the send task to complete
1778 cx.foreground_executor().spawn(send_task)
1779 }
1780
1781 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1782 pub fn restore_checkpoint(
1783 &mut self,
1784 id: UserMessageId,
1785 cx: &mut Context<Self>,
1786 ) -> Task<Result<()>> {
1787 let Some((_, message)) = self.user_message_mut(&id) else {
1788 return Task::ready(Err(anyhow!("message not found")));
1789 };
1790
1791 let checkpoint = message
1792 .checkpoint
1793 .as_ref()
1794 .map(|c| c.git_checkpoint.clone());
1795 let rewind = self.rewind(id.clone(), cx);
1796 let git_store = self.project.read(cx).git_store().clone();
1797
1798 cx.spawn(async move |_, cx| {
1799 rewind.await?;
1800 if let Some(checkpoint) = checkpoint {
1801 git_store
1802 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1803 .await?;
1804 }
1805
1806 Ok(())
1807 })
1808 }
1809
1810 /// Rewinds this thread to before the entry at `index`, removing it and all
1811 /// subsequent entries while rejecting any action_log changes made from that point.
1812 /// Unlike `restore_checkpoint`, this method does not restore from git.
1813 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1814 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1815 return Task::ready(Err(anyhow!("not supported")));
1816 };
1817
1818 cx.spawn(async move |this, cx| {
1819 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1820 this.update(cx, |this, cx| {
1821 if let Some((ix, _)) = this.user_message_mut(&id) {
1822 let range = ix..this.entries.len();
1823 this.entries.truncate(ix);
1824 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1825 }
1826 this.action_log()
1827 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1828 })?
1829 .await;
1830 Ok(())
1831 })
1832 }
1833
1834 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1835 let git_store = self.project.read(cx).git_store().clone();
1836
1837 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1838 if let Some(checkpoint) = message.checkpoint.as_ref() {
1839 checkpoint.git_checkpoint.clone()
1840 } else {
1841 return Task::ready(Ok(()));
1842 }
1843 } else {
1844 return Task::ready(Ok(()));
1845 };
1846
1847 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1848 cx.spawn(async move |this, cx| {
1849 let new_checkpoint = new_checkpoint
1850 .await
1851 .context("failed to get new checkpoint")
1852 .log_err();
1853 if let Some(new_checkpoint) = new_checkpoint {
1854 let equal = git_store
1855 .update(cx, |git, cx| {
1856 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1857 })?
1858 .await
1859 .unwrap_or(true);
1860 this.update(cx, |this, cx| {
1861 let (ix, message) = this.last_user_message().context("no user message")?;
1862 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1863 checkpoint.show = !equal;
1864 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1865 anyhow::Ok(())
1866 })??;
1867 }
1868
1869 Ok(())
1870 })
1871 }
1872
1873 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1874 self.entries
1875 .iter_mut()
1876 .enumerate()
1877 .rev()
1878 .find_map(|(ix, entry)| {
1879 if let AgentThreadEntry::UserMessage(message) = entry {
1880 Some((ix, message))
1881 } else {
1882 None
1883 }
1884 })
1885 }
1886
1887 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1888 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1889 if let AgentThreadEntry::UserMessage(message) = entry {
1890 if message.id.as_ref() == Some(id) {
1891 Some((ix, message))
1892 } else {
1893 None
1894 }
1895 } else {
1896 None
1897 }
1898 })
1899 }
1900
1901 pub fn read_text_file(
1902 &self,
1903 path: PathBuf,
1904 line: Option<u32>,
1905 limit: Option<u32>,
1906 reuse_shared_snapshot: bool,
1907 cx: &mut Context<Self>,
1908 ) -> Task<Result<String, acp::Error>> {
1909 // Args are 1-based, move to 0-based
1910 let line = line.unwrap_or_default().saturating_sub(1);
1911 let limit = limit.unwrap_or(u32::MAX);
1912 let project = self.project.clone();
1913 let action_log = self.action_log.clone();
1914 cx.spawn(async move |this, cx| {
1915 let load = project
1916 .update(cx, |project, cx| {
1917 let path = project
1918 .project_path_for_absolute_path(&path, cx)
1919 .ok_or_else(|| {
1920 acp::Error::resource_not_found(Some(path.display().to_string()))
1921 })?;
1922 Ok(project.open_buffer(path, cx))
1923 })
1924 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1925 .flatten()?;
1926
1927 let buffer = load.await?;
1928
1929 let snapshot = if reuse_shared_snapshot {
1930 this.read_with(cx, |this, _| {
1931 this.shared_buffers.get(&buffer.clone()).cloned()
1932 })
1933 .log_err()
1934 .flatten()
1935 } else {
1936 None
1937 };
1938
1939 let snapshot = if let Some(snapshot) = snapshot {
1940 snapshot
1941 } else {
1942 action_log.update(cx, |action_log, cx| {
1943 action_log.buffer_read(buffer.clone(), cx);
1944 })?;
1945
1946 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1947 this.update(cx, |this, _| {
1948 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1949 })?;
1950 snapshot
1951 };
1952
1953 let max_point = snapshot.max_point();
1954 let start_position = Point::new(line, 0);
1955
1956 if start_position > max_point {
1957 return Err(acp::Error::invalid_params().with_data(format!(
1958 "Attempting to read beyond the end of the file, line {}:{}",
1959 max_point.row + 1,
1960 max_point.column
1961 )));
1962 }
1963
1964 let start = snapshot.anchor_before(start_position);
1965 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1966
1967 project.update(cx, |project, cx| {
1968 project.set_agent_location(
1969 Some(AgentLocation {
1970 buffer: buffer.downgrade(),
1971 position: start,
1972 }),
1973 cx,
1974 );
1975 })?;
1976
1977 Ok(snapshot.text_for_range(start..end).collect::<String>())
1978 })
1979 }
1980
1981 pub fn write_text_file(
1982 &self,
1983 path: PathBuf,
1984 content: String,
1985 cx: &mut Context<Self>,
1986 ) -> Task<Result<()>> {
1987 let project = self.project.clone();
1988 let action_log = self.action_log.clone();
1989 cx.spawn(async move |this, cx| {
1990 let load = project.update(cx, |project, cx| {
1991 let path = project
1992 .project_path_for_absolute_path(&path, cx)
1993 .context("invalid path")?;
1994 anyhow::Ok(project.open_buffer(path, cx))
1995 });
1996 let buffer = load??.await?;
1997 let snapshot = this.update(cx, |this, cx| {
1998 this.shared_buffers
1999 .get(&buffer)
2000 .cloned()
2001 .unwrap_or_else(|| buffer.read(cx).snapshot())
2002 })?;
2003 let edits = cx
2004 .background_executor()
2005 .spawn(async move {
2006 let old_text = snapshot.text();
2007 text_diff(old_text.as_str(), &content)
2008 .into_iter()
2009 .map(|(range, replacement)| {
2010 (
2011 snapshot.anchor_after(range.start)
2012 ..snapshot.anchor_before(range.end),
2013 replacement,
2014 )
2015 })
2016 .collect::<Vec<_>>()
2017 })
2018 .await;
2019
2020 project.update(cx, |project, cx| {
2021 project.set_agent_location(
2022 Some(AgentLocation {
2023 buffer: buffer.downgrade(),
2024 position: edits
2025 .last()
2026 .map(|(range, _)| range.end)
2027 .unwrap_or(Anchor::MIN),
2028 }),
2029 cx,
2030 );
2031 })?;
2032
2033 let format_on_save = cx.update(|cx| {
2034 action_log.update(cx, |action_log, cx| {
2035 action_log.buffer_read(buffer.clone(), cx);
2036 });
2037
2038 let format_on_save = buffer.update(cx, |buffer, cx| {
2039 buffer.edit(edits, None, cx);
2040
2041 let settings = language::language_settings::language_settings(
2042 buffer.language().map(|l| l.name()),
2043 buffer.file(),
2044 cx,
2045 );
2046
2047 settings.format_on_save != FormatOnSave::Off
2048 });
2049 action_log.update(cx, |action_log, cx| {
2050 action_log.buffer_edited(buffer.clone(), cx);
2051 });
2052 format_on_save
2053 })?;
2054
2055 if format_on_save {
2056 let format_task = project.update(cx, |project, cx| {
2057 project.format(
2058 HashSet::from_iter([buffer.clone()]),
2059 LspFormatTarget::Buffers,
2060 false,
2061 FormatTrigger::Save,
2062 cx,
2063 )
2064 })?;
2065 format_task.await.log_err();
2066
2067 action_log.update(cx, |action_log, cx| {
2068 action_log.buffer_edited(buffer.clone(), cx);
2069 })?;
2070 }
2071
2072 project
2073 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2074 .await
2075 })
2076 }
2077
2078 pub fn create_terminal(
2079 &self,
2080 command: String,
2081 args: Vec<String>,
2082 extra_env: Vec<acp::EnvVariable>,
2083 cwd: Option<PathBuf>,
2084 output_byte_limit: Option<u64>,
2085 cx: &mut Context<Self>,
2086 ) -> Task<Result<Entity<Terminal>>> {
2087 let env = match &cwd {
2088 Some(dir) => self.project.update(cx, |project, cx| {
2089 let shell = TerminalSettings::get_global(cx).shell.clone();
2090 project.directory_environment(&shell, dir.as_path().into(), cx)
2091 }),
2092 None => Task::ready(None).shared(),
2093 };
2094 let env = cx.spawn(async move |_, _| {
2095 let mut env = env.await.unwrap_or_default();
2096 // Disables paging for `git` and hopefully other commands
2097 env.insert("PAGER".into(), "".into());
2098 for var in extra_env {
2099 env.insert(var.name, var.value);
2100 }
2101 env
2102 });
2103
2104 let project = self.project.clone();
2105 let language_registry = project.read(cx).languages().clone();
2106
2107 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2108 let terminal_task = cx.spawn({
2109 let terminal_id = terminal_id.clone();
2110 async move |_this, cx| {
2111 let env = env.await;
2112 let (task_command, task_args) = ShellBuilder::new(
2113 project
2114 .update(cx, |project, cx| {
2115 project
2116 .remote_client()
2117 .and_then(|r| r.read(cx).default_system_shell())
2118 })?
2119 .as_deref(),
2120 &Shell::Program(get_default_system_shell()),
2121 )
2122 .redirect_stdin_to_dev_null()
2123 .build(Some(command.clone()), &args);
2124 let terminal = project
2125 .update(cx, |project, cx| {
2126 project.create_terminal_task(
2127 task::SpawnInTerminal {
2128 command: Some(task_command),
2129 args: task_args,
2130 cwd: cwd.clone(),
2131 env,
2132 ..Default::default()
2133 },
2134 cx,
2135 )
2136 })?
2137 .await?;
2138
2139 cx.new(|cx| {
2140 Terminal::new(
2141 terminal_id,
2142 &format!("{} {}", command, args.join(" ")),
2143 cwd,
2144 output_byte_limit.map(|l| l as usize),
2145 terminal,
2146 language_registry,
2147 cx,
2148 )
2149 })
2150 }
2151 });
2152
2153 cx.spawn(async move |this, cx| {
2154 let terminal = terminal_task.await?;
2155 this.update(cx, |this, _cx| {
2156 this.terminals.insert(terminal_id, terminal.clone());
2157 terminal
2158 })
2159 })
2160 }
2161
2162 pub fn kill_terminal(
2163 &mut self,
2164 terminal_id: acp::TerminalId,
2165 cx: &mut Context<Self>,
2166 ) -> Result<()> {
2167 self.terminals
2168 .get(&terminal_id)
2169 .context("Terminal not found")?
2170 .update(cx, |terminal, cx| {
2171 terminal.kill(cx);
2172 });
2173
2174 Ok(())
2175 }
2176
2177 pub fn release_terminal(
2178 &mut self,
2179 terminal_id: acp::TerminalId,
2180 cx: &mut Context<Self>,
2181 ) -> Result<()> {
2182 self.terminals
2183 .remove(&terminal_id)
2184 .context("Terminal not found")?
2185 .update(cx, |terminal, cx| {
2186 terminal.kill(cx);
2187 });
2188
2189 Ok(())
2190 }
2191
2192 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2193 self.terminals
2194 .get(&terminal_id)
2195 .context("Terminal not found")
2196 .cloned()
2197 }
2198
2199 pub fn to_markdown(&self, cx: &App) -> String {
2200 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2201 }
2202
2203 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2204 cx.emit(AcpThreadEvent::LoadError(error));
2205 }
2206
2207 pub fn register_terminal_created(
2208 &mut self,
2209 terminal_id: acp::TerminalId,
2210 command_label: String,
2211 working_dir: Option<PathBuf>,
2212 output_byte_limit: Option<u64>,
2213 terminal: Entity<::terminal::Terminal>,
2214 cx: &mut Context<Self>,
2215 ) -> Entity<Terminal> {
2216 let language_registry = self.project.read(cx).languages().clone();
2217
2218 let entity = cx.new(|cx| {
2219 Terminal::new(
2220 terminal_id.clone(),
2221 &command_label,
2222 working_dir.clone(),
2223 output_byte_limit.map(|l| l as usize),
2224 terminal,
2225 language_registry,
2226 cx,
2227 )
2228 });
2229 self.terminals.insert(terminal_id.clone(), entity.clone());
2230 entity
2231 }
2232}
2233
2234fn markdown_for_raw_output(
2235 raw_output: &serde_json::Value,
2236 language_registry: &Arc<LanguageRegistry>,
2237 cx: &mut App,
2238) -> Option<Entity<Markdown>> {
2239 match raw_output {
2240 serde_json::Value::Null => None,
2241 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2242 Markdown::new(
2243 value.to_string().into(),
2244 Some(language_registry.clone()),
2245 None,
2246 cx,
2247 )
2248 })),
2249 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2250 Markdown::new(
2251 value.to_string().into(),
2252 Some(language_registry.clone()),
2253 None,
2254 cx,
2255 )
2256 })),
2257 serde_json::Value::String(value) => Some(cx.new(|cx| {
2258 Markdown::new(
2259 value.clone().into(),
2260 Some(language_registry.clone()),
2261 None,
2262 cx,
2263 )
2264 })),
2265 value => Some(cx.new(|cx| {
2266 Markdown::new(
2267 format!("```json\n{}\n```", value).into(),
2268 Some(language_registry.clone()),
2269 None,
2270 cx,
2271 )
2272 })),
2273 }
2274}
2275
2276#[cfg(test)]
2277mod tests {
2278 use super::*;
2279 use anyhow::anyhow;
2280 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2281 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2282 use indoc::indoc;
2283 use project::{FakeFs, Fs};
2284 use rand::{distr, prelude::*};
2285 use serde_json::json;
2286 use settings::SettingsStore;
2287 use smol::stream::StreamExt as _;
2288 use std::{
2289 any::Any,
2290 cell::RefCell,
2291 path::Path,
2292 rc::Rc,
2293 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2294 time::Duration,
2295 };
2296 use util::path;
2297
2298 fn init_test(cx: &mut TestAppContext) {
2299 env_logger::try_init().ok();
2300 cx.update(|cx| {
2301 let settings_store = SettingsStore::test(cx);
2302 cx.set_global(settings_store);
2303 Project::init_settings(cx);
2304 language::init(cx);
2305 });
2306 }
2307
2308 #[gpui::test]
2309 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2310 init_test(cx);
2311
2312 let fs = FakeFs::new(cx.executor());
2313 let project = Project::test(fs, [], cx).await;
2314 let connection = Rc::new(FakeAgentConnection::new());
2315 let thread = cx
2316 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2317 .await
2318 .unwrap();
2319
2320 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2321
2322 // Send Output BEFORE Created - should be buffered by acp_thread
2323 thread.update(cx, |thread, cx| {
2324 thread.on_terminal_provider_event(
2325 TerminalProviderEvent::Output {
2326 terminal_id: terminal_id.clone(),
2327 data: b"hello buffered".to_vec(),
2328 },
2329 cx,
2330 );
2331 });
2332
2333 // Create a display-only terminal and then send Created
2334 let lower = cx.new(|cx| {
2335 let builder = ::terminal::TerminalBuilder::new_display_only(
2336 None,
2337 ::terminal::terminal_settings::CursorShape::default(),
2338 ::terminal::terminal_settings::AlternateScroll::On,
2339 None,
2340 0,
2341 cx,
2342 )
2343 .unwrap();
2344 builder.subscribe(cx)
2345 });
2346
2347 thread.update(cx, |thread, cx| {
2348 thread.on_terminal_provider_event(
2349 TerminalProviderEvent::Created {
2350 terminal_id: terminal_id.clone(),
2351 label: "Buffered Test".to_string(),
2352 cwd: None,
2353 output_byte_limit: None,
2354 terminal: lower.clone(),
2355 },
2356 cx,
2357 );
2358 });
2359
2360 // After Created, buffered Output should have been flushed into the renderer
2361 let content = thread.read_with(cx, |thread, cx| {
2362 let term = thread.terminal(terminal_id.clone()).unwrap();
2363 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2364 });
2365
2366 assert!(
2367 content.contains("hello buffered"),
2368 "expected buffered output to render, got: {content}"
2369 );
2370 }
2371
2372 #[gpui::test]
2373 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2374 init_test(cx);
2375
2376 let fs = FakeFs::new(cx.executor());
2377 let project = Project::test(fs, [], cx).await;
2378 let connection = Rc::new(FakeAgentConnection::new());
2379 let thread = cx
2380 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2381 .await
2382 .unwrap();
2383
2384 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2385
2386 // Send Output BEFORE Created
2387 thread.update(cx, |thread, cx| {
2388 thread.on_terminal_provider_event(
2389 TerminalProviderEvent::Output {
2390 terminal_id: terminal_id.clone(),
2391 data: b"pre-exit data".to_vec(),
2392 },
2393 cx,
2394 );
2395 });
2396
2397 // Send Exit BEFORE Created
2398 thread.update(cx, |thread, cx| {
2399 thread.on_terminal_provider_event(
2400 TerminalProviderEvent::Exit {
2401 terminal_id: terminal_id.clone(),
2402 status: acp::TerminalExitStatus {
2403 exit_code: Some(0),
2404 signal: None,
2405 meta: None,
2406 },
2407 },
2408 cx,
2409 );
2410 });
2411
2412 // Now create a display-only lower-level terminal and send Created
2413 let lower = cx.new(|cx| {
2414 let builder = ::terminal::TerminalBuilder::new_display_only(
2415 None,
2416 ::terminal::terminal_settings::CursorShape::default(),
2417 ::terminal::terminal_settings::AlternateScroll::On,
2418 None,
2419 0,
2420 cx,
2421 )
2422 .unwrap();
2423 builder.subscribe(cx)
2424 });
2425
2426 thread.update(cx, |thread, cx| {
2427 thread.on_terminal_provider_event(
2428 TerminalProviderEvent::Created {
2429 terminal_id: terminal_id.clone(),
2430 label: "Buffered Exit Test".to_string(),
2431 cwd: None,
2432 output_byte_limit: None,
2433 terminal: lower.clone(),
2434 },
2435 cx,
2436 );
2437 });
2438
2439 // Output should be present after Created (flushed from buffer)
2440 let content = thread.read_with(cx, |thread, cx| {
2441 let term = thread.terminal(terminal_id.clone()).unwrap();
2442 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2443 });
2444
2445 assert!(
2446 content.contains("pre-exit data"),
2447 "expected pre-exit data to render, got: {content}"
2448 );
2449 }
2450
2451 #[gpui::test]
2452 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2453 init_test(cx);
2454
2455 let fs = FakeFs::new(cx.executor());
2456 let project = Project::test(fs, [], cx).await;
2457 let connection = Rc::new(FakeAgentConnection::new());
2458 let thread = cx
2459 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2460 .await
2461 .unwrap();
2462
2463 // Test creating a new user message
2464 thread.update(cx, |thread, cx| {
2465 thread.push_user_content_block(
2466 None,
2467 acp::ContentBlock::Text(acp::TextContent {
2468 annotations: None,
2469 text: "Hello, ".to_string(),
2470 meta: None,
2471 }),
2472 cx,
2473 );
2474 });
2475
2476 thread.update(cx, |thread, cx| {
2477 assert_eq!(thread.entries.len(), 1);
2478 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2479 assert_eq!(user_msg.id, None);
2480 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2481 } else {
2482 panic!("Expected UserMessage");
2483 }
2484 });
2485
2486 // Test appending to existing user message
2487 let message_1_id = UserMessageId::new();
2488 thread.update(cx, |thread, cx| {
2489 thread.push_user_content_block(
2490 Some(message_1_id.clone()),
2491 acp::ContentBlock::Text(acp::TextContent {
2492 annotations: None,
2493 text: "world!".to_string(),
2494 meta: None,
2495 }),
2496 cx,
2497 );
2498 });
2499
2500 thread.update(cx, |thread, cx| {
2501 assert_eq!(thread.entries.len(), 1);
2502 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2503 assert_eq!(user_msg.id, Some(message_1_id));
2504 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2505 } else {
2506 panic!("Expected UserMessage");
2507 }
2508 });
2509
2510 // Test creating new user message after assistant message
2511 thread.update(cx, |thread, cx| {
2512 thread.push_assistant_content_block(
2513 acp::ContentBlock::Text(acp::TextContent {
2514 annotations: None,
2515 text: "Assistant response".to_string(),
2516 meta: None,
2517 }),
2518 false,
2519 cx,
2520 );
2521 });
2522
2523 let message_2_id = UserMessageId::new();
2524 thread.update(cx, |thread, cx| {
2525 thread.push_user_content_block(
2526 Some(message_2_id.clone()),
2527 acp::ContentBlock::Text(acp::TextContent {
2528 annotations: None,
2529 text: "New user message".to_string(),
2530 meta: None,
2531 }),
2532 cx,
2533 );
2534 });
2535
2536 thread.update(cx, |thread, cx| {
2537 assert_eq!(thread.entries.len(), 3);
2538 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2539 assert_eq!(user_msg.id, Some(message_2_id));
2540 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2541 } else {
2542 panic!("Expected UserMessage at index 2");
2543 }
2544 });
2545 }
2546
2547 #[gpui::test]
2548 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2549 init_test(cx);
2550
2551 let fs = FakeFs::new(cx.executor());
2552 let project = Project::test(fs, [], cx).await;
2553 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2554 |_, thread, mut cx| {
2555 async move {
2556 thread.update(&mut cx, |thread, cx| {
2557 thread
2558 .handle_session_update(
2559 acp::SessionUpdate::AgentThoughtChunk {
2560 content: "Thinking ".into(),
2561 },
2562 cx,
2563 )
2564 .unwrap();
2565 thread
2566 .handle_session_update(
2567 acp::SessionUpdate::AgentThoughtChunk {
2568 content: "hard!".into(),
2569 },
2570 cx,
2571 )
2572 .unwrap();
2573 })?;
2574 Ok(acp::PromptResponse {
2575 stop_reason: acp::StopReason::EndTurn,
2576 meta: None,
2577 })
2578 }
2579 .boxed_local()
2580 },
2581 ));
2582
2583 let thread = cx
2584 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2585 .await
2586 .unwrap();
2587
2588 thread
2589 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2590 .await
2591 .unwrap();
2592
2593 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2594 assert_eq!(
2595 output,
2596 indoc! {r#"
2597 ## User
2598
2599 Hello from Zed!
2600
2601 ## Assistant
2602
2603 <thinking>
2604 Thinking hard!
2605 </thinking>
2606
2607 "#}
2608 );
2609 }
2610
2611 #[gpui::test]
2612 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2613 init_test(cx);
2614
2615 let fs = FakeFs::new(cx.executor());
2616 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2617 .await;
2618 let project = Project::test(fs.clone(), [], cx).await;
2619 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2620 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2621 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2622 move |_, thread, mut cx| {
2623 let read_file_tx = read_file_tx.clone();
2624 async move {
2625 let content = thread
2626 .update(&mut cx, |thread, cx| {
2627 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2628 })
2629 .unwrap()
2630 .await
2631 .unwrap();
2632 assert_eq!(content, "one\ntwo\nthree\n");
2633 read_file_tx.take().unwrap().send(()).unwrap();
2634 thread
2635 .update(&mut cx, |thread, cx| {
2636 thread.write_text_file(
2637 path!("/tmp/foo").into(),
2638 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2639 cx,
2640 )
2641 })
2642 .unwrap()
2643 .await
2644 .unwrap();
2645 Ok(acp::PromptResponse {
2646 stop_reason: acp::StopReason::EndTurn,
2647 meta: None,
2648 })
2649 }
2650 .boxed_local()
2651 },
2652 ));
2653
2654 let (worktree, pathbuf) = project
2655 .update(cx, |project, cx| {
2656 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2657 })
2658 .await
2659 .unwrap();
2660 let buffer = project
2661 .update(cx, |project, cx| {
2662 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2663 })
2664 .await
2665 .unwrap();
2666
2667 let thread = cx
2668 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2669 .await
2670 .unwrap();
2671
2672 let request = thread.update(cx, |thread, cx| {
2673 thread.send_raw("Extend the count in /tmp/foo", cx)
2674 });
2675 read_file_rx.await.ok();
2676 buffer.update(cx, |buffer, cx| {
2677 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2678 });
2679 cx.run_until_parked();
2680 assert_eq!(
2681 buffer.read_with(cx, |buffer, _| buffer.text()),
2682 "zero\none\ntwo\nthree\nfour\nfive\n"
2683 );
2684 assert_eq!(
2685 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2686 "zero\none\ntwo\nthree\nfour\nfive\n"
2687 );
2688 request.await.unwrap();
2689 }
2690
2691 #[gpui::test]
2692 async fn test_reading_from_line(cx: &mut TestAppContext) {
2693 init_test(cx);
2694
2695 let fs = FakeFs::new(cx.executor());
2696 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2697 .await;
2698 let project = Project::test(fs.clone(), [], cx).await;
2699 project
2700 .update(cx, |project, cx| {
2701 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2702 })
2703 .await
2704 .unwrap();
2705
2706 let connection = Rc::new(FakeAgentConnection::new());
2707
2708 let thread = cx
2709 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2710 .await
2711 .unwrap();
2712
2713 // Whole file
2714 let content = thread
2715 .update(cx, |thread, cx| {
2716 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2717 })
2718 .await
2719 .unwrap();
2720
2721 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2722
2723 // Only start line
2724 let content = thread
2725 .update(cx, |thread, cx| {
2726 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2727 })
2728 .await
2729 .unwrap();
2730
2731 assert_eq!(content, "three\nfour\n");
2732
2733 // Only limit
2734 let content = thread
2735 .update(cx, |thread, cx| {
2736 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2737 })
2738 .await
2739 .unwrap();
2740
2741 assert_eq!(content, "one\ntwo\n");
2742
2743 // Range
2744 let content = thread
2745 .update(cx, |thread, cx| {
2746 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2747 })
2748 .await
2749 .unwrap();
2750
2751 assert_eq!(content, "two\nthree\n");
2752
2753 // Invalid
2754 let err = thread
2755 .update(cx, |thread, cx| {
2756 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2757 })
2758 .await
2759 .unwrap_err();
2760
2761 assert_eq!(
2762 err.to_string(),
2763 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2764 );
2765 }
2766
2767 #[gpui::test]
2768 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2769 init_test(cx);
2770
2771 let fs = FakeFs::new(cx.executor());
2772 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2773 let project = Project::test(fs.clone(), [], cx).await;
2774 project
2775 .update(cx, |project, cx| {
2776 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2777 })
2778 .await
2779 .unwrap();
2780
2781 let connection = Rc::new(FakeAgentConnection::new());
2782
2783 let thread = cx
2784 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2785 .await
2786 .unwrap();
2787
2788 // Whole file
2789 let content = thread
2790 .update(cx, |thread, cx| {
2791 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2792 })
2793 .await
2794 .unwrap();
2795
2796 assert_eq!(content, "");
2797
2798 // Only start line
2799 let content = thread
2800 .update(cx, |thread, cx| {
2801 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2802 })
2803 .await
2804 .unwrap();
2805
2806 assert_eq!(content, "");
2807
2808 // Only limit
2809 let content = thread
2810 .update(cx, |thread, cx| {
2811 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2812 })
2813 .await
2814 .unwrap();
2815
2816 assert_eq!(content, "");
2817
2818 // Range
2819 let content = thread
2820 .update(cx, |thread, cx| {
2821 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2822 })
2823 .await
2824 .unwrap();
2825
2826 assert_eq!(content, "");
2827
2828 // Invalid
2829 let err = thread
2830 .update(cx, |thread, cx| {
2831 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2832 })
2833 .await
2834 .unwrap_err();
2835
2836 assert_eq!(
2837 err.to_string(),
2838 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2839 );
2840 }
2841 #[gpui::test]
2842 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2843 init_test(cx);
2844
2845 let fs = FakeFs::new(cx.executor());
2846 fs.insert_tree(path!("/tmp"), json!({})).await;
2847 let project = Project::test(fs.clone(), [], cx).await;
2848 project
2849 .update(cx, |project, cx| {
2850 project.find_or_create_worktree(path!("/tmp"), true, cx)
2851 })
2852 .await
2853 .unwrap();
2854
2855 let connection = Rc::new(FakeAgentConnection::new());
2856
2857 let thread = cx
2858 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2859 .await
2860 .unwrap();
2861
2862 // Out of project file
2863 let err = thread
2864 .update(cx, |thread, cx| {
2865 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2866 })
2867 .await
2868 .unwrap_err();
2869
2870 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2871 }
2872
2873 #[gpui::test]
2874 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2875 init_test(cx);
2876
2877 let fs = FakeFs::new(cx.executor());
2878 let project = Project::test(fs, [], cx).await;
2879 let id = acp::ToolCallId("test".into());
2880
2881 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2882 let id = id.clone();
2883 move |_, thread, mut cx| {
2884 let id = id.clone();
2885 async move {
2886 thread
2887 .update(&mut cx, |thread, cx| {
2888 thread.handle_session_update(
2889 acp::SessionUpdate::ToolCall(acp::ToolCall {
2890 id: id.clone(),
2891 title: "Label".into(),
2892 kind: acp::ToolKind::Fetch,
2893 status: acp::ToolCallStatus::InProgress,
2894 content: vec![],
2895 locations: vec![],
2896 raw_input: None,
2897 raw_output: None,
2898 meta: None,
2899 }),
2900 cx,
2901 )
2902 })
2903 .unwrap()
2904 .unwrap();
2905 Ok(acp::PromptResponse {
2906 stop_reason: acp::StopReason::EndTurn,
2907 meta: None,
2908 })
2909 }
2910 .boxed_local()
2911 }
2912 }));
2913
2914 let thread = cx
2915 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2916 .await
2917 .unwrap();
2918
2919 let request = thread.update(cx, |thread, cx| {
2920 thread.send_raw("Fetch https://example.com", cx)
2921 });
2922
2923 run_until_first_tool_call(&thread, cx).await;
2924
2925 thread.read_with(cx, |thread, _| {
2926 assert!(matches!(
2927 thread.entries[1],
2928 AgentThreadEntry::ToolCall(ToolCall {
2929 status: ToolCallStatus::InProgress,
2930 ..
2931 })
2932 ));
2933 });
2934
2935 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2936
2937 thread.read_with(cx, |thread, _| {
2938 assert!(matches!(
2939 &thread.entries[1],
2940 AgentThreadEntry::ToolCall(ToolCall {
2941 status: ToolCallStatus::Canceled,
2942 ..
2943 })
2944 ));
2945 });
2946
2947 thread
2948 .update(cx, |thread, cx| {
2949 thread.handle_session_update(
2950 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2951 id,
2952 fields: acp::ToolCallUpdateFields {
2953 status: Some(acp::ToolCallStatus::Completed),
2954 ..Default::default()
2955 },
2956 meta: None,
2957 }),
2958 cx,
2959 )
2960 })
2961 .unwrap();
2962
2963 request.await.unwrap();
2964
2965 thread.read_with(cx, |thread, _| {
2966 assert!(matches!(
2967 thread.entries[1],
2968 AgentThreadEntry::ToolCall(ToolCall {
2969 status: ToolCallStatus::Completed,
2970 ..
2971 })
2972 ));
2973 });
2974 }
2975
2976 #[gpui::test]
2977 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2978 init_test(cx);
2979 let fs = FakeFs::new(cx.background_executor.clone());
2980 fs.insert_tree(path!("/test"), json!({})).await;
2981 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2982
2983 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2984 move |_, thread, mut cx| {
2985 async move {
2986 thread
2987 .update(&mut cx, |thread, cx| {
2988 thread.handle_session_update(
2989 acp::SessionUpdate::ToolCall(acp::ToolCall {
2990 id: acp::ToolCallId("test".into()),
2991 title: "Label".into(),
2992 kind: acp::ToolKind::Edit,
2993 status: acp::ToolCallStatus::Completed,
2994 content: vec![acp::ToolCallContent::Diff {
2995 diff: acp::Diff {
2996 path: "/test/test.txt".into(),
2997 old_text: None,
2998 new_text: "foo".into(),
2999 meta: None,
3000 },
3001 }],
3002 locations: vec![],
3003 raw_input: None,
3004 raw_output: None,
3005 meta: None,
3006 }),
3007 cx,
3008 )
3009 })
3010 .unwrap()
3011 .unwrap();
3012 Ok(acp::PromptResponse {
3013 stop_reason: acp::StopReason::EndTurn,
3014 meta: None,
3015 })
3016 }
3017 .boxed_local()
3018 }
3019 }));
3020
3021 let thread = cx
3022 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3023 .await
3024 .unwrap();
3025
3026 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3027 .await
3028 .unwrap();
3029
3030 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3031 }
3032
3033 #[gpui::test(iterations = 10)]
3034 async fn test_checkpoints(cx: &mut TestAppContext) {
3035 init_test(cx);
3036 let fs = FakeFs::new(cx.background_executor.clone());
3037 fs.insert_tree(
3038 path!("/test"),
3039 json!({
3040 ".git": {}
3041 }),
3042 )
3043 .await;
3044 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3045
3046 let simulate_changes = Arc::new(AtomicBool::new(true));
3047 let next_filename = Arc::new(AtomicUsize::new(0));
3048 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3049 let simulate_changes = simulate_changes.clone();
3050 let next_filename = next_filename.clone();
3051 let fs = fs.clone();
3052 move |request, thread, mut cx| {
3053 let fs = fs.clone();
3054 let simulate_changes = simulate_changes.clone();
3055 let next_filename = next_filename.clone();
3056 async move {
3057 if simulate_changes.load(SeqCst) {
3058 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3059 fs.write(Path::new(&filename), b"").await?;
3060 }
3061
3062 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3063 panic!("expected text content block");
3064 };
3065 thread.update(&mut cx, |thread, cx| {
3066 thread
3067 .handle_session_update(
3068 acp::SessionUpdate::AgentMessageChunk {
3069 content: content.text.to_uppercase().into(),
3070 },
3071 cx,
3072 )
3073 .unwrap();
3074 })?;
3075 Ok(acp::PromptResponse {
3076 stop_reason: acp::StopReason::EndTurn,
3077 meta: None,
3078 })
3079 }
3080 .boxed_local()
3081 }
3082 }));
3083 let thread = cx
3084 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3085 .await
3086 .unwrap();
3087
3088 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3089 .await
3090 .unwrap();
3091 thread.read_with(cx, |thread, cx| {
3092 assert_eq!(
3093 thread.to_markdown(cx),
3094 indoc! {"
3095 ## User (checkpoint)
3096
3097 Lorem
3098
3099 ## Assistant
3100
3101 LOREM
3102
3103 "}
3104 );
3105 });
3106 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3107
3108 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3109 .await
3110 .unwrap();
3111 thread.read_with(cx, |thread, cx| {
3112 assert_eq!(
3113 thread.to_markdown(cx),
3114 indoc! {"
3115 ## User (checkpoint)
3116
3117 Lorem
3118
3119 ## Assistant
3120
3121 LOREM
3122
3123 ## User (checkpoint)
3124
3125 ipsum
3126
3127 ## Assistant
3128
3129 IPSUM
3130
3131 "}
3132 );
3133 });
3134 assert_eq!(
3135 fs.files(),
3136 vec![
3137 Path::new(path!("/test/file-0")),
3138 Path::new(path!("/test/file-1"))
3139 ]
3140 );
3141
3142 // Checkpoint isn't stored when there are no changes.
3143 simulate_changes.store(false, SeqCst);
3144 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3145 .await
3146 .unwrap();
3147 thread.read_with(cx, |thread, cx| {
3148 assert_eq!(
3149 thread.to_markdown(cx),
3150 indoc! {"
3151 ## User (checkpoint)
3152
3153 Lorem
3154
3155 ## Assistant
3156
3157 LOREM
3158
3159 ## User (checkpoint)
3160
3161 ipsum
3162
3163 ## Assistant
3164
3165 IPSUM
3166
3167 ## User
3168
3169 dolor
3170
3171 ## Assistant
3172
3173 DOLOR
3174
3175 "}
3176 );
3177 });
3178 assert_eq!(
3179 fs.files(),
3180 vec![
3181 Path::new(path!("/test/file-0")),
3182 Path::new(path!("/test/file-1"))
3183 ]
3184 );
3185
3186 // Rewinding the conversation truncates the history and restores the checkpoint.
3187 thread
3188 .update(cx, |thread, cx| {
3189 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3190 panic!("unexpected entries {:?}", thread.entries)
3191 };
3192 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3193 })
3194 .await
3195 .unwrap();
3196 thread.read_with(cx, |thread, cx| {
3197 assert_eq!(
3198 thread.to_markdown(cx),
3199 indoc! {"
3200 ## User (checkpoint)
3201
3202 Lorem
3203
3204 ## Assistant
3205
3206 LOREM
3207
3208 "}
3209 );
3210 });
3211 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3212 }
3213
3214 #[gpui::test]
3215 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3216 use std::sync::atomic::AtomicUsize;
3217 init_test(cx);
3218
3219 let fs = FakeFs::new(cx.executor());
3220 let project = Project::test(fs, None, cx).await;
3221
3222 // Create a connection that simulates refusal after tool result
3223 let prompt_count = Arc::new(AtomicUsize::new(0));
3224 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3225 let prompt_count = prompt_count.clone();
3226 move |_request, thread, mut cx| {
3227 let count = prompt_count.fetch_add(1, SeqCst);
3228 async move {
3229 if count == 0 {
3230 // First prompt: Generate a tool call with result
3231 thread.update(&mut cx, |thread, cx| {
3232 thread
3233 .handle_session_update(
3234 acp::SessionUpdate::ToolCall(acp::ToolCall {
3235 id: acp::ToolCallId("tool1".into()),
3236 title: "Test Tool".into(),
3237 kind: acp::ToolKind::Fetch,
3238 status: acp::ToolCallStatus::Completed,
3239 content: vec![],
3240 locations: vec![],
3241 raw_input: Some(serde_json::json!({"query": "test"})),
3242 raw_output: Some(
3243 serde_json::json!({"result": "inappropriate content"}),
3244 ),
3245 meta: None,
3246 }),
3247 cx,
3248 )
3249 .unwrap();
3250 })?;
3251
3252 // Now return refusal because of the tool result
3253 Ok(acp::PromptResponse {
3254 stop_reason: acp::StopReason::Refusal,
3255 meta: None,
3256 })
3257 } else {
3258 Ok(acp::PromptResponse {
3259 stop_reason: acp::StopReason::EndTurn,
3260 meta: None,
3261 })
3262 }
3263 }
3264 .boxed_local()
3265 }
3266 }));
3267
3268 let thread = cx
3269 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3270 .await
3271 .unwrap();
3272
3273 // Track if we see a Refusal event
3274 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3275 let saw_refusal_event_captured = saw_refusal_event.clone();
3276 thread.update(cx, |_thread, cx| {
3277 cx.subscribe(
3278 &thread,
3279 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3280 if matches!(event, AcpThreadEvent::Refusal) {
3281 *saw_refusal_event_captured.lock().unwrap() = true;
3282 }
3283 },
3284 )
3285 .detach();
3286 });
3287
3288 // Send a user message - this will trigger tool call and then refusal
3289 let send_task = thread.update(cx, |thread, cx| {
3290 thread.send(
3291 vec![acp::ContentBlock::Text(acp::TextContent {
3292 text: "Hello".into(),
3293 annotations: None,
3294 meta: None,
3295 })],
3296 cx,
3297 )
3298 });
3299 cx.background_executor.spawn(send_task).detach();
3300 cx.run_until_parked();
3301
3302 // Verify that:
3303 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3304 // 2. The user message was NOT truncated
3305 assert!(
3306 *saw_refusal_event.lock().unwrap(),
3307 "Refusal event should be emitted for tool result refusals"
3308 );
3309
3310 thread.read_with(cx, |thread, _| {
3311 let entries = thread.entries();
3312 assert!(entries.len() >= 2, "Should have user message and tool call");
3313
3314 // Verify user message is still there
3315 assert!(
3316 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3317 "User message should not be truncated"
3318 );
3319
3320 // Verify tool call is there with result
3321 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3322 assert!(
3323 tool_call.raw_output.is_some(),
3324 "Tool call should have output"
3325 );
3326 } else {
3327 panic!("Expected tool call at index 1");
3328 }
3329 });
3330 }
3331
3332 #[gpui::test]
3333 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3334 init_test(cx);
3335
3336 let fs = FakeFs::new(cx.executor());
3337 let project = Project::test(fs, None, cx).await;
3338
3339 let refuse_next = Arc::new(AtomicBool::new(false));
3340 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3341 let refuse_next = refuse_next.clone();
3342 move |_request, _thread, _cx| {
3343 if refuse_next.load(SeqCst) {
3344 async move {
3345 Ok(acp::PromptResponse {
3346 stop_reason: acp::StopReason::Refusal,
3347 meta: None,
3348 })
3349 }
3350 .boxed_local()
3351 } else {
3352 async move {
3353 Ok(acp::PromptResponse {
3354 stop_reason: acp::StopReason::EndTurn,
3355 meta: None,
3356 })
3357 }
3358 .boxed_local()
3359 }
3360 }
3361 }));
3362
3363 let thread = cx
3364 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3365 .await
3366 .unwrap();
3367
3368 // Track if we see a Refusal event
3369 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3370 let saw_refusal_event_captured = saw_refusal_event.clone();
3371 thread.update(cx, |_thread, cx| {
3372 cx.subscribe(
3373 &thread,
3374 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3375 if matches!(event, AcpThreadEvent::Refusal) {
3376 *saw_refusal_event_captured.lock().unwrap() = true;
3377 }
3378 },
3379 )
3380 .detach();
3381 });
3382
3383 // Send a message that will be refused
3384 refuse_next.store(true, SeqCst);
3385 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3386 .await
3387 .unwrap();
3388
3389 // Verify that a Refusal event WAS emitted for user prompt refusal
3390 assert!(
3391 *saw_refusal_event.lock().unwrap(),
3392 "Refusal event should be emitted for user prompt refusals"
3393 );
3394
3395 // Verify the message was truncated (user prompt refusal)
3396 thread.read_with(cx, |thread, cx| {
3397 assert_eq!(thread.to_markdown(cx), "");
3398 });
3399 }
3400
3401 #[gpui::test]
3402 async fn test_refusal(cx: &mut TestAppContext) {
3403 init_test(cx);
3404 let fs = FakeFs::new(cx.background_executor.clone());
3405 fs.insert_tree(path!("/"), json!({})).await;
3406 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3407
3408 let refuse_next = Arc::new(AtomicBool::new(false));
3409 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3410 let refuse_next = refuse_next.clone();
3411 move |request, thread, mut cx| {
3412 let refuse_next = refuse_next.clone();
3413 async move {
3414 if refuse_next.load(SeqCst) {
3415 return Ok(acp::PromptResponse {
3416 stop_reason: acp::StopReason::Refusal,
3417 meta: None,
3418 });
3419 }
3420
3421 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3422 panic!("expected text content block");
3423 };
3424 thread.update(&mut cx, |thread, cx| {
3425 thread
3426 .handle_session_update(
3427 acp::SessionUpdate::AgentMessageChunk {
3428 content: content.text.to_uppercase().into(),
3429 },
3430 cx,
3431 )
3432 .unwrap();
3433 })?;
3434 Ok(acp::PromptResponse {
3435 stop_reason: acp::StopReason::EndTurn,
3436 meta: None,
3437 })
3438 }
3439 .boxed_local()
3440 }
3441 }));
3442 let thread = cx
3443 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3444 .await
3445 .unwrap();
3446
3447 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3448 .await
3449 .unwrap();
3450 thread.read_with(cx, |thread, cx| {
3451 assert_eq!(
3452 thread.to_markdown(cx),
3453 indoc! {"
3454 ## User
3455
3456 hello
3457
3458 ## Assistant
3459
3460 HELLO
3461
3462 "}
3463 );
3464 });
3465
3466 // Simulate refusing the second message. The message should be truncated
3467 // when a user prompt is refused.
3468 refuse_next.store(true, SeqCst);
3469 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3470 .await
3471 .unwrap();
3472 thread.read_with(cx, |thread, cx| {
3473 assert_eq!(
3474 thread.to_markdown(cx),
3475 indoc! {"
3476 ## User
3477
3478 hello
3479
3480 ## Assistant
3481
3482 HELLO
3483
3484 "}
3485 );
3486 });
3487 }
3488
3489 async fn run_until_first_tool_call(
3490 thread: &Entity<AcpThread>,
3491 cx: &mut TestAppContext,
3492 ) -> usize {
3493 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3494
3495 let subscription = cx.update(|cx| {
3496 cx.subscribe(thread, move |thread, _, cx| {
3497 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3498 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3499 return tx.try_send(ix).unwrap();
3500 }
3501 }
3502 })
3503 });
3504
3505 select! {
3506 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3507 panic!("Timeout waiting for tool call")
3508 }
3509 ix = rx.next().fuse() => {
3510 drop(subscription);
3511 ix.unwrap()
3512 }
3513 }
3514 }
3515
3516 #[derive(Clone, Default)]
3517 struct FakeAgentConnection {
3518 auth_methods: Vec<acp::AuthMethod>,
3519 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3520 on_user_message: Option<
3521 Rc<
3522 dyn Fn(
3523 acp::PromptRequest,
3524 WeakEntity<AcpThread>,
3525 AsyncApp,
3526 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3527 + 'static,
3528 >,
3529 >,
3530 }
3531
3532 impl FakeAgentConnection {
3533 fn new() -> Self {
3534 Self {
3535 auth_methods: Vec::new(),
3536 on_user_message: None,
3537 sessions: Arc::default(),
3538 }
3539 }
3540
3541 #[expect(unused)]
3542 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3543 self.auth_methods = auth_methods;
3544 self
3545 }
3546
3547 fn on_user_message(
3548 mut self,
3549 handler: impl Fn(
3550 acp::PromptRequest,
3551 WeakEntity<AcpThread>,
3552 AsyncApp,
3553 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3554 + 'static,
3555 ) -> Self {
3556 self.on_user_message.replace(Rc::new(handler));
3557 self
3558 }
3559 }
3560
3561 impl AgentConnection for FakeAgentConnection {
3562 fn auth_methods(&self) -> &[acp::AuthMethod] {
3563 &self.auth_methods
3564 }
3565
3566 fn new_thread(
3567 self: Rc<Self>,
3568 project: Entity<Project>,
3569 _cwd: &Path,
3570 cx: &mut App,
3571 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3572 let session_id = acp::SessionId(
3573 rand::rng()
3574 .sample_iter(&distr::Alphanumeric)
3575 .take(7)
3576 .map(char::from)
3577 .collect::<String>()
3578 .into(),
3579 );
3580 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3581 let thread = cx.new(|cx| {
3582 AcpThread::new(
3583 "Test",
3584 self.clone(),
3585 project,
3586 action_log,
3587 session_id.clone(),
3588 watch::Receiver::constant(acp::PromptCapabilities {
3589 image: true,
3590 audio: true,
3591 embedded_context: true,
3592 meta: None,
3593 }),
3594 cx,
3595 )
3596 });
3597 self.sessions.lock().insert(session_id, thread.downgrade());
3598 Task::ready(Ok(thread))
3599 }
3600
3601 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3602 if self.auth_methods().iter().any(|m| m.id == method) {
3603 Task::ready(Ok(()))
3604 } else {
3605 Task::ready(Err(anyhow!("Invalid Auth Method")))
3606 }
3607 }
3608
3609 fn prompt(
3610 &self,
3611 _id: Option<UserMessageId>,
3612 params: acp::PromptRequest,
3613 cx: &mut App,
3614 ) -> Task<gpui::Result<acp::PromptResponse>> {
3615 let sessions = self.sessions.lock();
3616 let thread = sessions.get(¶ms.session_id).unwrap();
3617 if let Some(handler) = &self.on_user_message {
3618 let handler = handler.clone();
3619 let thread = thread.clone();
3620 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3621 } else {
3622 Task::ready(Ok(acp::PromptResponse {
3623 stop_reason: acp::StopReason::EndTurn,
3624 meta: None,
3625 }))
3626 }
3627 }
3628
3629 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3630 let sessions = self.sessions.lock();
3631 let thread = sessions.get(session_id).unwrap().clone();
3632
3633 cx.spawn(async move |cx| {
3634 thread
3635 .update(cx, |thread, cx| thread.cancel(cx))
3636 .unwrap()
3637 .await
3638 })
3639 .detach();
3640 }
3641
3642 fn truncate(
3643 &self,
3644 session_id: &acp::SessionId,
3645 _cx: &App,
3646 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3647 Some(Rc::new(FakeAgentSessionEditor {
3648 _session_id: session_id.clone(),
3649 }))
3650 }
3651
3652 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3653 self
3654 }
3655 }
3656
3657 struct FakeAgentSessionEditor {
3658 _session_id: acp::SessionId,
3659 }
3660
3661 impl AgentSessionTruncate for FakeAgentSessionEditor {
3662 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3663 Task::ready(Ok(()))
3664 }
3665 }
3666
3667 #[gpui::test]
3668 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3669 init_test(cx);
3670
3671 let fs = FakeFs::new(cx.executor());
3672 let project = Project::test(fs, [], cx).await;
3673 let connection = Rc::new(FakeAgentConnection::new());
3674 let thread = cx
3675 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3676 .await
3677 .unwrap();
3678
3679 // Try to update a tool call that doesn't exist
3680 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3681 thread.update(cx, |thread, cx| {
3682 let result = thread.handle_session_update(
3683 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3684 id: nonexistent_id.clone(),
3685 fields: acp::ToolCallUpdateFields {
3686 status: Some(acp::ToolCallStatus::Completed),
3687 ..Default::default()
3688 },
3689 meta: None,
3690 }),
3691 cx,
3692 );
3693
3694 // The update should succeed (not return an error)
3695 assert!(result.is_ok());
3696
3697 // There should now be exactly one entry in the thread
3698 assert_eq!(thread.entries.len(), 1);
3699
3700 // The entry should be a failed tool call
3701 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3702 assert_eq!(tool_call.id, nonexistent_id);
3703 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3704 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3705
3706 // Check that the content contains the error message
3707 assert_eq!(tool_call.content.len(), 1);
3708 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3709 match content_block {
3710 ContentBlock::Markdown { markdown } => {
3711 let markdown_text = markdown.read(cx).source();
3712 assert!(markdown_text.contains("Tool call not found"));
3713 }
3714 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3715 ContentBlock::ResourceLink { .. } => {
3716 panic!("Expected markdown content, got resource link")
3717 }
3718 }
3719 } else {
3720 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3721 }
3722 } else {
3723 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3724 }
3725 });
3726 }
3727}