1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use ::terminal::terminal_settings::TerminalSettings;
7use agent_settings::AgentSettings;
8use collections::HashSet;
9pub use connection::*;
10pub use diff::*;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::{Settings as _, SettingsLocation};
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::ActionLog;
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77}
78
79impl AssistantMessage {
80 pub fn to_markdown(&self, cx: &App) -> String {
81 format!(
82 "## Assistant\n\n{}\n\n",
83 self.chunks
84 .iter()
85 .map(|chunk| chunk.to_markdown(cx))
86 .join("\n\n")
87 )
88 }
89}
90
91#[derive(Debug, PartialEq)]
92pub enum AssistantMessageChunk {
93 Message { block: ContentBlock },
94 Thought { block: ContentBlock },
95}
96
97impl AssistantMessageChunk {
98 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Message {
100 block: ContentBlock::new(chunk.into(), language_registry, cx),
101 }
102 }
103
104 fn to_markdown(&self, cx: &App) -> String {
105 match self {
106 Self::Message { block } => block.to_markdown(cx).to_string(),
107 Self::Thought { block } => {
108 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
109 }
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum AgentThreadEntry {
116 UserMessage(UserMessage),
117 AssistantMessage(AssistantMessage),
118 ToolCall(ToolCall),
119}
120
121impl AgentThreadEntry {
122 pub fn to_markdown(&self, cx: &App) -> String {
123 match self {
124 Self::UserMessage(message) => message.to_markdown(cx),
125 Self::AssistantMessage(message) => message.to_markdown(cx),
126 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
127 }
128 }
129
130 pub fn user_message(&self) -> Option<&UserMessage> {
131 if let AgentThreadEntry::UserMessage(message) = self {
132 Some(message)
133 } else {
134 None
135 }
136 }
137
138 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.diffs())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
147 if let AgentThreadEntry::ToolCall(call) = self {
148 itertools::Either::Left(call.terminals())
149 } else {
150 itertools::Either::Right(std::iter::empty())
151 }
152 }
153
154 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
155 if let AgentThreadEntry::ToolCall(ToolCall {
156 locations,
157 resolved_locations,
158 ..
159 }) = self
160 {
161 Some((
162 locations.get(ix)?.clone(),
163 resolved_locations.get(ix)?.clone()?,
164 ))
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 pub content: Vec<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179 pub resolved_locations: Vec<Option<AgentLocation>>,
180 pub raw_input: Option<serde_json::Value>,
181 pub raw_output: Option<serde_json::Value>,
182}
183
184impl ToolCall {
185 fn from_acp(
186 tool_call: acp::ToolCall,
187 status: ToolCallStatus,
188 language_registry: Arc<LanguageRegistry>,
189 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
190 cx: &mut App,
191 ) -> Result<Self> {
192 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
193 first_line.to_owned() + "…"
194 } else {
195 tool_call.title
196 };
197 let mut content = Vec::with_capacity(tool_call.content.len());
198 for item in tool_call.content {
199 content.push(ToolCallContent::from_acp(
200 item,
201 language_registry.clone(),
202 terminals,
203 cx,
204 )?);
205 }
206
207 let result = Self {
208 id: tool_call.id,
209 label: cx
210 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
211 kind: tool_call.kind,
212 content,
213 locations: tool_call.locations,
214 resolved_locations: Vec::default(),
215 status,
216 raw_input: tool_call.raw_input,
217 raw_output: tool_call.raw_output,
218 };
219 Ok(result)
220 }
221
222 fn update_fields(
223 &mut self,
224 fields: acp::ToolCallUpdateFields,
225 language_registry: Arc<LanguageRegistry>,
226 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
227 cx: &mut App,
228 ) -> Result<()> {
229 let acp::ToolCallUpdateFields {
230 kind,
231 status,
232 title,
233 content,
234 locations,
235 raw_input,
236 raw_output,
237 } = fields;
238
239 if let Some(kind) = kind {
240 self.kind = kind;
241 }
242
243 if let Some(status) = status {
244 self.status = status.into();
245 }
246
247 if let Some(title) = title {
248 self.label.update(cx, |label, cx| {
249 if let Some((first_line, _)) = title.split_once("\n") {
250 label.replace(first_line.to_owned() + "…", cx)
251 } else {
252 label.replace(title, cx);
253 }
254 });
255 }
256
257 if let Some(content) = content {
258 let new_content_len = content.len();
259 let mut content = content.into_iter();
260
261 // Reuse existing content if we can
262 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
263 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
264 }
265 for new in content {
266 self.content.push(ToolCallContent::from_acp(
267 new,
268 language_registry.clone(),
269 terminals,
270 cx,
271 )?)
272 }
273 self.content.truncate(new_content_len);
274 }
275
276 if let Some(locations) = locations {
277 self.locations = locations;
278 }
279
280 if let Some(raw_input) = raw_input {
281 self.raw_input = Some(raw_input);
282 }
283
284 if let Some(raw_output) = raw_output {
285 if self.content.is_empty()
286 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
287 {
288 self.content
289 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
290 markdown,
291 }));
292 }
293 self.raw_output = Some(raw_output);
294 }
295 Ok(())
296 }
297
298 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
299 self.content.iter().filter_map(|content| match content {
300 ToolCallContent::Diff(diff) => Some(diff),
301 ToolCallContent::ContentBlock(_) => None,
302 ToolCallContent::Terminal(_) => None,
303 })
304 }
305
306 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
307 self.content.iter().filter_map(|content| match content {
308 ToolCallContent::Terminal(terminal) => Some(terminal),
309 ToolCallContent::ContentBlock(_) => None,
310 ToolCallContent::Diff(_) => None,
311 })
312 }
313
314 fn to_markdown(&self, cx: &App) -> String {
315 let mut markdown = format!(
316 "**Tool Call: {}**\nStatus: {}\n\n",
317 self.label.read(cx).source(),
318 self.status
319 );
320 for content in &self.content {
321 markdown.push_str(content.to_markdown(cx).as_str());
322 markdown.push_str("\n\n");
323 }
324 markdown
325 }
326
327 async fn resolve_location(
328 location: acp::ToolCallLocation,
329 project: WeakEntity<Project>,
330 cx: &mut AsyncApp,
331 ) -> Option<ResolvedLocation> {
332 let buffer = project
333 .update(cx, |project, cx| {
334 project
335 .project_path_for_absolute_path(&location.path, cx)
336 .map(|path| project.open_buffer(path, cx))
337 })
338 .ok()??;
339 let buffer = buffer.await.log_err()?;
340 let position = buffer
341 .update(cx, |buffer, _| {
342 if let Some(row) = location.line {
343 let snapshot = buffer.snapshot();
344 let column = snapshot.indent_size_for_line(row).len;
345 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
346 snapshot.anchor_before(point)
347 } else {
348 Anchor::MIN
349 }
350 })
351 .ok()?;
352
353 Some(ResolvedLocation { buffer, position })
354 }
355
356 fn resolve_locations(
357 &self,
358 project: Entity<Project>,
359 cx: &mut App,
360 ) -> Task<Vec<Option<ResolvedLocation>>> {
361 let locations = self.locations.clone();
362 project.update(cx, |_, cx| {
363 cx.spawn(async move |project, cx| {
364 let mut new_locations = Vec::new();
365 for location in locations {
366 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
367 }
368 new_locations
369 })
370 })
371 }
372}
373
374// Separate so we can hold a strong reference to the buffer
375// for saving on the thread
376#[derive(Clone, Debug, PartialEq, Eq)]
377struct ResolvedLocation {
378 buffer: Entity<Buffer>,
379 position: Anchor,
380}
381
382impl From<&ResolvedLocation> for AgentLocation {
383 fn from(value: &ResolvedLocation) -> Self {
384 Self {
385 buffer: value.buffer.downgrade(),
386 position: value.position,
387 }
388 }
389}
390
391#[derive(Debug)]
392pub enum ToolCallStatus {
393 /// The tool call hasn't started running yet, but we start showing it to
394 /// the user.
395 Pending,
396 /// The tool call is waiting for confirmation from the user.
397 WaitingForConfirmation {
398 options: Vec<acp::PermissionOption>,
399 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
400 },
401 /// The tool call is currently running.
402 InProgress,
403 /// The tool call completed successfully.
404 Completed,
405 /// The tool call failed.
406 Failed,
407 /// The user rejected the tool call.
408 Rejected,
409 /// The user canceled generation so the tool call was canceled.
410 Canceled,
411}
412
413impl From<acp::ToolCallStatus> for ToolCallStatus {
414 fn from(status: acp::ToolCallStatus) -> Self {
415 match status {
416 acp::ToolCallStatus::Pending => Self::Pending,
417 acp::ToolCallStatus::InProgress => Self::InProgress,
418 acp::ToolCallStatus::Completed => Self::Completed,
419 acp::ToolCallStatus::Failed => Self::Failed,
420 }
421 }
422}
423
424impl Display for ToolCallStatus {
425 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
426 write!(
427 f,
428 "{}",
429 match self {
430 ToolCallStatus::Pending => "Pending",
431 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
432 ToolCallStatus::InProgress => "In Progress",
433 ToolCallStatus::Completed => "Completed",
434 ToolCallStatus::Failed => "Failed",
435 ToolCallStatus::Rejected => "Rejected",
436 ToolCallStatus::Canceled => "Canceled",
437 }
438 )
439 }
440}
441
442#[derive(Debug, PartialEq, Clone)]
443pub enum ContentBlock {
444 Empty,
445 Markdown { markdown: Entity<Markdown> },
446 ResourceLink { resource_link: acp::ResourceLink },
447}
448
449impl ContentBlock {
450 pub fn new(
451 block: acp::ContentBlock,
452 language_registry: &Arc<LanguageRegistry>,
453 cx: &mut App,
454 ) -> Self {
455 let mut this = Self::Empty;
456 this.append(block, language_registry, cx);
457 this
458 }
459
460 pub fn new_combined(
461 blocks: impl IntoIterator<Item = acp::ContentBlock>,
462 language_registry: Arc<LanguageRegistry>,
463 cx: &mut App,
464 ) -> Self {
465 let mut this = Self::Empty;
466 for block in blocks {
467 this.append(block, &language_registry, cx);
468 }
469 this
470 }
471
472 pub fn append(
473 &mut self,
474 block: acp::ContentBlock,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) {
478 if matches!(self, ContentBlock::Empty)
479 && let acp::ContentBlock::ResourceLink(resource_link) = block
480 {
481 *self = ContentBlock::ResourceLink { resource_link };
482 return;
483 }
484
485 let new_content = self.block_string_contents(block);
486
487 match self {
488 ContentBlock::Empty => {
489 *self = Self::create_markdown_block(new_content, language_registry, cx);
490 }
491 ContentBlock::Markdown { markdown } => {
492 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
493 }
494 ContentBlock::ResourceLink { resource_link } => {
495 let existing_content = Self::resource_link_md(&resource_link.uri);
496 let combined = format!("{}\n{}", existing_content, new_content);
497
498 *self = Self::create_markdown_block(combined, language_registry, cx);
499 }
500 }
501 }
502
503 fn create_markdown_block(
504 content: String,
505 language_registry: &Arc<LanguageRegistry>,
506 cx: &mut App,
507 ) -> ContentBlock {
508 ContentBlock::Markdown {
509 markdown: cx
510 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
511 }
512 }
513
514 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
515 match block {
516 acp::ContentBlock::Text(text_content) => text_content.text,
517 acp::ContentBlock::ResourceLink(resource_link) => {
518 Self::resource_link_md(&resource_link.uri)
519 }
520 acp::ContentBlock::Resource(acp::EmbeddedResource {
521 resource:
522 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
523 uri,
524 ..
525 }),
526 ..
527 }) => Self::resource_link_md(&uri),
528 acp::ContentBlock::Image(image) => Self::image_md(&image),
529 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
530 }
531 }
532
533 fn resource_link_md(uri: &str) -> String {
534 if let Some(uri) = MentionUri::parse(uri).log_err() {
535 uri.as_link().to_string()
536 } else {
537 uri.to_string()
538 }
539 }
540
541 fn image_md(_image: &acp::ImageContent) -> String {
542 "`Image`".into()
543 }
544
545 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
546 match self {
547 ContentBlock::Empty => "",
548 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
549 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
550 }
551 }
552
553 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
554 match self {
555 ContentBlock::Empty => None,
556 ContentBlock::Markdown { markdown } => Some(markdown),
557 ContentBlock::ResourceLink { .. } => None,
558 }
559 }
560
561 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
562 match self {
563 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
564 _ => None,
565 }
566 }
567}
568
569#[derive(Debug)]
570pub enum ToolCallContent {
571 ContentBlock(ContentBlock),
572 Diff(Entity<Diff>),
573 Terminal(Entity<Terminal>),
574}
575
576impl ToolCallContent {
577 pub fn from_acp(
578 content: acp::ToolCallContent,
579 language_registry: Arc<LanguageRegistry>,
580 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
581 cx: &mut App,
582 ) -> Result<Self> {
583 match content {
584 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
585 content,
586 &language_registry,
587 cx,
588 ))),
589 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
590 Diff::finalized(
591 diff.path.to_string_lossy().into_owned(),
592 diff.old_text,
593 diff.new_text,
594 language_registry,
595 cx,
596 )
597 }))),
598 acp::ToolCallContent::Terminal { terminal_id } => terminals
599 .get(&terminal_id)
600 .cloned()
601 .map(Self::Terminal)
602 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
603 }
604 }
605
606 pub fn update_from_acp(
607 &mut self,
608 new: acp::ToolCallContent,
609 language_registry: Arc<LanguageRegistry>,
610 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
611 cx: &mut App,
612 ) -> Result<()> {
613 let needs_update = match (&self, &new) {
614 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
615 old_diff.read(cx).needs_update(
616 new_diff.old_text.as_deref().unwrap_or(""),
617 &new_diff.new_text,
618 cx,
619 )
620 }
621 _ => true,
622 };
623
624 if needs_update {
625 *self = Self::from_acp(new, language_registry, terminals, cx)?;
626 }
627 Ok(())
628 }
629
630 pub fn to_markdown(&self, cx: &App) -> String {
631 match self {
632 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
633 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
634 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
635 }
636 }
637}
638
639#[derive(Debug, PartialEq)]
640pub enum ToolCallUpdate {
641 UpdateFields(acp::ToolCallUpdate),
642 UpdateDiff(ToolCallUpdateDiff),
643 UpdateTerminal(ToolCallUpdateTerminal),
644}
645
646impl ToolCallUpdate {
647 fn id(&self) -> &acp::ToolCallId {
648 match self {
649 Self::UpdateFields(update) => &update.id,
650 Self::UpdateDiff(diff) => &diff.id,
651 Self::UpdateTerminal(terminal) => &terminal.id,
652 }
653 }
654}
655
656impl From<acp::ToolCallUpdate> for ToolCallUpdate {
657 fn from(update: acp::ToolCallUpdate) -> Self {
658 Self::UpdateFields(update)
659 }
660}
661
662impl From<ToolCallUpdateDiff> for ToolCallUpdate {
663 fn from(diff: ToolCallUpdateDiff) -> Self {
664 Self::UpdateDiff(diff)
665 }
666}
667
668#[derive(Debug, PartialEq)]
669pub struct ToolCallUpdateDiff {
670 pub id: acp::ToolCallId,
671 pub diff: Entity<Diff>,
672}
673
674impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
675 fn from(terminal: ToolCallUpdateTerminal) -> Self {
676 Self::UpdateTerminal(terminal)
677 }
678}
679
680#[derive(Debug, PartialEq)]
681pub struct ToolCallUpdateTerminal {
682 pub id: acp::ToolCallId,
683 pub terminal: Entity<Terminal>,
684}
685
686#[derive(Debug, Default)]
687pub struct Plan {
688 pub entries: Vec<PlanEntry>,
689}
690
691#[derive(Debug)]
692pub struct PlanStats<'a> {
693 pub in_progress_entry: Option<&'a PlanEntry>,
694 pub pending: u32,
695 pub completed: u32,
696}
697
698impl Plan {
699 pub fn is_empty(&self) -> bool {
700 self.entries.is_empty()
701 }
702
703 pub fn stats(&self) -> PlanStats<'_> {
704 let mut stats = PlanStats {
705 in_progress_entry: None,
706 pending: 0,
707 completed: 0,
708 };
709
710 for entry in &self.entries {
711 match &entry.status {
712 acp::PlanEntryStatus::Pending => {
713 stats.pending += 1;
714 }
715 acp::PlanEntryStatus::InProgress => {
716 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
717 }
718 acp::PlanEntryStatus::Completed => {
719 stats.completed += 1;
720 }
721 }
722 }
723
724 stats
725 }
726}
727
728#[derive(Debug)]
729pub struct PlanEntry {
730 pub content: Entity<Markdown>,
731 pub priority: acp::PlanEntryPriority,
732 pub status: acp::PlanEntryStatus,
733}
734
735impl PlanEntry {
736 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
737 Self {
738 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
739 priority: entry.priority,
740 status: entry.status,
741 }
742 }
743}
744
745#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
746pub struct TokenUsage {
747 pub max_tokens: u64,
748 pub used_tokens: u64,
749}
750
751impl TokenUsage {
752 pub fn ratio(&self) -> TokenUsageRatio {
753 #[cfg(debug_assertions)]
754 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
755 .unwrap_or("0.8".to_string())
756 .parse()
757 .unwrap();
758 #[cfg(not(debug_assertions))]
759 let warning_threshold: f32 = 0.8;
760
761 // When the maximum is unknown because there is no selected model,
762 // avoid showing the token limit warning.
763 if self.max_tokens == 0 {
764 TokenUsageRatio::Normal
765 } else if self.used_tokens >= self.max_tokens {
766 TokenUsageRatio::Exceeded
767 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
768 TokenUsageRatio::Warning
769 } else {
770 TokenUsageRatio::Normal
771 }
772 }
773}
774
775#[derive(Debug, Clone, PartialEq, Eq)]
776pub enum TokenUsageRatio {
777 Normal,
778 Warning,
779 Exceeded,
780}
781
782#[derive(Debug, Clone)]
783pub struct RetryStatus {
784 pub last_error: SharedString,
785 pub attempt: usize,
786 pub max_attempts: usize,
787 pub started_at: Instant,
788 pub duration: Duration,
789}
790
791pub struct AcpThread {
792 title: SharedString,
793 entries: Vec<AgentThreadEntry>,
794 plan: Plan,
795 project: Entity<Project>,
796 action_log: Entity<ActionLog>,
797 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
798 send_task: Option<Task<()>>,
799 connection: Rc<dyn AgentConnection>,
800 session_id: acp::SessionId,
801 token_usage: Option<TokenUsage>,
802 prompt_capabilities: acp::PromptCapabilities,
803 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
804 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
805 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
806 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
807}
808
809#[derive(Debug)]
810pub enum AcpThreadEvent {
811 NewEntry,
812 TitleUpdated,
813 TokenUsageUpdated,
814 EntryUpdated(usize),
815 EntriesRemoved(Range<usize>),
816 ToolAuthorizationRequired,
817 Retry(RetryStatus),
818 Stopped,
819 Error,
820 LoadError(LoadError),
821 PromptCapabilitiesUpdated,
822 Refusal,
823 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
824 ModeUpdated(acp::SessionModeId),
825}
826
827impl EventEmitter<AcpThreadEvent> for AcpThread {}
828
829#[derive(Debug, Clone)]
830pub enum TerminalProviderEvent {
831 Created {
832 terminal_id: acp::TerminalId,
833 label: String,
834 cwd: Option<PathBuf>,
835 output_byte_limit: Option<u64>,
836 terminal: Entity<::terminal::Terminal>,
837 },
838 Output {
839 terminal_id: acp::TerminalId,
840 data: Vec<u8>,
841 },
842 TitleChanged {
843 terminal_id: acp::TerminalId,
844 title: String,
845 },
846 Exit {
847 terminal_id: acp::TerminalId,
848 status: acp::TerminalExitStatus,
849 },
850}
851
852#[derive(Debug, Clone)]
853pub enum TerminalProviderCommand {
854 WriteInput {
855 terminal_id: acp::TerminalId,
856 bytes: Vec<u8>,
857 },
858 Resize {
859 terminal_id: acp::TerminalId,
860 cols: u16,
861 rows: u16,
862 },
863 Close {
864 terminal_id: acp::TerminalId,
865 },
866}
867
868impl AcpThread {
869 pub fn on_terminal_provider_event(
870 &mut self,
871 event: TerminalProviderEvent,
872 cx: &mut Context<Self>,
873 ) {
874 match event {
875 TerminalProviderEvent::Created {
876 terminal_id,
877 label,
878 cwd,
879 output_byte_limit,
880 terminal,
881 } => {
882 let entity = self.register_terminal_created(
883 terminal_id.clone(),
884 label,
885 cwd,
886 output_byte_limit,
887 terminal,
888 cx,
889 );
890
891 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
892 for data in chunks.drain(..) {
893 entity.update(cx, |term, cx| {
894 term.inner().update(cx, |inner, cx| {
895 inner.write_output(&data, cx);
896 })
897 });
898 }
899 }
900
901 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
902 entity.update(cx, |_term, cx| {
903 cx.notify();
904 });
905 }
906
907 cx.notify();
908 }
909 TerminalProviderEvent::Output { terminal_id, data } => {
910 if let Some(entity) = self.terminals.get(&terminal_id) {
911 entity.update(cx, |term, cx| {
912 term.inner().update(cx, |inner, cx| {
913 inner.write_output(&data, cx);
914 })
915 });
916 } else {
917 self.pending_terminal_output
918 .entry(terminal_id)
919 .or_default()
920 .push(data);
921 }
922 }
923 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
924 if let Some(entity) = self.terminals.get(&terminal_id) {
925 entity.update(cx, |term, cx| {
926 term.inner().update(cx, |inner, cx| {
927 inner.breadcrumb_text = title;
928 cx.emit(::terminal::Event::BreadcrumbsChanged);
929 })
930 });
931 }
932 }
933 TerminalProviderEvent::Exit {
934 terminal_id,
935 status,
936 } => {
937 if let Some(entity) = self.terminals.get(&terminal_id) {
938 entity.update(cx, |_term, cx| {
939 cx.notify();
940 });
941 } else {
942 self.pending_terminal_exit.insert(terminal_id, status);
943 }
944 }
945 }
946 }
947}
948
949#[derive(PartialEq, Eq, Debug)]
950pub enum ThreadStatus {
951 Idle,
952 Generating,
953}
954
955#[derive(Debug, Clone)]
956pub enum LoadError {
957 Unsupported {
958 command: SharedString,
959 current_version: SharedString,
960 minimum_version: SharedString,
961 },
962 FailedToInstall(SharedString),
963 Exited {
964 status: ExitStatus,
965 },
966 Other(SharedString),
967}
968
969impl Display for LoadError {
970 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
971 match self {
972 LoadError::Unsupported {
973 command: path,
974 current_version,
975 minimum_version,
976 } => {
977 write!(
978 f,
979 "version {current_version} from {path} is not supported (need at least {minimum_version})"
980 )
981 }
982 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
983 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
984 LoadError::Other(msg) => write!(f, "{msg}"),
985 }
986 }
987}
988
989impl Error for LoadError {}
990
991impl AcpThread {
992 pub fn new(
993 title: impl Into<SharedString>,
994 connection: Rc<dyn AgentConnection>,
995 project: Entity<Project>,
996 action_log: Entity<ActionLog>,
997 session_id: acp::SessionId,
998 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
999 cx: &mut Context<Self>,
1000 ) -> Self {
1001 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1002 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1003 loop {
1004 let caps = prompt_capabilities_rx.recv().await?;
1005 this.update(cx, |this, cx| {
1006 this.prompt_capabilities = caps;
1007 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1008 })?;
1009 }
1010 });
1011
1012 Self {
1013 action_log,
1014 shared_buffers: Default::default(),
1015 entries: Default::default(),
1016 plan: Default::default(),
1017 title: title.into(),
1018 project,
1019 send_task: None,
1020 connection,
1021 session_id,
1022 token_usage: None,
1023 prompt_capabilities,
1024 _observe_prompt_capabilities: task,
1025 terminals: HashMap::default(),
1026 pending_terminal_output: HashMap::default(),
1027 pending_terminal_exit: HashMap::default(),
1028 }
1029 }
1030
1031 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1032 self.prompt_capabilities.clone()
1033 }
1034
1035 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1036 &self.connection
1037 }
1038
1039 pub fn action_log(&self) -> &Entity<ActionLog> {
1040 &self.action_log
1041 }
1042
1043 pub fn project(&self) -> &Entity<Project> {
1044 &self.project
1045 }
1046
1047 pub fn title(&self) -> SharedString {
1048 self.title.clone()
1049 }
1050
1051 pub fn entries(&self) -> &[AgentThreadEntry] {
1052 &self.entries
1053 }
1054
1055 pub fn session_id(&self) -> &acp::SessionId {
1056 &self.session_id
1057 }
1058
1059 pub fn status(&self) -> ThreadStatus {
1060 if self.send_task.is_some() {
1061 ThreadStatus::Generating
1062 } else {
1063 ThreadStatus::Idle
1064 }
1065 }
1066
1067 pub fn token_usage(&self) -> Option<&TokenUsage> {
1068 self.token_usage.as_ref()
1069 }
1070
1071 pub fn has_pending_edit_tool_calls(&self) -> bool {
1072 for entry in self.entries.iter().rev() {
1073 match entry {
1074 AgentThreadEntry::UserMessage(_) => return false,
1075 AgentThreadEntry::ToolCall(
1076 call @ ToolCall {
1077 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1078 ..
1079 },
1080 ) if call.diffs().next().is_some() => {
1081 return true;
1082 }
1083 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1084 }
1085 }
1086
1087 false
1088 }
1089
1090 pub fn used_tools_since_last_user_message(&self) -> bool {
1091 for entry in self.entries.iter().rev() {
1092 match entry {
1093 AgentThreadEntry::UserMessage(..) => return false,
1094 AgentThreadEntry::AssistantMessage(..) => continue,
1095 AgentThreadEntry::ToolCall(..) => return true,
1096 }
1097 }
1098
1099 false
1100 }
1101
1102 pub fn handle_session_update(
1103 &mut self,
1104 update: acp::SessionUpdate,
1105 cx: &mut Context<Self>,
1106 ) -> Result<(), acp::Error> {
1107 match update {
1108 acp::SessionUpdate::UserMessageChunk { content } => {
1109 self.push_user_content_block(None, content, cx);
1110 }
1111 acp::SessionUpdate::AgentMessageChunk { content } => {
1112 self.push_assistant_content_block(content, false, cx);
1113 }
1114 acp::SessionUpdate::AgentThoughtChunk { content } => {
1115 self.push_assistant_content_block(content, true, cx);
1116 }
1117 acp::SessionUpdate::ToolCall(tool_call) => {
1118 self.upsert_tool_call(tool_call, cx)?;
1119 }
1120 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1121 self.update_tool_call(tool_call_update, cx)?;
1122 }
1123 acp::SessionUpdate::Plan(plan) => {
1124 self.update_plan(plan, cx);
1125 }
1126 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
1127 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
1128 }
1129 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
1130 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
1131 }
1132 }
1133 Ok(())
1134 }
1135
1136 pub fn push_user_content_block(
1137 &mut self,
1138 message_id: Option<UserMessageId>,
1139 chunk: acp::ContentBlock,
1140 cx: &mut Context<Self>,
1141 ) {
1142 let language_registry = self.project.read(cx).languages().clone();
1143 let entries_len = self.entries.len();
1144
1145 if let Some(last_entry) = self.entries.last_mut()
1146 && let AgentThreadEntry::UserMessage(UserMessage {
1147 id,
1148 content,
1149 chunks,
1150 ..
1151 }) = last_entry
1152 {
1153 *id = message_id.or(id.take());
1154 content.append(chunk.clone(), &language_registry, cx);
1155 chunks.push(chunk);
1156 let idx = entries_len - 1;
1157 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1158 } else {
1159 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1160 self.push_entry(
1161 AgentThreadEntry::UserMessage(UserMessage {
1162 id: message_id,
1163 content,
1164 chunks: vec![chunk],
1165 checkpoint: None,
1166 }),
1167 cx,
1168 );
1169 }
1170 }
1171
1172 pub fn push_assistant_content_block(
1173 &mut self,
1174 chunk: acp::ContentBlock,
1175 is_thought: bool,
1176 cx: &mut Context<Self>,
1177 ) {
1178 let language_registry = self.project.read(cx).languages().clone();
1179 let entries_len = self.entries.len();
1180 if let Some(last_entry) = self.entries.last_mut()
1181 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1182 {
1183 let idx = entries_len - 1;
1184 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1185 match (chunks.last_mut(), is_thought) {
1186 (Some(AssistantMessageChunk::Message { block }), false)
1187 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1188 block.append(chunk, &language_registry, cx)
1189 }
1190 _ => {
1191 let block = ContentBlock::new(chunk, &language_registry, cx);
1192 if is_thought {
1193 chunks.push(AssistantMessageChunk::Thought { block })
1194 } else {
1195 chunks.push(AssistantMessageChunk::Message { block })
1196 }
1197 }
1198 }
1199 } else {
1200 let block = ContentBlock::new(chunk, &language_registry, cx);
1201 let chunk = if is_thought {
1202 AssistantMessageChunk::Thought { block }
1203 } else {
1204 AssistantMessageChunk::Message { block }
1205 };
1206
1207 self.push_entry(
1208 AgentThreadEntry::AssistantMessage(AssistantMessage {
1209 chunks: vec![chunk],
1210 }),
1211 cx,
1212 );
1213 }
1214 }
1215
1216 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1217 self.entries.push(entry);
1218 cx.emit(AcpThreadEvent::NewEntry);
1219 }
1220
1221 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1222 self.connection.set_title(&self.session_id, cx).is_some()
1223 }
1224
1225 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1226 if title != self.title {
1227 self.title = title.clone();
1228 cx.emit(AcpThreadEvent::TitleUpdated);
1229 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1230 return set_title.run(title, cx);
1231 }
1232 }
1233 Task::ready(Ok(()))
1234 }
1235
1236 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1237 self.token_usage = usage;
1238 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1239 }
1240
1241 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1242 cx.emit(AcpThreadEvent::Retry(status));
1243 }
1244
1245 pub fn update_tool_call(
1246 &mut self,
1247 update: impl Into<ToolCallUpdate>,
1248 cx: &mut Context<Self>,
1249 ) -> Result<()> {
1250 let update = update.into();
1251 let languages = self.project.read(cx).languages().clone();
1252
1253 let ix = match self.index_for_tool_call(update.id()) {
1254 Some(ix) => ix,
1255 None => {
1256 // Tool call not found - create a failed tool call entry
1257 let failed_tool_call = ToolCall {
1258 id: update.id().clone(),
1259 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1260 kind: acp::ToolKind::Fetch,
1261 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1262 acp::ContentBlock::Text(acp::TextContent {
1263 text: "Tool call not found".to_string(),
1264 annotations: None,
1265 meta: None,
1266 }),
1267 &languages,
1268 cx,
1269 ))],
1270 status: ToolCallStatus::Failed,
1271 locations: Vec::new(),
1272 resolved_locations: Vec::new(),
1273 raw_input: None,
1274 raw_output: None,
1275 };
1276 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1277 return Ok(());
1278 }
1279 };
1280 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1281 unreachable!()
1282 };
1283
1284 match update {
1285 ToolCallUpdate::UpdateFields(update) => {
1286 let location_updated = update.fields.locations.is_some();
1287 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1288 if location_updated {
1289 self.resolve_locations(update.id, cx);
1290 }
1291 }
1292 ToolCallUpdate::UpdateDiff(update) => {
1293 call.content.clear();
1294 call.content.push(ToolCallContent::Diff(update.diff));
1295 }
1296 ToolCallUpdate::UpdateTerminal(update) => {
1297 call.content.clear();
1298 call.content
1299 .push(ToolCallContent::Terminal(update.terminal));
1300 }
1301 }
1302
1303 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1304
1305 Ok(())
1306 }
1307
1308 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1309 pub fn upsert_tool_call(
1310 &mut self,
1311 tool_call: acp::ToolCall,
1312 cx: &mut Context<Self>,
1313 ) -> Result<(), acp::Error> {
1314 let status = tool_call.status.into();
1315 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1316 }
1317
1318 /// Fails if id does not match an existing entry.
1319 pub fn upsert_tool_call_inner(
1320 &mut self,
1321 update: acp::ToolCallUpdate,
1322 status: ToolCallStatus,
1323 cx: &mut Context<Self>,
1324 ) -> Result<(), acp::Error> {
1325 let language_registry = self.project.read(cx).languages().clone();
1326 let id = update.id.clone();
1327
1328 if let Some(ix) = self.index_for_tool_call(&id) {
1329 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1330 unreachable!()
1331 };
1332
1333 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1334 call.status = status;
1335
1336 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1337 } else {
1338 let call = ToolCall::from_acp(
1339 update.try_into()?,
1340 status,
1341 language_registry,
1342 &self.terminals,
1343 cx,
1344 )?;
1345 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1346 };
1347
1348 self.resolve_locations(id, cx);
1349 Ok(())
1350 }
1351
1352 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1353 self.entries
1354 .iter()
1355 .enumerate()
1356 .rev()
1357 .find_map(|(index, entry)| {
1358 if let AgentThreadEntry::ToolCall(tool_call) = entry
1359 && &tool_call.id == id
1360 {
1361 Some(index)
1362 } else {
1363 None
1364 }
1365 })
1366 }
1367
1368 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1369 // The tool call we are looking for is typically the last one, or very close to the end.
1370 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1371 self.entries
1372 .iter_mut()
1373 .enumerate()
1374 .rev()
1375 .find_map(|(index, tool_call)| {
1376 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1377 && &tool_call.id == id
1378 {
1379 Some((index, tool_call))
1380 } else {
1381 None
1382 }
1383 })
1384 }
1385
1386 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1387 self.entries
1388 .iter()
1389 .enumerate()
1390 .rev()
1391 .find_map(|(index, tool_call)| {
1392 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1393 && &tool_call.id == id
1394 {
1395 Some((index, tool_call))
1396 } else {
1397 None
1398 }
1399 })
1400 }
1401
1402 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1403 let project = self.project.clone();
1404 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1405 return;
1406 };
1407 let task = tool_call.resolve_locations(project, cx);
1408 cx.spawn(async move |this, cx| {
1409 let resolved_locations = task.await;
1410
1411 this.update(cx, |this, cx| {
1412 let project = this.project.clone();
1413
1414 for location in resolved_locations.iter().flatten() {
1415 this.shared_buffers
1416 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1417 }
1418 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1419 return;
1420 };
1421
1422 if let Some(Some(location)) = resolved_locations.last() {
1423 project.update(cx, |project, cx| {
1424 let should_ignore = if let Some(agent_location) = project
1425 .agent_location()
1426 .filter(|agent_location| agent_location.buffer == location.buffer)
1427 {
1428 let snapshot = location.buffer.read(cx).snapshot();
1429 let old_position = agent_location.position.to_point(&snapshot);
1430 let new_position = location.position.to_point(&snapshot);
1431
1432 // ignore this so that when we get updates from the edit tool
1433 // the position doesn't reset to the startof line
1434 old_position.row == new_position.row
1435 && old_position.column > new_position.column
1436 } else {
1437 false
1438 };
1439 if !should_ignore {
1440 project.set_agent_location(Some(location.into()), cx);
1441 }
1442 });
1443 }
1444
1445 let resolved_locations = resolved_locations
1446 .iter()
1447 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1448 .collect::<Vec<_>>();
1449
1450 if tool_call.resolved_locations != resolved_locations {
1451 tool_call.resolved_locations = resolved_locations;
1452 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1453 }
1454 })
1455 })
1456 .detach();
1457 }
1458
1459 pub fn request_tool_call_authorization(
1460 &mut self,
1461 tool_call: acp::ToolCallUpdate,
1462 options: Vec<acp::PermissionOption>,
1463 respect_always_allow_setting: bool,
1464 cx: &mut Context<Self>,
1465 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1466 let (tx, rx) = oneshot::channel();
1467
1468 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1469 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1470 // some tools would (incorrectly) continue to auto-accept.
1471 if let Some(allow_once_option) = options.iter().find_map(|option| {
1472 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1473 Some(option.id.clone())
1474 } else {
1475 None
1476 }
1477 }) {
1478 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1479 return Ok(async {
1480 acp::RequestPermissionOutcome::Selected {
1481 option_id: allow_once_option,
1482 }
1483 }
1484 .boxed());
1485 }
1486 }
1487
1488 let status = ToolCallStatus::WaitingForConfirmation {
1489 options,
1490 respond_tx: tx,
1491 };
1492
1493 self.upsert_tool_call_inner(tool_call, status, cx)?;
1494 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1495
1496 let fut = async {
1497 match rx.await {
1498 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1499 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1500 }
1501 }
1502 .boxed();
1503
1504 Ok(fut)
1505 }
1506
1507 pub fn authorize_tool_call(
1508 &mut self,
1509 id: acp::ToolCallId,
1510 option_id: acp::PermissionOptionId,
1511 option_kind: acp::PermissionOptionKind,
1512 cx: &mut Context<Self>,
1513 ) {
1514 let Some((ix, call)) = self.tool_call_mut(&id) else {
1515 return;
1516 };
1517
1518 let new_status = match option_kind {
1519 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1520 ToolCallStatus::Rejected
1521 }
1522 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1523 ToolCallStatus::InProgress
1524 }
1525 };
1526
1527 let curr_status = mem::replace(&mut call.status, new_status);
1528
1529 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1530 respond_tx.send(option_id).log_err();
1531 } else if cfg!(debug_assertions) {
1532 panic!("tried to authorize an already authorized tool call");
1533 }
1534
1535 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1536 }
1537
1538 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1539 let mut first_tool_call = None;
1540
1541 for entry in self.entries.iter().rev() {
1542 match &entry {
1543 AgentThreadEntry::ToolCall(call) => {
1544 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1545 first_tool_call = Some(call);
1546 } else {
1547 continue;
1548 }
1549 }
1550 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1551 // Reached the beginning of the turn.
1552 // If we had pending permission requests in the previous turn, they have been cancelled.
1553 break;
1554 }
1555 }
1556 }
1557
1558 first_tool_call
1559 }
1560
1561 pub fn plan(&self) -> &Plan {
1562 &self.plan
1563 }
1564
1565 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1566 let new_entries_len = request.entries.len();
1567 let mut new_entries = request.entries.into_iter();
1568
1569 // Reuse existing markdown to prevent flickering
1570 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1571 let PlanEntry {
1572 content,
1573 priority,
1574 status,
1575 } = old;
1576 content.update(cx, |old, cx| {
1577 old.replace(new.content, cx);
1578 });
1579 *priority = new.priority;
1580 *status = new.status;
1581 }
1582 for new in new_entries {
1583 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1584 }
1585 self.plan.entries.truncate(new_entries_len);
1586
1587 cx.notify();
1588 }
1589
1590 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1591 self.plan
1592 .entries
1593 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1594 cx.notify();
1595 }
1596
1597 #[cfg(any(test, feature = "test-support"))]
1598 pub fn send_raw(
1599 &mut self,
1600 message: &str,
1601 cx: &mut Context<Self>,
1602 ) -> BoxFuture<'static, Result<()>> {
1603 self.send(
1604 vec![acp::ContentBlock::Text(acp::TextContent {
1605 text: message.to_string(),
1606 annotations: None,
1607 meta: None,
1608 })],
1609 cx,
1610 )
1611 }
1612
1613 pub fn send(
1614 &mut self,
1615 message: Vec<acp::ContentBlock>,
1616 cx: &mut Context<Self>,
1617 ) -> BoxFuture<'static, Result<()>> {
1618 let block = ContentBlock::new_combined(
1619 message.clone(),
1620 self.project.read(cx).languages().clone(),
1621 cx,
1622 );
1623 let request = acp::PromptRequest {
1624 prompt: message.clone(),
1625 session_id: self.session_id.clone(),
1626 meta: None,
1627 };
1628 let git_store = self.project.read(cx).git_store().clone();
1629
1630 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1631 Some(UserMessageId::new())
1632 } else {
1633 None
1634 };
1635
1636 self.run_turn(cx, async move |this, cx| {
1637 this.update(cx, |this, cx| {
1638 this.push_entry(
1639 AgentThreadEntry::UserMessage(UserMessage {
1640 id: message_id.clone(),
1641 content: block,
1642 chunks: message,
1643 checkpoint: None,
1644 }),
1645 cx,
1646 );
1647 })
1648 .ok();
1649
1650 let old_checkpoint = git_store
1651 .update(cx, |git, cx| git.checkpoint(cx))?
1652 .await
1653 .context("failed to get old checkpoint")
1654 .log_err();
1655 this.update(cx, |this, cx| {
1656 if let Some((_ix, message)) = this.last_user_message() {
1657 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1658 git_checkpoint,
1659 show: false,
1660 });
1661 }
1662 this.connection.prompt(message_id, request, cx)
1663 })?
1664 .await
1665 })
1666 }
1667
1668 pub fn can_resume(&self, cx: &App) -> bool {
1669 self.connection.resume(&self.session_id, cx).is_some()
1670 }
1671
1672 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1673 self.run_turn(cx, async move |this, cx| {
1674 this.update(cx, |this, cx| {
1675 this.connection
1676 .resume(&this.session_id, cx)
1677 .map(|resume| resume.run(cx))
1678 })?
1679 .context("resuming a session is not supported")?
1680 .await
1681 })
1682 }
1683
1684 fn run_turn(
1685 &mut self,
1686 cx: &mut Context<Self>,
1687 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1688 ) -> BoxFuture<'static, Result<()>> {
1689 self.clear_completed_plan_entries(cx);
1690
1691 let (tx, rx) = oneshot::channel();
1692 let cancel_task = self.cancel(cx);
1693
1694 self.send_task = Some(cx.spawn(async move |this, cx| {
1695 cancel_task.await;
1696 tx.send(f(this, cx).await).ok();
1697 }));
1698
1699 cx.spawn(async move |this, cx| {
1700 let response = rx.await;
1701
1702 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1703 .await?;
1704
1705 this.update(cx, |this, cx| {
1706 this.project
1707 .update(cx, |project, cx| project.set_agent_location(None, cx));
1708 match response {
1709 Ok(Err(e)) => {
1710 this.send_task.take();
1711 cx.emit(AcpThreadEvent::Error);
1712 Err(e)
1713 }
1714 result => {
1715 let canceled = matches!(
1716 result,
1717 Ok(Ok(acp::PromptResponse {
1718 stop_reason: acp::StopReason::Cancelled,
1719 meta: None,
1720 }))
1721 );
1722
1723 // We only take the task if the current prompt wasn't canceled.
1724 //
1725 // This prompt may have been canceled because another one was sent
1726 // while it was still generating. In these cases, dropping `send_task`
1727 // would cause the next generation to be canceled.
1728 if !canceled {
1729 this.send_task.take();
1730 }
1731
1732 // Handle refusal - distinguish between user prompt and tool call refusals
1733 if let Ok(Ok(acp::PromptResponse {
1734 stop_reason: acp::StopReason::Refusal,
1735 meta: _,
1736 })) = result
1737 {
1738 if let Some((user_msg_ix, _)) = this.last_user_message() {
1739 // Check if there's a completed tool call with results after the last user message
1740 // This indicates the refusal is in response to tool output, not the user's prompt
1741 let has_completed_tool_call_after_user_msg =
1742 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1743 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1744 // Check if the tool call has completed and has output
1745 matches!(tool_call.status, ToolCallStatus::Completed)
1746 && tool_call.raw_output.is_some()
1747 } else {
1748 false
1749 }
1750 });
1751
1752 if has_completed_tool_call_after_user_msg {
1753 // Refusal is due to tool output - don't truncate, just notify
1754 // The model refused based on what the tool returned
1755 cx.emit(AcpThreadEvent::Refusal);
1756 } else {
1757 // User prompt was refused - truncate back to before the user message
1758 let range = user_msg_ix..this.entries.len();
1759 if range.start < range.end {
1760 this.entries.truncate(user_msg_ix);
1761 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1762 }
1763 cx.emit(AcpThreadEvent::Refusal);
1764 }
1765 } else {
1766 // No user message found, treat as general refusal
1767 cx.emit(AcpThreadEvent::Refusal);
1768 }
1769 }
1770
1771 cx.emit(AcpThreadEvent::Stopped);
1772 Ok(())
1773 }
1774 }
1775 })?
1776 })
1777 .boxed()
1778 }
1779
1780 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1781 let Some(send_task) = self.send_task.take() else {
1782 return Task::ready(());
1783 };
1784
1785 for entry in self.entries.iter_mut() {
1786 if let AgentThreadEntry::ToolCall(call) = entry {
1787 let cancel = matches!(
1788 call.status,
1789 ToolCallStatus::Pending
1790 | ToolCallStatus::WaitingForConfirmation { .. }
1791 | ToolCallStatus::InProgress
1792 );
1793
1794 if cancel {
1795 call.status = ToolCallStatus::Canceled;
1796 }
1797 }
1798 }
1799
1800 self.connection.cancel(&self.session_id, cx);
1801
1802 // Wait for the send task to complete
1803 cx.foreground_executor().spawn(send_task)
1804 }
1805
1806 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1807 pub fn restore_checkpoint(
1808 &mut self,
1809 id: UserMessageId,
1810 cx: &mut Context<Self>,
1811 ) -> Task<Result<()>> {
1812 let Some((_, message)) = self.user_message_mut(&id) else {
1813 return Task::ready(Err(anyhow!("message not found")));
1814 };
1815
1816 let checkpoint = message
1817 .checkpoint
1818 .as_ref()
1819 .map(|c| c.git_checkpoint.clone());
1820 let rewind = self.rewind(id.clone(), cx);
1821 let git_store = self.project.read(cx).git_store().clone();
1822
1823 cx.spawn(async move |_, cx| {
1824 rewind.await?;
1825 if let Some(checkpoint) = checkpoint {
1826 git_store
1827 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1828 .await?;
1829 }
1830
1831 Ok(())
1832 })
1833 }
1834
1835 /// Rewinds this thread to before the entry at `index`, removing it and all
1836 /// subsequent entries while rejecting any action_log changes made from that point.
1837 /// Unlike `restore_checkpoint`, this method does not restore from git.
1838 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1839 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1840 return Task::ready(Err(anyhow!("not supported")));
1841 };
1842
1843 cx.spawn(async move |this, cx| {
1844 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1845 this.update(cx, |this, cx| {
1846 if let Some((ix, _)) = this.user_message_mut(&id) {
1847 let range = ix..this.entries.len();
1848 this.entries.truncate(ix);
1849 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1850 }
1851 this.action_log()
1852 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1853 })?
1854 .await;
1855 Ok(())
1856 })
1857 }
1858
1859 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1860 let git_store = self.project.read(cx).git_store().clone();
1861
1862 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1863 if let Some(checkpoint) = message.checkpoint.as_ref() {
1864 checkpoint.git_checkpoint.clone()
1865 } else {
1866 return Task::ready(Ok(()));
1867 }
1868 } else {
1869 return Task::ready(Ok(()));
1870 };
1871
1872 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1873 cx.spawn(async move |this, cx| {
1874 let new_checkpoint = new_checkpoint
1875 .await
1876 .context("failed to get new checkpoint")
1877 .log_err();
1878 if let Some(new_checkpoint) = new_checkpoint {
1879 let equal = git_store
1880 .update(cx, |git, cx| {
1881 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1882 })?
1883 .await
1884 .unwrap_or(true);
1885 this.update(cx, |this, cx| {
1886 let (ix, message) = this.last_user_message().context("no user message")?;
1887 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1888 checkpoint.show = !equal;
1889 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1890 anyhow::Ok(())
1891 })??;
1892 }
1893
1894 Ok(())
1895 })
1896 }
1897
1898 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1899 self.entries
1900 .iter_mut()
1901 .enumerate()
1902 .rev()
1903 .find_map(|(ix, entry)| {
1904 if let AgentThreadEntry::UserMessage(message) = entry {
1905 Some((ix, message))
1906 } else {
1907 None
1908 }
1909 })
1910 }
1911
1912 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1913 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1914 if let AgentThreadEntry::UserMessage(message) = entry {
1915 if message.id.as_ref() == Some(id) {
1916 Some((ix, message))
1917 } else {
1918 None
1919 }
1920 } else {
1921 None
1922 }
1923 })
1924 }
1925
1926 pub fn read_text_file(
1927 &self,
1928 path: PathBuf,
1929 line: Option<u32>,
1930 limit: Option<u32>,
1931 reuse_shared_snapshot: bool,
1932 cx: &mut Context<Self>,
1933 ) -> Task<Result<String, acp::Error>> {
1934 // Args are 1-based, move to 0-based
1935 let line = line.unwrap_or_default().saturating_sub(1);
1936 let limit = limit.unwrap_or(u32::MAX);
1937 let project = self.project.clone();
1938 let action_log = self.action_log.clone();
1939 cx.spawn(async move |this, cx| {
1940 let load = project
1941 .update(cx, |project, cx| {
1942 let path = project
1943 .project_path_for_absolute_path(&path, cx)
1944 .ok_or_else(|| {
1945 acp::Error::resource_not_found(Some(path.display().to_string()))
1946 })?;
1947 Ok(project.open_buffer(path, cx))
1948 })
1949 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1950 .flatten()?;
1951
1952 let buffer = load.await?;
1953
1954 let snapshot = if reuse_shared_snapshot {
1955 this.read_with(cx, |this, _| {
1956 this.shared_buffers.get(&buffer.clone()).cloned()
1957 })
1958 .log_err()
1959 .flatten()
1960 } else {
1961 None
1962 };
1963
1964 let snapshot = if let Some(snapshot) = snapshot {
1965 snapshot
1966 } else {
1967 action_log.update(cx, |action_log, cx| {
1968 action_log.buffer_read(buffer.clone(), cx);
1969 })?;
1970
1971 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1972 this.update(cx, |this, _| {
1973 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1974 })?;
1975 snapshot
1976 };
1977
1978 let max_point = snapshot.max_point();
1979 let start_position = Point::new(line, 0);
1980
1981 if start_position > max_point {
1982 return Err(acp::Error::invalid_params().with_data(format!(
1983 "Attempting to read beyond the end of the file, line {}:{}",
1984 max_point.row + 1,
1985 max_point.column
1986 )));
1987 }
1988
1989 let start = snapshot.anchor_before(start_position);
1990 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1991
1992 project.update(cx, |project, cx| {
1993 project.set_agent_location(
1994 Some(AgentLocation {
1995 buffer: buffer.downgrade(),
1996 position: start,
1997 }),
1998 cx,
1999 );
2000 })?;
2001
2002 Ok(snapshot.text_for_range(start..end).collect::<String>())
2003 })
2004 }
2005
2006 pub fn write_text_file(
2007 &self,
2008 path: PathBuf,
2009 content: String,
2010 cx: &mut Context<Self>,
2011 ) -> Task<Result<()>> {
2012 let project = self.project.clone();
2013 let action_log = self.action_log.clone();
2014 cx.spawn(async move |this, cx| {
2015 let load = project.update(cx, |project, cx| {
2016 let path = project
2017 .project_path_for_absolute_path(&path, cx)
2018 .context("invalid path")?;
2019 anyhow::Ok(project.open_buffer(path, cx))
2020 });
2021 let buffer = load??.await?;
2022 let snapshot = this.update(cx, |this, cx| {
2023 this.shared_buffers
2024 .get(&buffer)
2025 .cloned()
2026 .unwrap_or_else(|| buffer.read(cx).snapshot())
2027 })?;
2028 let edits = cx
2029 .background_executor()
2030 .spawn(async move {
2031 let old_text = snapshot.text();
2032 text_diff(old_text.as_str(), &content)
2033 .into_iter()
2034 .map(|(range, replacement)| {
2035 (
2036 snapshot.anchor_after(range.start)
2037 ..snapshot.anchor_before(range.end),
2038 replacement,
2039 )
2040 })
2041 .collect::<Vec<_>>()
2042 })
2043 .await;
2044
2045 project.update(cx, |project, cx| {
2046 project.set_agent_location(
2047 Some(AgentLocation {
2048 buffer: buffer.downgrade(),
2049 position: edits
2050 .last()
2051 .map(|(range, _)| range.end)
2052 .unwrap_or(Anchor::MIN),
2053 }),
2054 cx,
2055 );
2056 })?;
2057
2058 let format_on_save = cx.update(|cx| {
2059 action_log.update(cx, |action_log, cx| {
2060 action_log.buffer_read(buffer.clone(), cx);
2061 });
2062
2063 let format_on_save = buffer.update(cx, |buffer, cx| {
2064 buffer.edit(edits, None, cx);
2065
2066 let settings = language::language_settings::language_settings(
2067 buffer.language().map(|l| l.name()),
2068 buffer.file(),
2069 cx,
2070 );
2071
2072 settings.format_on_save != FormatOnSave::Off
2073 });
2074 action_log.update(cx, |action_log, cx| {
2075 action_log.buffer_edited(buffer.clone(), cx);
2076 });
2077 format_on_save
2078 })?;
2079
2080 if format_on_save {
2081 let format_task = project.update(cx, |project, cx| {
2082 project.format(
2083 HashSet::from_iter([buffer.clone()]),
2084 LspFormatTarget::Buffers,
2085 false,
2086 FormatTrigger::Save,
2087 cx,
2088 )
2089 })?;
2090 format_task.await.log_err();
2091
2092 action_log.update(cx, |action_log, cx| {
2093 action_log.buffer_edited(buffer.clone(), cx);
2094 })?;
2095 }
2096
2097 project
2098 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2099 .await
2100 })
2101 }
2102
2103 pub fn create_terminal(
2104 &self,
2105 command: String,
2106 args: Vec<String>,
2107 extra_env: Vec<acp::EnvVariable>,
2108 cwd: Option<PathBuf>,
2109 output_byte_limit: Option<u64>,
2110 cx: &mut Context<Self>,
2111 ) -> Task<Result<Entity<Terminal>>> {
2112 let env = match &cwd {
2113 Some(dir) => self.project.update(cx, |project, cx| {
2114 let worktree = project.find_worktree(dir.as_path(), cx);
2115 let shell = TerminalSettings::get(
2116 worktree.as_ref().map(|(worktree, path)| SettingsLocation {
2117 worktree_id: worktree.read(cx).id(),
2118 path: &path,
2119 }),
2120 cx,
2121 )
2122 .shell
2123 .clone();
2124 project.directory_environment(&shell, dir.as_path().into(), cx)
2125 }),
2126 None => Task::ready(None).shared(),
2127 };
2128 let env = cx.spawn(async move |_, _| {
2129 let mut env = env.await.unwrap_or_default();
2130 // Disables paging for `git` and hopefully other commands
2131 env.insert("PAGER".into(), "".into());
2132 for var in extra_env {
2133 env.insert(var.name, var.value);
2134 }
2135 env
2136 });
2137
2138 let project = self.project.clone();
2139 let language_registry = project.read(cx).languages().clone();
2140 let is_windows = project.read(cx).path_style(cx).is_windows();
2141
2142 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2143 let terminal_task = cx.spawn({
2144 let terminal_id = terminal_id.clone();
2145 async move |_this, cx| {
2146 let env = env.await;
2147 let shell = project
2148 .update(cx, |project, cx| {
2149 project
2150 .remote_client()
2151 .and_then(|r| r.read(cx).default_system_shell())
2152 })?
2153 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2154 let (task_command, task_args) =
2155 ShellBuilder::new(&Shell::Program(shell), is_windows)
2156 .redirect_stdin_to_dev_null()
2157 .build(Some(command.clone()), &args);
2158 let terminal = project
2159 .update(cx, |project, cx| {
2160 project.create_terminal_task(
2161 task::SpawnInTerminal {
2162 command: Some(task_command),
2163 args: task_args,
2164 cwd: cwd.clone(),
2165 env,
2166 ..Default::default()
2167 },
2168 cx,
2169 )
2170 })?
2171 .await?;
2172
2173 cx.new(|cx| {
2174 Terminal::new(
2175 terminal_id,
2176 &format!("{} {}", command, args.join(" ")),
2177 cwd,
2178 output_byte_limit.map(|l| l as usize),
2179 terminal,
2180 language_registry,
2181 cx,
2182 )
2183 })
2184 }
2185 });
2186
2187 cx.spawn(async move |this, cx| {
2188 let terminal = terminal_task.await?;
2189 this.update(cx, |this, _cx| {
2190 this.terminals.insert(terminal_id, terminal.clone());
2191 terminal
2192 })
2193 })
2194 }
2195
2196 pub fn kill_terminal(
2197 &mut self,
2198 terminal_id: acp::TerminalId,
2199 cx: &mut Context<Self>,
2200 ) -> Result<()> {
2201 self.terminals
2202 .get(&terminal_id)
2203 .context("Terminal not found")?
2204 .update(cx, |terminal, cx| {
2205 terminal.kill(cx);
2206 });
2207
2208 Ok(())
2209 }
2210
2211 pub fn release_terminal(
2212 &mut self,
2213 terminal_id: acp::TerminalId,
2214 cx: &mut Context<Self>,
2215 ) -> Result<()> {
2216 self.terminals
2217 .remove(&terminal_id)
2218 .context("Terminal not found")?
2219 .update(cx, |terminal, cx| {
2220 terminal.kill(cx);
2221 });
2222
2223 Ok(())
2224 }
2225
2226 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2227 self.terminals
2228 .get(&terminal_id)
2229 .context("Terminal not found")
2230 .cloned()
2231 }
2232
2233 pub fn to_markdown(&self, cx: &App) -> String {
2234 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2235 }
2236
2237 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2238 cx.emit(AcpThreadEvent::LoadError(error));
2239 }
2240
2241 pub fn register_terminal_created(
2242 &mut self,
2243 terminal_id: acp::TerminalId,
2244 command_label: String,
2245 working_dir: Option<PathBuf>,
2246 output_byte_limit: Option<u64>,
2247 terminal: Entity<::terminal::Terminal>,
2248 cx: &mut Context<Self>,
2249 ) -> Entity<Terminal> {
2250 let language_registry = self.project.read(cx).languages().clone();
2251
2252 let entity = cx.new(|cx| {
2253 Terminal::new(
2254 terminal_id.clone(),
2255 &command_label,
2256 working_dir.clone(),
2257 output_byte_limit.map(|l| l as usize),
2258 terminal,
2259 language_registry,
2260 cx,
2261 )
2262 });
2263 self.terminals.insert(terminal_id.clone(), entity.clone());
2264 entity
2265 }
2266}
2267
2268fn markdown_for_raw_output(
2269 raw_output: &serde_json::Value,
2270 language_registry: &Arc<LanguageRegistry>,
2271 cx: &mut App,
2272) -> Option<Entity<Markdown>> {
2273 match raw_output {
2274 serde_json::Value::Null => None,
2275 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2276 Markdown::new(
2277 value.to_string().into(),
2278 Some(language_registry.clone()),
2279 None,
2280 cx,
2281 )
2282 })),
2283 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2284 Markdown::new(
2285 value.to_string().into(),
2286 Some(language_registry.clone()),
2287 None,
2288 cx,
2289 )
2290 })),
2291 serde_json::Value::String(value) => Some(cx.new(|cx| {
2292 Markdown::new(
2293 value.clone().into(),
2294 Some(language_registry.clone()),
2295 None,
2296 cx,
2297 )
2298 })),
2299 value => Some(cx.new(|cx| {
2300 Markdown::new(
2301 format!("```json\n{}\n```", value).into(),
2302 Some(language_registry.clone()),
2303 None,
2304 cx,
2305 )
2306 })),
2307 }
2308}
2309
2310#[cfg(test)]
2311mod tests {
2312 use super::*;
2313 use anyhow::anyhow;
2314 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2315 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2316 use indoc::indoc;
2317 use project::{FakeFs, Fs};
2318 use rand::{distr, prelude::*};
2319 use serde_json::json;
2320 use settings::SettingsStore;
2321 use smol::stream::StreamExt as _;
2322 use std::{
2323 any::Any,
2324 cell::RefCell,
2325 path::Path,
2326 rc::Rc,
2327 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2328 time::Duration,
2329 };
2330 use util::path;
2331
2332 fn init_test(cx: &mut TestAppContext) {
2333 env_logger::try_init().ok();
2334 cx.update(|cx| {
2335 let settings_store = SettingsStore::test(cx);
2336 cx.set_global(settings_store);
2337 Project::init_settings(cx);
2338 language::init(cx);
2339 });
2340 }
2341
2342 #[gpui::test]
2343 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2344 init_test(cx);
2345
2346 let fs = FakeFs::new(cx.executor());
2347 let project = Project::test(fs, [], cx).await;
2348 let connection = Rc::new(FakeAgentConnection::new());
2349 let thread = cx
2350 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2351 .await
2352 .unwrap();
2353
2354 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2355
2356 // Send Output BEFORE Created - should be buffered by acp_thread
2357 thread.update(cx, |thread, cx| {
2358 thread.on_terminal_provider_event(
2359 TerminalProviderEvent::Output {
2360 terminal_id: terminal_id.clone(),
2361 data: b"hello buffered".to_vec(),
2362 },
2363 cx,
2364 );
2365 });
2366
2367 // Create a display-only terminal and then send Created
2368 let lower = cx.new(|cx| {
2369 let builder = ::terminal::TerminalBuilder::new_display_only(
2370 ::terminal::terminal_settings::CursorShape::default(),
2371 ::terminal::terminal_settings::AlternateScroll::On,
2372 None,
2373 0,
2374 )
2375 .unwrap();
2376 builder.subscribe(cx)
2377 });
2378
2379 thread.update(cx, |thread, cx| {
2380 thread.on_terminal_provider_event(
2381 TerminalProviderEvent::Created {
2382 terminal_id: terminal_id.clone(),
2383 label: "Buffered Test".to_string(),
2384 cwd: None,
2385 output_byte_limit: None,
2386 terminal: lower.clone(),
2387 },
2388 cx,
2389 );
2390 });
2391
2392 // After Created, buffered Output should have been flushed into the renderer
2393 let content = thread.read_with(cx, |thread, cx| {
2394 let term = thread.terminal(terminal_id.clone()).unwrap();
2395 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2396 });
2397
2398 assert!(
2399 content.contains("hello buffered"),
2400 "expected buffered output to render, got: {content}"
2401 );
2402 }
2403
2404 #[gpui::test]
2405 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2406 init_test(cx);
2407
2408 let fs = FakeFs::new(cx.executor());
2409 let project = Project::test(fs, [], cx).await;
2410 let connection = Rc::new(FakeAgentConnection::new());
2411 let thread = cx
2412 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2413 .await
2414 .unwrap();
2415
2416 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2417
2418 // Send Output BEFORE Created
2419 thread.update(cx, |thread, cx| {
2420 thread.on_terminal_provider_event(
2421 TerminalProviderEvent::Output {
2422 terminal_id: terminal_id.clone(),
2423 data: b"pre-exit data".to_vec(),
2424 },
2425 cx,
2426 );
2427 });
2428
2429 // Send Exit BEFORE Created
2430 thread.update(cx, |thread, cx| {
2431 thread.on_terminal_provider_event(
2432 TerminalProviderEvent::Exit {
2433 terminal_id: terminal_id.clone(),
2434 status: acp::TerminalExitStatus {
2435 exit_code: Some(0),
2436 signal: None,
2437 meta: None,
2438 },
2439 },
2440 cx,
2441 );
2442 });
2443
2444 // Now create a display-only lower-level terminal and send Created
2445 let lower = cx.new(|cx| {
2446 let builder = ::terminal::TerminalBuilder::new_display_only(
2447 ::terminal::terminal_settings::CursorShape::default(),
2448 ::terminal::terminal_settings::AlternateScroll::On,
2449 None,
2450 0,
2451 )
2452 .unwrap();
2453 builder.subscribe(cx)
2454 });
2455
2456 thread.update(cx, |thread, cx| {
2457 thread.on_terminal_provider_event(
2458 TerminalProviderEvent::Created {
2459 terminal_id: terminal_id.clone(),
2460 label: "Buffered Exit Test".to_string(),
2461 cwd: None,
2462 output_byte_limit: None,
2463 terminal: lower.clone(),
2464 },
2465 cx,
2466 );
2467 });
2468
2469 // Output should be present after Created (flushed from buffer)
2470 let content = thread.read_with(cx, |thread, cx| {
2471 let term = thread.terminal(terminal_id.clone()).unwrap();
2472 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2473 });
2474
2475 assert!(
2476 content.contains("pre-exit data"),
2477 "expected pre-exit data to render, got: {content}"
2478 );
2479 }
2480
2481 #[gpui::test]
2482 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2483 init_test(cx);
2484
2485 let fs = FakeFs::new(cx.executor());
2486 let project = Project::test(fs, [], cx).await;
2487 let connection = Rc::new(FakeAgentConnection::new());
2488 let thread = cx
2489 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2490 .await
2491 .unwrap();
2492
2493 // Test creating a new user message
2494 thread.update(cx, |thread, cx| {
2495 thread.push_user_content_block(
2496 None,
2497 acp::ContentBlock::Text(acp::TextContent {
2498 annotations: None,
2499 text: "Hello, ".to_string(),
2500 meta: None,
2501 }),
2502 cx,
2503 );
2504 });
2505
2506 thread.update(cx, |thread, cx| {
2507 assert_eq!(thread.entries.len(), 1);
2508 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2509 assert_eq!(user_msg.id, None);
2510 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2511 } else {
2512 panic!("Expected UserMessage");
2513 }
2514 });
2515
2516 // Test appending to existing user message
2517 let message_1_id = UserMessageId::new();
2518 thread.update(cx, |thread, cx| {
2519 thread.push_user_content_block(
2520 Some(message_1_id.clone()),
2521 acp::ContentBlock::Text(acp::TextContent {
2522 annotations: None,
2523 text: "world!".to_string(),
2524 meta: None,
2525 }),
2526 cx,
2527 );
2528 });
2529
2530 thread.update(cx, |thread, cx| {
2531 assert_eq!(thread.entries.len(), 1);
2532 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2533 assert_eq!(user_msg.id, Some(message_1_id));
2534 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2535 } else {
2536 panic!("Expected UserMessage");
2537 }
2538 });
2539
2540 // Test creating new user message after assistant message
2541 thread.update(cx, |thread, cx| {
2542 thread.push_assistant_content_block(
2543 acp::ContentBlock::Text(acp::TextContent {
2544 annotations: None,
2545 text: "Assistant response".to_string(),
2546 meta: None,
2547 }),
2548 false,
2549 cx,
2550 );
2551 });
2552
2553 let message_2_id = UserMessageId::new();
2554 thread.update(cx, |thread, cx| {
2555 thread.push_user_content_block(
2556 Some(message_2_id.clone()),
2557 acp::ContentBlock::Text(acp::TextContent {
2558 annotations: None,
2559 text: "New user message".to_string(),
2560 meta: None,
2561 }),
2562 cx,
2563 );
2564 });
2565
2566 thread.update(cx, |thread, cx| {
2567 assert_eq!(thread.entries.len(), 3);
2568 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2569 assert_eq!(user_msg.id, Some(message_2_id));
2570 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2571 } else {
2572 panic!("Expected UserMessage at index 2");
2573 }
2574 });
2575 }
2576
2577 #[gpui::test]
2578 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2579 init_test(cx);
2580
2581 let fs = FakeFs::new(cx.executor());
2582 let project = Project::test(fs, [], cx).await;
2583 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2584 |_, thread, mut cx| {
2585 async move {
2586 thread.update(&mut cx, |thread, cx| {
2587 thread
2588 .handle_session_update(
2589 acp::SessionUpdate::AgentThoughtChunk {
2590 content: "Thinking ".into(),
2591 },
2592 cx,
2593 )
2594 .unwrap();
2595 thread
2596 .handle_session_update(
2597 acp::SessionUpdate::AgentThoughtChunk {
2598 content: "hard!".into(),
2599 },
2600 cx,
2601 )
2602 .unwrap();
2603 })?;
2604 Ok(acp::PromptResponse {
2605 stop_reason: acp::StopReason::EndTurn,
2606 meta: None,
2607 })
2608 }
2609 .boxed_local()
2610 },
2611 ));
2612
2613 let thread = cx
2614 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2615 .await
2616 .unwrap();
2617
2618 thread
2619 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2620 .await
2621 .unwrap();
2622
2623 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2624 assert_eq!(
2625 output,
2626 indoc! {r#"
2627 ## User
2628
2629 Hello from Zed!
2630
2631 ## Assistant
2632
2633 <thinking>
2634 Thinking hard!
2635 </thinking>
2636
2637 "#}
2638 );
2639 }
2640
2641 #[gpui::test]
2642 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2643 init_test(cx);
2644
2645 let fs = FakeFs::new(cx.executor());
2646 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2647 .await;
2648 let project = Project::test(fs.clone(), [], cx).await;
2649 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2650 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2651 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2652 move |_, thread, mut cx| {
2653 let read_file_tx = read_file_tx.clone();
2654 async move {
2655 let content = thread
2656 .update(&mut cx, |thread, cx| {
2657 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2658 })
2659 .unwrap()
2660 .await
2661 .unwrap();
2662 assert_eq!(content, "one\ntwo\nthree\n");
2663 read_file_tx.take().unwrap().send(()).unwrap();
2664 thread
2665 .update(&mut cx, |thread, cx| {
2666 thread.write_text_file(
2667 path!("/tmp/foo").into(),
2668 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2669 cx,
2670 )
2671 })
2672 .unwrap()
2673 .await
2674 .unwrap();
2675 Ok(acp::PromptResponse {
2676 stop_reason: acp::StopReason::EndTurn,
2677 meta: None,
2678 })
2679 }
2680 .boxed_local()
2681 },
2682 ));
2683
2684 let (worktree, pathbuf) = project
2685 .update(cx, |project, cx| {
2686 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2687 })
2688 .await
2689 .unwrap();
2690 let buffer = project
2691 .update(cx, |project, cx| {
2692 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2693 })
2694 .await
2695 .unwrap();
2696
2697 let thread = cx
2698 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2699 .await
2700 .unwrap();
2701
2702 let request = thread.update(cx, |thread, cx| {
2703 thread.send_raw("Extend the count in /tmp/foo", cx)
2704 });
2705 read_file_rx.await.ok();
2706 buffer.update(cx, |buffer, cx| {
2707 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2708 });
2709 cx.run_until_parked();
2710 assert_eq!(
2711 buffer.read_with(cx, |buffer, _| buffer.text()),
2712 "zero\none\ntwo\nthree\nfour\nfive\n"
2713 );
2714 assert_eq!(
2715 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2716 "zero\none\ntwo\nthree\nfour\nfive\n"
2717 );
2718 request.await.unwrap();
2719 }
2720
2721 #[gpui::test]
2722 async fn test_reading_from_line(cx: &mut TestAppContext) {
2723 init_test(cx);
2724
2725 let fs = FakeFs::new(cx.executor());
2726 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2727 .await;
2728 let project = Project::test(fs.clone(), [], cx).await;
2729 project
2730 .update(cx, |project, cx| {
2731 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2732 })
2733 .await
2734 .unwrap();
2735
2736 let connection = Rc::new(FakeAgentConnection::new());
2737
2738 let thread = cx
2739 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2740 .await
2741 .unwrap();
2742
2743 // Whole file
2744 let content = thread
2745 .update(cx, |thread, cx| {
2746 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2747 })
2748 .await
2749 .unwrap();
2750
2751 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2752
2753 // Only start line
2754 let content = thread
2755 .update(cx, |thread, cx| {
2756 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2757 })
2758 .await
2759 .unwrap();
2760
2761 assert_eq!(content, "three\nfour\n");
2762
2763 // Only limit
2764 let content = thread
2765 .update(cx, |thread, cx| {
2766 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2767 })
2768 .await
2769 .unwrap();
2770
2771 assert_eq!(content, "one\ntwo\n");
2772
2773 // Range
2774 let content = thread
2775 .update(cx, |thread, cx| {
2776 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2777 })
2778 .await
2779 .unwrap();
2780
2781 assert_eq!(content, "two\nthree\n");
2782
2783 // Invalid
2784 let err = thread
2785 .update(cx, |thread, cx| {
2786 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2787 })
2788 .await
2789 .unwrap_err();
2790
2791 assert_eq!(
2792 err.to_string(),
2793 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2794 );
2795 }
2796
2797 #[gpui::test]
2798 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2799 init_test(cx);
2800
2801 let fs = FakeFs::new(cx.executor());
2802 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2803 let project = Project::test(fs.clone(), [], cx).await;
2804 project
2805 .update(cx, |project, cx| {
2806 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2807 })
2808 .await
2809 .unwrap();
2810
2811 let connection = Rc::new(FakeAgentConnection::new());
2812
2813 let thread = cx
2814 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2815 .await
2816 .unwrap();
2817
2818 // Whole file
2819 let content = thread
2820 .update(cx, |thread, cx| {
2821 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2822 })
2823 .await
2824 .unwrap();
2825
2826 assert_eq!(content, "");
2827
2828 // Only start line
2829 let content = thread
2830 .update(cx, |thread, cx| {
2831 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2832 })
2833 .await
2834 .unwrap();
2835
2836 assert_eq!(content, "");
2837
2838 // Only limit
2839 let content = thread
2840 .update(cx, |thread, cx| {
2841 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2842 })
2843 .await
2844 .unwrap();
2845
2846 assert_eq!(content, "");
2847
2848 // Range
2849 let content = thread
2850 .update(cx, |thread, cx| {
2851 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2852 })
2853 .await
2854 .unwrap();
2855
2856 assert_eq!(content, "");
2857
2858 // Invalid
2859 let err = thread
2860 .update(cx, |thread, cx| {
2861 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2862 })
2863 .await
2864 .unwrap_err();
2865
2866 assert_eq!(
2867 err.to_string(),
2868 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2869 );
2870 }
2871 #[gpui::test]
2872 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2873 init_test(cx);
2874
2875 let fs = FakeFs::new(cx.executor());
2876 fs.insert_tree(path!("/tmp"), json!({})).await;
2877 let project = Project::test(fs.clone(), [], cx).await;
2878 project
2879 .update(cx, |project, cx| {
2880 project.find_or_create_worktree(path!("/tmp"), true, cx)
2881 })
2882 .await
2883 .unwrap();
2884
2885 let connection = Rc::new(FakeAgentConnection::new());
2886
2887 let thread = cx
2888 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2889 .await
2890 .unwrap();
2891
2892 // Out of project file
2893 let err = thread
2894 .update(cx, |thread, cx| {
2895 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2896 })
2897 .await
2898 .unwrap_err();
2899
2900 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2901 }
2902
2903 #[gpui::test]
2904 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2905 init_test(cx);
2906
2907 let fs = FakeFs::new(cx.executor());
2908 let project = Project::test(fs, [], cx).await;
2909 let id = acp::ToolCallId("test".into());
2910
2911 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2912 let id = id.clone();
2913 move |_, thread, mut cx| {
2914 let id = id.clone();
2915 async move {
2916 thread
2917 .update(&mut cx, |thread, cx| {
2918 thread.handle_session_update(
2919 acp::SessionUpdate::ToolCall(acp::ToolCall {
2920 id: id.clone(),
2921 title: "Label".into(),
2922 kind: acp::ToolKind::Fetch,
2923 status: acp::ToolCallStatus::InProgress,
2924 content: vec![],
2925 locations: vec![],
2926 raw_input: None,
2927 raw_output: None,
2928 meta: None,
2929 }),
2930 cx,
2931 )
2932 })
2933 .unwrap()
2934 .unwrap();
2935 Ok(acp::PromptResponse {
2936 stop_reason: acp::StopReason::EndTurn,
2937 meta: None,
2938 })
2939 }
2940 .boxed_local()
2941 }
2942 }));
2943
2944 let thread = cx
2945 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2946 .await
2947 .unwrap();
2948
2949 let request = thread.update(cx, |thread, cx| {
2950 thread.send_raw("Fetch https://example.com", cx)
2951 });
2952
2953 run_until_first_tool_call(&thread, cx).await;
2954
2955 thread.read_with(cx, |thread, _| {
2956 assert!(matches!(
2957 thread.entries[1],
2958 AgentThreadEntry::ToolCall(ToolCall {
2959 status: ToolCallStatus::InProgress,
2960 ..
2961 })
2962 ));
2963 });
2964
2965 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2966
2967 thread.read_with(cx, |thread, _| {
2968 assert!(matches!(
2969 &thread.entries[1],
2970 AgentThreadEntry::ToolCall(ToolCall {
2971 status: ToolCallStatus::Canceled,
2972 ..
2973 })
2974 ));
2975 });
2976
2977 thread
2978 .update(cx, |thread, cx| {
2979 thread.handle_session_update(
2980 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2981 id,
2982 fields: acp::ToolCallUpdateFields {
2983 status: Some(acp::ToolCallStatus::Completed),
2984 ..Default::default()
2985 },
2986 meta: None,
2987 }),
2988 cx,
2989 )
2990 })
2991 .unwrap();
2992
2993 request.await.unwrap();
2994
2995 thread.read_with(cx, |thread, _| {
2996 assert!(matches!(
2997 thread.entries[1],
2998 AgentThreadEntry::ToolCall(ToolCall {
2999 status: ToolCallStatus::Completed,
3000 ..
3001 })
3002 ));
3003 });
3004 }
3005
3006 #[gpui::test]
3007 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3008 init_test(cx);
3009 let fs = FakeFs::new(cx.background_executor.clone());
3010 fs.insert_tree(path!("/test"), json!({})).await;
3011 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3012
3013 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3014 move |_, thread, mut cx| {
3015 async move {
3016 thread
3017 .update(&mut cx, |thread, cx| {
3018 thread.handle_session_update(
3019 acp::SessionUpdate::ToolCall(acp::ToolCall {
3020 id: acp::ToolCallId("test".into()),
3021 title: "Label".into(),
3022 kind: acp::ToolKind::Edit,
3023 status: acp::ToolCallStatus::Completed,
3024 content: vec![acp::ToolCallContent::Diff {
3025 diff: acp::Diff {
3026 path: "/test/test.txt".into(),
3027 old_text: None,
3028 new_text: "foo".into(),
3029 meta: None,
3030 },
3031 }],
3032 locations: vec![],
3033 raw_input: None,
3034 raw_output: None,
3035 meta: None,
3036 }),
3037 cx,
3038 )
3039 })
3040 .unwrap()
3041 .unwrap();
3042 Ok(acp::PromptResponse {
3043 stop_reason: acp::StopReason::EndTurn,
3044 meta: None,
3045 })
3046 }
3047 .boxed_local()
3048 }
3049 }));
3050
3051 let thread = cx
3052 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3053 .await
3054 .unwrap();
3055
3056 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3057 .await
3058 .unwrap();
3059
3060 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3061 }
3062
3063 #[gpui::test(iterations = 10)]
3064 async fn test_checkpoints(cx: &mut TestAppContext) {
3065 init_test(cx);
3066 let fs = FakeFs::new(cx.background_executor.clone());
3067 fs.insert_tree(
3068 path!("/test"),
3069 json!({
3070 ".git": {}
3071 }),
3072 )
3073 .await;
3074 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3075
3076 let simulate_changes = Arc::new(AtomicBool::new(true));
3077 let next_filename = Arc::new(AtomicUsize::new(0));
3078 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3079 let simulate_changes = simulate_changes.clone();
3080 let next_filename = next_filename.clone();
3081 let fs = fs.clone();
3082 move |request, thread, mut cx| {
3083 let fs = fs.clone();
3084 let simulate_changes = simulate_changes.clone();
3085 let next_filename = next_filename.clone();
3086 async move {
3087 if simulate_changes.load(SeqCst) {
3088 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3089 fs.write(Path::new(&filename), b"").await?;
3090 }
3091
3092 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3093 panic!("expected text content block");
3094 };
3095 thread.update(&mut cx, |thread, cx| {
3096 thread
3097 .handle_session_update(
3098 acp::SessionUpdate::AgentMessageChunk {
3099 content: content.text.to_uppercase().into(),
3100 },
3101 cx,
3102 )
3103 .unwrap();
3104 })?;
3105 Ok(acp::PromptResponse {
3106 stop_reason: acp::StopReason::EndTurn,
3107 meta: None,
3108 })
3109 }
3110 .boxed_local()
3111 }
3112 }));
3113 let thread = cx
3114 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3115 .await
3116 .unwrap();
3117
3118 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3119 .await
3120 .unwrap();
3121 thread.read_with(cx, |thread, cx| {
3122 assert_eq!(
3123 thread.to_markdown(cx),
3124 indoc! {"
3125 ## User (checkpoint)
3126
3127 Lorem
3128
3129 ## Assistant
3130
3131 LOREM
3132
3133 "}
3134 );
3135 });
3136 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3137
3138 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3139 .await
3140 .unwrap();
3141 thread.read_with(cx, |thread, cx| {
3142 assert_eq!(
3143 thread.to_markdown(cx),
3144 indoc! {"
3145 ## User (checkpoint)
3146
3147 Lorem
3148
3149 ## Assistant
3150
3151 LOREM
3152
3153 ## User (checkpoint)
3154
3155 ipsum
3156
3157 ## Assistant
3158
3159 IPSUM
3160
3161 "}
3162 );
3163 });
3164 assert_eq!(
3165 fs.files(),
3166 vec![
3167 Path::new(path!("/test/file-0")),
3168 Path::new(path!("/test/file-1"))
3169 ]
3170 );
3171
3172 // Checkpoint isn't stored when there are no changes.
3173 simulate_changes.store(false, SeqCst);
3174 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3175 .await
3176 .unwrap();
3177 thread.read_with(cx, |thread, cx| {
3178 assert_eq!(
3179 thread.to_markdown(cx),
3180 indoc! {"
3181 ## User (checkpoint)
3182
3183 Lorem
3184
3185 ## Assistant
3186
3187 LOREM
3188
3189 ## User (checkpoint)
3190
3191 ipsum
3192
3193 ## Assistant
3194
3195 IPSUM
3196
3197 ## User
3198
3199 dolor
3200
3201 ## Assistant
3202
3203 DOLOR
3204
3205 "}
3206 );
3207 });
3208 assert_eq!(
3209 fs.files(),
3210 vec![
3211 Path::new(path!("/test/file-0")),
3212 Path::new(path!("/test/file-1"))
3213 ]
3214 );
3215
3216 // Rewinding the conversation truncates the history and restores the checkpoint.
3217 thread
3218 .update(cx, |thread, cx| {
3219 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3220 panic!("unexpected entries {:?}", thread.entries)
3221 };
3222 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3223 })
3224 .await
3225 .unwrap();
3226 thread.read_with(cx, |thread, cx| {
3227 assert_eq!(
3228 thread.to_markdown(cx),
3229 indoc! {"
3230 ## User (checkpoint)
3231
3232 Lorem
3233
3234 ## Assistant
3235
3236 LOREM
3237
3238 "}
3239 );
3240 });
3241 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3242 }
3243
3244 #[gpui::test]
3245 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3246 use std::sync::atomic::AtomicUsize;
3247 init_test(cx);
3248
3249 let fs = FakeFs::new(cx.executor());
3250 let project = Project::test(fs, None, cx).await;
3251
3252 // Create a connection that simulates refusal after tool result
3253 let prompt_count = Arc::new(AtomicUsize::new(0));
3254 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3255 let prompt_count = prompt_count.clone();
3256 move |_request, thread, mut cx| {
3257 let count = prompt_count.fetch_add(1, SeqCst);
3258 async move {
3259 if count == 0 {
3260 // First prompt: Generate a tool call with result
3261 thread.update(&mut cx, |thread, cx| {
3262 thread
3263 .handle_session_update(
3264 acp::SessionUpdate::ToolCall(acp::ToolCall {
3265 id: acp::ToolCallId("tool1".into()),
3266 title: "Test Tool".into(),
3267 kind: acp::ToolKind::Fetch,
3268 status: acp::ToolCallStatus::Completed,
3269 content: vec![],
3270 locations: vec![],
3271 raw_input: Some(serde_json::json!({"query": "test"})),
3272 raw_output: Some(
3273 serde_json::json!({"result": "inappropriate content"}),
3274 ),
3275 meta: None,
3276 }),
3277 cx,
3278 )
3279 .unwrap();
3280 })?;
3281
3282 // Now return refusal because of the tool result
3283 Ok(acp::PromptResponse {
3284 stop_reason: acp::StopReason::Refusal,
3285 meta: None,
3286 })
3287 } else {
3288 Ok(acp::PromptResponse {
3289 stop_reason: acp::StopReason::EndTurn,
3290 meta: None,
3291 })
3292 }
3293 }
3294 .boxed_local()
3295 }
3296 }));
3297
3298 let thread = cx
3299 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3300 .await
3301 .unwrap();
3302
3303 // Track if we see a Refusal event
3304 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3305 let saw_refusal_event_captured = saw_refusal_event.clone();
3306 thread.update(cx, |_thread, cx| {
3307 cx.subscribe(
3308 &thread,
3309 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3310 if matches!(event, AcpThreadEvent::Refusal) {
3311 *saw_refusal_event_captured.lock().unwrap() = true;
3312 }
3313 },
3314 )
3315 .detach();
3316 });
3317
3318 // Send a user message - this will trigger tool call and then refusal
3319 let send_task = thread.update(cx, |thread, cx| {
3320 thread.send(
3321 vec![acp::ContentBlock::Text(acp::TextContent {
3322 text: "Hello".into(),
3323 annotations: None,
3324 meta: None,
3325 })],
3326 cx,
3327 )
3328 });
3329 cx.background_executor.spawn(send_task).detach();
3330 cx.run_until_parked();
3331
3332 // Verify that:
3333 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3334 // 2. The user message was NOT truncated
3335 assert!(
3336 *saw_refusal_event.lock().unwrap(),
3337 "Refusal event should be emitted for tool result refusals"
3338 );
3339
3340 thread.read_with(cx, |thread, _| {
3341 let entries = thread.entries();
3342 assert!(entries.len() >= 2, "Should have user message and tool call");
3343
3344 // Verify user message is still there
3345 assert!(
3346 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3347 "User message should not be truncated"
3348 );
3349
3350 // Verify tool call is there with result
3351 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3352 assert!(
3353 tool_call.raw_output.is_some(),
3354 "Tool call should have output"
3355 );
3356 } else {
3357 panic!("Expected tool call at index 1");
3358 }
3359 });
3360 }
3361
3362 #[gpui::test]
3363 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3364 init_test(cx);
3365
3366 let fs = FakeFs::new(cx.executor());
3367 let project = Project::test(fs, None, cx).await;
3368
3369 let refuse_next = Arc::new(AtomicBool::new(false));
3370 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3371 let refuse_next = refuse_next.clone();
3372 move |_request, _thread, _cx| {
3373 if refuse_next.load(SeqCst) {
3374 async move {
3375 Ok(acp::PromptResponse {
3376 stop_reason: acp::StopReason::Refusal,
3377 meta: None,
3378 })
3379 }
3380 .boxed_local()
3381 } else {
3382 async move {
3383 Ok(acp::PromptResponse {
3384 stop_reason: acp::StopReason::EndTurn,
3385 meta: None,
3386 })
3387 }
3388 .boxed_local()
3389 }
3390 }
3391 }));
3392
3393 let thread = cx
3394 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3395 .await
3396 .unwrap();
3397
3398 // Track if we see a Refusal event
3399 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3400 let saw_refusal_event_captured = saw_refusal_event.clone();
3401 thread.update(cx, |_thread, cx| {
3402 cx.subscribe(
3403 &thread,
3404 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3405 if matches!(event, AcpThreadEvent::Refusal) {
3406 *saw_refusal_event_captured.lock().unwrap() = true;
3407 }
3408 },
3409 )
3410 .detach();
3411 });
3412
3413 // Send a message that will be refused
3414 refuse_next.store(true, SeqCst);
3415 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3416 .await
3417 .unwrap();
3418
3419 // Verify that a Refusal event WAS emitted for user prompt refusal
3420 assert!(
3421 *saw_refusal_event.lock().unwrap(),
3422 "Refusal event should be emitted for user prompt refusals"
3423 );
3424
3425 // Verify the message was truncated (user prompt refusal)
3426 thread.read_with(cx, |thread, cx| {
3427 assert_eq!(thread.to_markdown(cx), "");
3428 });
3429 }
3430
3431 #[gpui::test]
3432 async fn test_refusal(cx: &mut TestAppContext) {
3433 init_test(cx);
3434 let fs = FakeFs::new(cx.background_executor.clone());
3435 fs.insert_tree(path!("/"), json!({})).await;
3436 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3437
3438 let refuse_next = Arc::new(AtomicBool::new(false));
3439 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3440 let refuse_next = refuse_next.clone();
3441 move |request, thread, mut cx| {
3442 let refuse_next = refuse_next.clone();
3443 async move {
3444 if refuse_next.load(SeqCst) {
3445 return Ok(acp::PromptResponse {
3446 stop_reason: acp::StopReason::Refusal,
3447 meta: None,
3448 });
3449 }
3450
3451 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3452 panic!("expected text content block");
3453 };
3454 thread.update(&mut cx, |thread, cx| {
3455 thread
3456 .handle_session_update(
3457 acp::SessionUpdate::AgentMessageChunk {
3458 content: content.text.to_uppercase().into(),
3459 },
3460 cx,
3461 )
3462 .unwrap();
3463 })?;
3464 Ok(acp::PromptResponse {
3465 stop_reason: acp::StopReason::EndTurn,
3466 meta: None,
3467 })
3468 }
3469 .boxed_local()
3470 }
3471 }));
3472 let thread = cx
3473 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3474 .await
3475 .unwrap();
3476
3477 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3478 .await
3479 .unwrap();
3480 thread.read_with(cx, |thread, cx| {
3481 assert_eq!(
3482 thread.to_markdown(cx),
3483 indoc! {"
3484 ## User
3485
3486 hello
3487
3488 ## Assistant
3489
3490 HELLO
3491
3492 "}
3493 );
3494 });
3495
3496 // Simulate refusing the second message. The message should be truncated
3497 // when a user prompt is refused.
3498 refuse_next.store(true, SeqCst);
3499 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3500 .await
3501 .unwrap();
3502 thread.read_with(cx, |thread, cx| {
3503 assert_eq!(
3504 thread.to_markdown(cx),
3505 indoc! {"
3506 ## User
3507
3508 hello
3509
3510 ## Assistant
3511
3512 HELLO
3513
3514 "}
3515 );
3516 });
3517 }
3518
3519 async fn run_until_first_tool_call(
3520 thread: &Entity<AcpThread>,
3521 cx: &mut TestAppContext,
3522 ) -> usize {
3523 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3524
3525 let subscription = cx.update(|cx| {
3526 cx.subscribe(thread, move |thread, _, cx| {
3527 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3528 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3529 return tx.try_send(ix).unwrap();
3530 }
3531 }
3532 })
3533 });
3534
3535 select! {
3536 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3537 panic!("Timeout waiting for tool call")
3538 }
3539 ix = rx.next().fuse() => {
3540 drop(subscription);
3541 ix.unwrap()
3542 }
3543 }
3544 }
3545
3546 #[derive(Clone, Default)]
3547 struct FakeAgentConnection {
3548 auth_methods: Vec<acp::AuthMethod>,
3549 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3550 on_user_message: Option<
3551 Rc<
3552 dyn Fn(
3553 acp::PromptRequest,
3554 WeakEntity<AcpThread>,
3555 AsyncApp,
3556 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3557 + 'static,
3558 >,
3559 >,
3560 }
3561
3562 impl FakeAgentConnection {
3563 fn new() -> Self {
3564 Self {
3565 auth_methods: Vec::new(),
3566 on_user_message: None,
3567 sessions: Arc::default(),
3568 }
3569 }
3570
3571 #[expect(unused)]
3572 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3573 self.auth_methods = auth_methods;
3574 self
3575 }
3576
3577 fn on_user_message(
3578 mut self,
3579 handler: impl Fn(
3580 acp::PromptRequest,
3581 WeakEntity<AcpThread>,
3582 AsyncApp,
3583 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3584 + 'static,
3585 ) -> Self {
3586 self.on_user_message.replace(Rc::new(handler));
3587 self
3588 }
3589 }
3590
3591 impl AgentConnection for FakeAgentConnection {
3592 fn auth_methods(&self) -> &[acp::AuthMethod] {
3593 &self.auth_methods
3594 }
3595
3596 fn new_thread(
3597 self: Rc<Self>,
3598 project: Entity<Project>,
3599 _cwd: &Path,
3600 cx: &mut App,
3601 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3602 let session_id = acp::SessionId(
3603 rand::rng()
3604 .sample_iter(&distr::Alphanumeric)
3605 .take(7)
3606 .map(char::from)
3607 .collect::<String>()
3608 .into(),
3609 );
3610 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3611 let thread = cx.new(|cx| {
3612 AcpThread::new(
3613 "Test",
3614 self.clone(),
3615 project,
3616 action_log,
3617 session_id.clone(),
3618 watch::Receiver::constant(acp::PromptCapabilities {
3619 image: true,
3620 audio: true,
3621 embedded_context: true,
3622 meta: None,
3623 }),
3624 cx,
3625 )
3626 });
3627 self.sessions.lock().insert(session_id, thread.downgrade());
3628 Task::ready(Ok(thread))
3629 }
3630
3631 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3632 if self.auth_methods().iter().any(|m| m.id == method) {
3633 Task::ready(Ok(()))
3634 } else {
3635 Task::ready(Err(anyhow!("Invalid Auth Method")))
3636 }
3637 }
3638
3639 fn prompt(
3640 &self,
3641 _id: Option<UserMessageId>,
3642 params: acp::PromptRequest,
3643 cx: &mut App,
3644 ) -> Task<gpui::Result<acp::PromptResponse>> {
3645 let sessions = self.sessions.lock();
3646 let thread = sessions.get(¶ms.session_id).unwrap();
3647 if let Some(handler) = &self.on_user_message {
3648 let handler = handler.clone();
3649 let thread = thread.clone();
3650 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3651 } else {
3652 Task::ready(Ok(acp::PromptResponse {
3653 stop_reason: acp::StopReason::EndTurn,
3654 meta: None,
3655 }))
3656 }
3657 }
3658
3659 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3660 let sessions = self.sessions.lock();
3661 let thread = sessions.get(session_id).unwrap().clone();
3662
3663 cx.spawn(async move |cx| {
3664 thread
3665 .update(cx, |thread, cx| thread.cancel(cx))
3666 .unwrap()
3667 .await
3668 })
3669 .detach();
3670 }
3671
3672 fn truncate(
3673 &self,
3674 session_id: &acp::SessionId,
3675 _cx: &App,
3676 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3677 Some(Rc::new(FakeAgentSessionEditor {
3678 _session_id: session_id.clone(),
3679 }))
3680 }
3681
3682 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3683 self
3684 }
3685 }
3686
3687 struct FakeAgentSessionEditor {
3688 _session_id: acp::SessionId,
3689 }
3690
3691 impl AgentSessionTruncate for FakeAgentSessionEditor {
3692 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3693 Task::ready(Ok(()))
3694 }
3695 }
3696
3697 #[gpui::test]
3698 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3699 init_test(cx);
3700
3701 let fs = FakeFs::new(cx.executor());
3702 let project = Project::test(fs, [], cx).await;
3703 let connection = Rc::new(FakeAgentConnection::new());
3704 let thread = cx
3705 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3706 .await
3707 .unwrap();
3708
3709 // Try to update a tool call that doesn't exist
3710 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3711 thread.update(cx, |thread, cx| {
3712 let result = thread.handle_session_update(
3713 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3714 id: nonexistent_id.clone(),
3715 fields: acp::ToolCallUpdateFields {
3716 status: Some(acp::ToolCallStatus::Completed),
3717 ..Default::default()
3718 },
3719 meta: None,
3720 }),
3721 cx,
3722 );
3723
3724 // The update should succeed (not return an error)
3725 assert!(result.is_ok());
3726
3727 // There should now be exactly one entry in the thread
3728 assert_eq!(thread.entries.len(), 1);
3729
3730 // The entry should be a failed tool call
3731 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3732 assert_eq!(tool_call.id, nonexistent_id);
3733 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3734 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3735
3736 // Check that the content contains the error message
3737 assert_eq!(tool_call.content.len(), 1);
3738 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3739 match content_block {
3740 ContentBlock::Markdown { markdown } => {
3741 let markdown_text = markdown.read(cx).source();
3742 assert!(markdown_text.contains("Tool call not found"));
3743 }
3744 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3745 ContentBlock::ResourceLink { .. } => {
3746 panic!("Expected markdown content, got resource link")
3747 }
3748 }
3749 } else {
3750 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3751 }
3752 } else {
3753 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3754 }
3755 });
3756 }
3757}