1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use ::terminal::terminal_settings::TerminalSettings;
7use agent_settings::AgentSettings;
8use collections::HashSet;
9pub use connection::*;
10pub use diff::*;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::{Settings as _, SettingsLocation};
16use task::{Shell, ShellBuilder};
17pub use terminal::*;
18
19use action_log::ActionLog;
20use agent_client_protocol::{self as acp};
21use anyhow::{Context as _, Result, anyhow};
22use editor::Bias;
23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
25use itertools::Itertools;
26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
27use markdown::Markdown;
28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
29use std::collections::HashMap;
30use std::error::Error;
31use std::fmt::{Formatter, Write};
32use std::ops::Range;
33use std::process::ExitStatus;
34use std::rc::Rc;
35use std::time::{Duration, Instant};
36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
37use ui::App;
38use util::{ResultExt, get_default_system_shell_preferring_bash};
39use uuid::Uuid;
40
41#[derive(Debug)]
42pub struct UserMessage {
43 pub id: Option<UserMessageId>,
44 pub content: ContentBlock,
45 pub chunks: Vec<acp::ContentBlock>,
46 pub checkpoint: Option<Checkpoint>,
47}
48
49#[derive(Debug)]
50pub struct Checkpoint {
51 git_checkpoint: GitStoreCheckpoint,
52 pub show: bool,
53}
54
55impl UserMessage {
56 fn to_markdown(&self, cx: &App) -> String {
57 let mut markdown = String::new();
58 if self
59 .checkpoint
60 .as_ref()
61 .is_some_and(|checkpoint| checkpoint.show)
62 {
63 writeln!(markdown, "## User (checkpoint)").unwrap();
64 } else {
65 writeln!(markdown, "## User").unwrap();
66 }
67 writeln!(markdown).unwrap();
68 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
69 writeln!(markdown).unwrap();
70 markdown
71 }
72}
73
74#[derive(Debug, PartialEq)]
75pub struct AssistantMessage {
76 pub chunks: Vec<AssistantMessageChunk>,
77}
78
79impl AssistantMessage {
80 pub fn to_markdown(&self, cx: &App) -> String {
81 format!(
82 "## Assistant\n\n{}\n\n",
83 self.chunks
84 .iter()
85 .map(|chunk| chunk.to_markdown(cx))
86 .join("\n\n")
87 )
88 }
89}
90
91#[derive(Debug, PartialEq)]
92pub enum AssistantMessageChunk {
93 Message { block: ContentBlock },
94 Thought { block: ContentBlock },
95}
96
97impl AssistantMessageChunk {
98 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
99 Self::Message {
100 block: ContentBlock::new(chunk.into(), language_registry, cx),
101 }
102 }
103
104 fn to_markdown(&self, cx: &App) -> String {
105 match self {
106 Self::Message { block } => block.to_markdown(cx).to_string(),
107 Self::Thought { block } => {
108 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
109 }
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum AgentThreadEntry {
116 UserMessage(UserMessage),
117 AssistantMessage(AssistantMessage),
118 ToolCall(ToolCall),
119}
120
121impl AgentThreadEntry {
122 pub fn to_markdown(&self, cx: &App) -> String {
123 match self {
124 Self::UserMessage(message) => message.to_markdown(cx),
125 Self::AssistantMessage(message) => message.to_markdown(cx),
126 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
127 }
128 }
129
130 pub fn user_message(&self) -> Option<&UserMessage> {
131 if let AgentThreadEntry::UserMessage(message) = self {
132 Some(message)
133 } else {
134 None
135 }
136 }
137
138 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.diffs())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
147 if let AgentThreadEntry::ToolCall(call) = self {
148 itertools::Either::Left(call.terminals())
149 } else {
150 itertools::Either::Right(std::iter::empty())
151 }
152 }
153
154 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
155 if let AgentThreadEntry::ToolCall(ToolCall {
156 locations,
157 resolved_locations,
158 ..
159 }) = self
160 {
161 Some((
162 locations.get(ix)?.clone(),
163 resolved_locations.get(ix)?.clone()?,
164 ))
165 } else {
166 None
167 }
168 }
169}
170
171#[derive(Debug)]
172pub struct ToolCall {
173 pub id: acp::ToolCallId,
174 pub label: Entity<Markdown>,
175 pub kind: acp::ToolKind,
176 pub content: Vec<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179 pub resolved_locations: Vec<Option<AgentLocation>>,
180 pub raw_input: Option<serde_json::Value>,
181 pub raw_output: Option<serde_json::Value>,
182}
183
184impl ToolCall {
185 fn from_acp(
186 tool_call: acp::ToolCall,
187 status: ToolCallStatus,
188 language_registry: Arc<LanguageRegistry>,
189 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
190 cx: &mut App,
191 ) -> Result<Self> {
192 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
193 first_line.to_owned() + "…"
194 } else {
195 tool_call.title
196 };
197 let mut content = Vec::with_capacity(tool_call.content.len());
198 for item in tool_call.content {
199 content.push(ToolCallContent::from_acp(
200 item,
201 language_registry.clone(),
202 terminals,
203 cx,
204 )?);
205 }
206
207 let result = Self {
208 id: tool_call.id,
209 label: cx
210 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
211 kind: tool_call.kind,
212 content,
213 locations: tool_call.locations,
214 resolved_locations: Vec::default(),
215 status,
216 raw_input: tool_call.raw_input,
217 raw_output: tool_call.raw_output,
218 };
219 Ok(result)
220 }
221
222 fn update_fields(
223 &mut self,
224 fields: acp::ToolCallUpdateFields,
225 language_registry: Arc<LanguageRegistry>,
226 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
227 cx: &mut App,
228 ) -> Result<()> {
229 let acp::ToolCallUpdateFields {
230 kind,
231 status,
232 title,
233 content,
234 locations,
235 raw_input,
236 raw_output,
237 } = fields;
238
239 if let Some(kind) = kind {
240 self.kind = kind;
241 }
242
243 if let Some(status) = status {
244 self.status = status.into();
245 }
246
247 if let Some(title) = title {
248 self.label.update(cx, |label, cx| {
249 if let Some((first_line, _)) = title.split_once("\n") {
250 label.replace(first_line.to_owned() + "…", cx)
251 } else {
252 label.replace(title, cx);
253 }
254 });
255 }
256
257 if let Some(content) = content {
258 let new_content_len = content.len();
259 let mut content = content.into_iter();
260
261 // Reuse existing content if we can
262 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
263 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
264 }
265 for new in content {
266 self.content.push(ToolCallContent::from_acp(
267 new,
268 language_registry.clone(),
269 terminals,
270 cx,
271 )?)
272 }
273 self.content.truncate(new_content_len);
274 }
275
276 if let Some(locations) = locations {
277 self.locations = locations;
278 }
279
280 if let Some(raw_input) = raw_input {
281 self.raw_input = Some(raw_input);
282 }
283
284 if let Some(raw_output) = raw_output {
285 if self.content.is_empty()
286 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
287 {
288 self.content
289 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
290 markdown,
291 }));
292 }
293 self.raw_output = Some(raw_output);
294 }
295 Ok(())
296 }
297
298 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
299 self.content.iter().filter_map(|content| match content {
300 ToolCallContent::Diff(diff) => Some(diff),
301 ToolCallContent::ContentBlock(_) => None,
302 ToolCallContent::Terminal(_) => None,
303 })
304 }
305
306 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
307 self.content.iter().filter_map(|content| match content {
308 ToolCallContent::Terminal(terminal) => Some(terminal),
309 ToolCallContent::ContentBlock(_) => None,
310 ToolCallContent::Diff(_) => None,
311 })
312 }
313
314 fn to_markdown(&self, cx: &App) -> String {
315 let mut markdown = format!(
316 "**Tool Call: {}**\nStatus: {}\n\n",
317 self.label.read(cx).source(),
318 self.status
319 );
320 for content in &self.content {
321 markdown.push_str(content.to_markdown(cx).as_str());
322 markdown.push_str("\n\n");
323 }
324 markdown
325 }
326
327 async fn resolve_location(
328 location: acp::ToolCallLocation,
329 project: WeakEntity<Project>,
330 cx: &mut AsyncApp,
331 ) -> Option<ResolvedLocation> {
332 let buffer = project
333 .update(cx, |project, cx| {
334 project
335 .project_path_for_absolute_path(&location.path, cx)
336 .map(|path| project.open_buffer(path, cx))
337 })
338 .ok()??;
339 let buffer = buffer.await.log_err()?;
340 let position = buffer
341 .update(cx, |buffer, _| {
342 if let Some(row) = location.line {
343 let snapshot = buffer.snapshot();
344 let column = snapshot.indent_size_for_line(row).len;
345 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
346 snapshot.anchor_before(point)
347 } else {
348 Anchor::MIN
349 }
350 })
351 .ok()?;
352
353 Some(ResolvedLocation { buffer, position })
354 }
355
356 fn resolve_locations(
357 &self,
358 project: Entity<Project>,
359 cx: &mut App,
360 ) -> Task<Vec<Option<ResolvedLocation>>> {
361 let locations = self.locations.clone();
362 project.update(cx, |_, cx| {
363 cx.spawn(async move |project, cx| {
364 let mut new_locations = Vec::new();
365 for location in locations {
366 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
367 }
368 new_locations
369 })
370 })
371 }
372}
373
374// Separate so we can hold a strong reference to the buffer
375// for saving on the thread
376#[derive(Clone, Debug, PartialEq, Eq)]
377struct ResolvedLocation {
378 buffer: Entity<Buffer>,
379 position: Anchor,
380}
381
382impl From<&ResolvedLocation> for AgentLocation {
383 fn from(value: &ResolvedLocation) -> Self {
384 Self {
385 buffer: value.buffer.downgrade(),
386 position: value.position,
387 }
388 }
389}
390
391#[derive(Debug)]
392pub enum ToolCallStatus {
393 /// The tool call hasn't started running yet, but we start showing it to
394 /// the user.
395 Pending,
396 /// The tool call is waiting for confirmation from the user.
397 WaitingForConfirmation {
398 options: Vec<acp::PermissionOption>,
399 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
400 },
401 /// The tool call is currently running.
402 InProgress,
403 /// The tool call completed successfully.
404 Completed,
405 /// The tool call failed.
406 Failed,
407 /// The user rejected the tool call.
408 Rejected,
409 /// The user canceled generation so the tool call was canceled.
410 Canceled,
411}
412
413impl From<acp::ToolCallStatus> for ToolCallStatus {
414 fn from(status: acp::ToolCallStatus) -> Self {
415 match status {
416 acp::ToolCallStatus::Pending => Self::Pending,
417 acp::ToolCallStatus::InProgress => Self::InProgress,
418 acp::ToolCallStatus::Completed => Self::Completed,
419 acp::ToolCallStatus::Failed => Self::Failed,
420 }
421 }
422}
423
424impl Display for ToolCallStatus {
425 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
426 write!(
427 f,
428 "{}",
429 match self {
430 ToolCallStatus::Pending => "Pending",
431 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
432 ToolCallStatus::InProgress => "In Progress",
433 ToolCallStatus::Completed => "Completed",
434 ToolCallStatus::Failed => "Failed",
435 ToolCallStatus::Rejected => "Rejected",
436 ToolCallStatus::Canceled => "Canceled",
437 }
438 )
439 }
440}
441
442#[derive(Debug, PartialEq, Clone)]
443pub enum ContentBlock {
444 Empty,
445 Markdown { markdown: Entity<Markdown> },
446 ResourceLink { resource_link: acp::ResourceLink },
447}
448
449impl ContentBlock {
450 pub fn new(
451 block: acp::ContentBlock,
452 language_registry: &Arc<LanguageRegistry>,
453 cx: &mut App,
454 ) -> Self {
455 let mut this = Self::Empty;
456 this.append(block, language_registry, cx);
457 this
458 }
459
460 pub fn new_combined(
461 blocks: impl IntoIterator<Item = acp::ContentBlock>,
462 language_registry: Arc<LanguageRegistry>,
463 cx: &mut App,
464 ) -> Self {
465 let mut this = Self::Empty;
466 for block in blocks {
467 this.append(block, &language_registry, cx);
468 }
469 this
470 }
471
472 pub fn append(
473 &mut self,
474 block: acp::ContentBlock,
475 language_registry: &Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) {
478 if matches!(self, ContentBlock::Empty)
479 && let acp::ContentBlock::ResourceLink(resource_link) = block
480 {
481 *self = ContentBlock::ResourceLink { resource_link };
482 return;
483 }
484
485 let new_content = self.block_string_contents(block);
486
487 match self {
488 ContentBlock::Empty => {
489 *self = Self::create_markdown_block(new_content, language_registry, cx);
490 }
491 ContentBlock::Markdown { markdown } => {
492 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
493 }
494 ContentBlock::ResourceLink { resource_link } => {
495 let existing_content = Self::resource_link_md(&resource_link.uri);
496 let combined = format!("{}\n{}", existing_content, new_content);
497
498 *self = Self::create_markdown_block(combined, language_registry, cx);
499 }
500 }
501 }
502
503 fn create_markdown_block(
504 content: String,
505 language_registry: &Arc<LanguageRegistry>,
506 cx: &mut App,
507 ) -> ContentBlock {
508 ContentBlock::Markdown {
509 markdown: cx
510 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
511 }
512 }
513
514 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
515 match block {
516 acp::ContentBlock::Text(text_content) => text_content.text,
517 acp::ContentBlock::ResourceLink(resource_link) => {
518 Self::resource_link_md(&resource_link.uri)
519 }
520 acp::ContentBlock::Resource(acp::EmbeddedResource {
521 resource:
522 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
523 uri,
524 ..
525 }),
526 ..
527 }) => Self::resource_link_md(&uri),
528 acp::ContentBlock::Image(image) => Self::image_md(&image),
529 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
530 }
531 }
532
533 fn resource_link_md(uri: &str) -> String {
534 if let Some(uri) = MentionUri::parse(uri).log_err() {
535 uri.as_link().to_string()
536 } else {
537 uri.to_string()
538 }
539 }
540
541 fn image_md(_image: &acp::ImageContent) -> String {
542 "`Image`".into()
543 }
544
545 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
546 match self {
547 ContentBlock::Empty => "",
548 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
549 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
550 }
551 }
552
553 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
554 match self {
555 ContentBlock::Empty => None,
556 ContentBlock::Markdown { markdown } => Some(markdown),
557 ContentBlock::ResourceLink { .. } => None,
558 }
559 }
560
561 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
562 match self {
563 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
564 _ => None,
565 }
566 }
567}
568
569#[derive(Debug)]
570pub enum ToolCallContent {
571 ContentBlock(ContentBlock),
572 Diff(Entity<Diff>),
573 Terminal(Entity<Terminal>),
574}
575
576impl ToolCallContent {
577 pub fn from_acp(
578 content: acp::ToolCallContent,
579 language_registry: Arc<LanguageRegistry>,
580 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
581 cx: &mut App,
582 ) -> Result<Self> {
583 match content {
584 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
585 content,
586 &language_registry,
587 cx,
588 ))),
589 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
590 Diff::finalized(
591 diff.path.to_string_lossy().into_owned(),
592 diff.old_text,
593 diff.new_text,
594 language_registry,
595 cx,
596 )
597 }))),
598 acp::ToolCallContent::Terminal { terminal_id } => terminals
599 .get(&terminal_id)
600 .cloned()
601 .map(Self::Terminal)
602 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
603 }
604 }
605
606 pub fn update_from_acp(
607 &mut self,
608 new: acp::ToolCallContent,
609 language_registry: Arc<LanguageRegistry>,
610 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
611 cx: &mut App,
612 ) -> Result<()> {
613 let needs_update = match (&self, &new) {
614 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
615 old_diff.read(cx).needs_update(
616 new_diff.old_text.as_deref().unwrap_or(""),
617 &new_diff.new_text,
618 cx,
619 )
620 }
621 _ => true,
622 };
623
624 if needs_update {
625 *self = Self::from_acp(new, language_registry, terminals, cx)?;
626 }
627 Ok(())
628 }
629
630 pub fn to_markdown(&self, cx: &App) -> String {
631 match self {
632 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
633 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
634 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
635 }
636 }
637}
638
639#[derive(Debug, PartialEq)]
640pub enum ToolCallUpdate {
641 UpdateFields(acp::ToolCallUpdate),
642 UpdateDiff(ToolCallUpdateDiff),
643 UpdateTerminal(ToolCallUpdateTerminal),
644}
645
646impl ToolCallUpdate {
647 fn id(&self) -> &acp::ToolCallId {
648 match self {
649 Self::UpdateFields(update) => &update.id,
650 Self::UpdateDiff(diff) => &diff.id,
651 Self::UpdateTerminal(terminal) => &terminal.id,
652 }
653 }
654}
655
656impl From<acp::ToolCallUpdate> for ToolCallUpdate {
657 fn from(update: acp::ToolCallUpdate) -> Self {
658 Self::UpdateFields(update)
659 }
660}
661
662impl From<ToolCallUpdateDiff> for ToolCallUpdate {
663 fn from(diff: ToolCallUpdateDiff) -> Self {
664 Self::UpdateDiff(diff)
665 }
666}
667
668#[derive(Debug, PartialEq)]
669pub struct ToolCallUpdateDiff {
670 pub id: acp::ToolCallId,
671 pub diff: Entity<Diff>,
672}
673
674impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
675 fn from(terminal: ToolCallUpdateTerminal) -> Self {
676 Self::UpdateTerminal(terminal)
677 }
678}
679
680#[derive(Debug, PartialEq)]
681pub struct ToolCallUpdateTerminal {
682 pub id: acp::ToolCallId,
683 pub terminal: Entity<Terminal>,
684}
685
686#[derive(Debug, Default)]
687pub struct Plan {
688 pub entries: Vec<PlanEntry>,
689}
690
691#[derive(Debug)]
692pub struct PlanStats<'a> {
693 pub in_progress_entry: Option<&'a PlanEntry>,
694 pub pending: u32,
695 pub completed: u32,
696}
697
698impl Plan {
699 pub fn is_empty(&self) -> bool {
700 self.entries.is_empty()
701 }
702
703 pub fn stats(&self) -> PlanStats<'_> {
704 let mut stats = PlanStats {
705 in_progress_entry: None,
706 pending: 0,
707 completed: 0,
708 };
709
710 for entry in &self.entries {
711 match &entry.status {
712 acp::PlanEntryStatus::Pending => {
713 stats.pending += 1;
714 }
715 acp::PlanEntryStatus::InProgress => {
716 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
717 }
718 acp::PlanEntryStatus::Completed => {
719 stats.completed += 1;
720 }
721 }
722 }
723
724 stats
725 }
726}
727
728#[derive(Debug)]
729pub struct PlanEntry {
730 pub content: Entity<Markdown>,
731 pub priority: acp::PlanEntryPriority,
732 pub status: acp::PlanEntryStatus,
733}
734
735impl PlanEntry {
736 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
737 Self {
738 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
739 priority: entry.priority,
740 status: entry.status,
741 }
742 }
743}
744
745#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
746pub struct TokenUsage {
747 pub max_tokens: u64,
748 pub used_tokens: u64,
749}
750
751impl TokenUsage {
752 pub fn ratio(&self) -> TokenUsageRatio {
753 #[cfg(debug_assertions)]
754 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
755 .unwrap_or("0.8".to_string())
756 .parse()
757 .unwrap();
758 #[cfg(not(debug_assertions))]
759 let warning_threshold: f32 = 0.8;
760
761 // When the maximum is unknown because there is no selected model,
762 // avoid showing the token limit warning.
763 if self.max_tokens == 0 {
764 TokenUsageRatio::Normal
765 } else if self.used_tokens >= self.max_tokens {
766 TokenUsageRatio::Exceeded
767 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
768 TokenUsageRatio::Warning
769 } else {
770 TokenUsageRatio::Normal
771 }
772 }
773}
774
775#[derive(Debug, Clone, PartialEq, Eq)]
776pub enum TokenUsageRatio {
777 Normal,
778 Warning,
779 Exceeded,
780}
781
782#[derive(Debug, Clone)]
783pub struct RetryStatus {
784 pub last_error: SharedString,
785 pub attempt: usize,
786 pub max_attempts: usize,
787 pub started_at: Instant,
788 pub duration: Duration,
789}
790
791pub struct AcpThread {
792 title: SharedString,
793 entries: Vec<AgentThreadEntry>,
794 plan: Plan,
795 project: Entity<Project>,
796 action_log: Entity<ActionLog>,
797 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
798 send_task: Option<Task<()>>,
799 connection: Rc<dyn AgentConnection>,
800 session_id: acp::SessionId,
801 token_usage: Option<TokenUsage>,
802 prompt_capabilities: acp::PromptCapabilities,
803 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
804 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
805 pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
806 pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
807}
808
809#[derive(Debug)]
810pub enum AcpThreadEvent {
811 NewEntry,
812 TitleUpdated,
813 TokenUsageUpdated,
814 EntryUpdated(usize),
815 EntriesRemoved(Range<usize>),
816 ToolAuthorizationRequired,
817 Retry(RetryStatus),
818 Stopped,
819 Error,
820 LoadError(LoadError),
821 PromptCapabilitiesUpdated,
822 Refusal,
823 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
824 ModeUpdated(acp::SessionModeId),
825}
826
827impl EventEmitter<AcpThreadEvent> for AcpThread {}
828
829#[derive(Debug, Clone)]
830pub enum TerminalProviderEvent {
831 Created {
832 terminal_id: acp::TerminalId,
833 label: String,
834 cwd: Option<PathBuf>,
835 output_byte_limit: Option<u64>,
836 terminal: Entity<::terminal::Terminal>,
837 },
838 Output {
839 terminal_id: acp::TerminalId,
840 data: Vec<u8>,
841 },
842 TitleChanged {
843 terminal_id: acp::TerminalId,
844 title: String,
845 },
846 Exit {
847 terminal_id: acp::TerminalId,
848 status: acp::TerminalExitStatus,
849 },
850}
851
852#[derive(Debug, Clone)]
853pub enum TerminalProviderCommand {
854 WriteInput {
855 terminal_id: acp::TerminalId,
856 bytes: Vec<u8>,
857 },
858 Resize {
859 terminal_id: acp::TerminalId,
860 cols: u16,
861 rows: u16,
862 },
863 Close {
864 terminal_id: acp::TerminalId,
865 },
866}
867
868impl AcpThread {
869 pub fn on_terminal_provider_event(
870 &mut self,
871 event: TerminalProviderEvent,
872 cx: &mut Context<Self>,
873 ) {
874 match event {
875 TerminalProviderEvent::Created {
876 terminal_id,
877 label,
878 cwd,
879 output_byte_limit,
880 terminal,
881 } => {
882 let entity = self.register_terminal_created(
883 terminal_id.clone(),
884 label,
885 cwd,
886 output_byte_limit,
887 terminal,
888 cx,
889 );
890
891 if let Some(mut chunks) = self.pending_terminal_output.remove(&terminal_id) {
892 for data in chunks.drain(..) {
893 entity.update(cx, |term, cx| {
894 term.inner().update(cx, |inner, cx| {
895 inner.write_output(&data, cx);
896 })
897 });
898 }
899 }
900
901 if let Some(_status) = self.pending_terminal_exit.remove(&terminal_id) {
902 entity.update(cx, |_term, cx| {
903 cx.notify();
904 });
905 }
906
907 cx.notify();
908 }
909 TerminalProviderEvent::Output { terminal_id, data } => {
910 if let Some(entity) = self.terminals.get(&terminal_id) {
911 entity.update(cx, |term, cx| {
912 term.inner().update(cx, |inner, cx| {
913 inner.write_output(&data, cx);
914 })
915 });
916 } else {
917 self.pending_terminal_output
918 .entry(terminal_id)
919 .or_default()
920 .push(data);
921 }
922 }
923 TerminalProviderEvent::TitleChanged { terminal_id, title } => {
924 if let Some(entity) = self.terminals.get(&terminal_id) {
925 entity.update(cx, |term, cx| {
926 term.inner().update(cx, |inner, cx| {
927 inner.breadcrumb_text = title;
928 cx.emit(::terminal::Event::BreadcrumbsChanged);
929 })
930 });
931 }
932 }
933 TerminalProviderEvent::Exit {
934 terminal_id,
935 status,
936 } => {
937 if let Some(entity) = self.terminals.get(&terminal_id) {
938 entity.update(cx, |_term, cx| {
939 cx.notify();
940 });
941 } else {
942 self.pending_terminal_exit.insert(terminal_id, status);
943 }
944 }
945 }
946 }
947}
948
949#[derive(PartialEq, Eq, Debug)]
950pub enum ThreadStatus {
951 Idle,
952 Generating,
953}
954
955#[derive(Debug, Clone)]
956pub enum LoadError {
957 Unsupported {
958 command: SharedString,
959 current_version: SharedString,
960 minimum_version: SharedString,
961 },
962 FailedToInstall(SharedString),
963 Exited {
964 status: ExitStatus,
965 },
966 Other(SharedString),
967}
968
969impl Display for LoadError {
970 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
971 match self {
972 LoadError::Unsupported {
973 command: path,
974 current_version,
975 minimum_version,
976 } => {
977 write!(
978 f,
979 "version {current_version} from {path} is not supported (need at least {minimum_version})"
980 )
981 }
982 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
983 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
984 LoadError::Other(msg) => write!(f, "{msg}"),
985 }
986 }
987}
988
989impl Error for LoadError {}
990
991impl AcpThread {
992 pub fn new(
993 title: impl Into<SharedString>,
994 connection: Rc<dyn AgentConnection>,
995 project: Entity<Project>,
996 action_log: Entity<ActionLog>,
997 session_id: acp::SessionId,
998 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
999 cx: &mut Context<Self>,
1000 ) -> Self {
1001 let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
1002 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
1003 loop {
1004 let caps = prompt_capabilities_rx.recv().await?;
1005 this.update(cx, |this, cx| {
1006 this.prompt_capabilities = caps;
1007 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
1008 })?;
1009 }
1010 });
1011
1012 Self {
1013 action_log,
1014 shared_buffers: Default::default(),
1015 entries: Default::default(),
1016 plan: Default::default(),
1017 title: title.into(),
1018 project,
1019 send_task: None,
1020 connection,
1021 session_id,
1022 token_usage: None,
1023 prompt_capabilities,
1024 _observe_prompt_capabilities: task,
1025 terminals: HashMap::default(),
1026 pending_terminal_output: HashMap::default(),
1027 pending_terminal_exit: HashMap::default(),
1028 }
1029 }
1030
1031 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
1032 self.prompt_capabilities.clone()
1033 }
1034
1035 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
1036 &self.connection
1037 }
1038
1039 pub fn action_log(&self) -> &Entity<ActionLog> {
1040 &self.action_log
1041 }
1042
1043 pub fn project(&self) -> &Entity<Project> {
1044 &self.project
1045 }
1046
1047 pub fn title(&self) -> SharedString {
1048 self.title.clone()
1049 }
1050
1051 pub fn entries(&self) -> &[AgentThreadEntry] {
1052 &self.entries
1053 }
1054
1055 pub fn session_id(&self) -> &acp::SessionId {
1056 &self.session_id
1057 }
1058
1059 pub fn status(&self) -> ThreadStatus {
1060 if self.send_task.is_some() {
1061 ThreadStatus::Generating
1062 } else {
1063 ThreadStatus::Idle
1064 }
1065 }
1066
1067 pub fn token_usage(&self) -> Option<&TokenUsage> {
1068 self.token_usage.as_ref()
1069 }
1070
1071 pub fn has_pending_edit_tool_calls(&self) -> bool {
1072 for entry in self.entries.iter().rev() {
1073 match entry {
1074 AgentThreadEntry::UserMessage(_) => return false,
1075 AgentThreadEntry::ToolCall(
1076 call @ ToolCall {
1077 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
1078 ..
1079 },
1080 ) if call.diffs().next().is_some() => {
1081 return true;
1082 }
1083 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
1084 }
1085 }
1086
1087 false
1088 }
1089
1090 pub fn used_tools_since_last_user_message(&self) -> bool {
1091 for entry in self.entries.iter().rev() {
1092 match entry {
1093 AgentThreadEntry::UserMessage(..) => return false,
1094 AgentThreadEntry::AssistantMessage(..) => continue,
1095 AgentThreadEntry::ToolCall(..) => return true,
1096 }
1097 }
1098
1099 false
1100 }
1101
1102 pub fn handle_session_update(
1103 &mut self,
1104 update: acp::SessionUpdate,
1105 cx: &mut Context<Self>,
1106 ) -> Result<(), acp::Error> {
1107 match update {
1108 acp::SessionUpdate::UserMessageChunk(acp::ContentChunk { content, .. }) => {
1109 self.push_user_content_block(None, content, cx);
1110 }
1111 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk { content, .. }) => {
1112 self.push_assistant_content_block(content, false, cx);
1113 }
1114 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk { content, .. }) => {
1115 self.push_assistant_content_block(content, true, cx);
1116 }
1117 acp::SessionUpdate::ToolCall(tool_call) => {
1118 self.upsert_tool_call(tool_call, cx)?;
1119 }
1120 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1121 self.update_tool_call(tool_call_update, cx)?;
1122 }
1123 acp::SessionUpdate::Plan(plan) => {
1124 self.update_plan(plan, cx);
1125 }
1126 acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
1127 available_commands,
1128 ..
1129 }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
1130 acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
1131 current_mode_id,
1132 ..
1133 }) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
1134 }
1135 Ok(())
1136 }
1137
1138 pub fn push_user_content_block(
1139 &mut self,
1140 message_id: Option<UserMessageId>,
1141 chunk: acp::ContentBlock,
1142 cx: &mut Context<Self>,
1143 ) {
1144 let language_registry = self.project.read(cx).languages().clone();
1145 let entries_len = self.entries.len();
1146
1147 if let Some(last_entry) = self.entries.last_mut()
1148 && let AgentThreadEntry::UserMessage(UserMessage {
1149 id,
1150 content,
1151 chunks,
1152 ..
1153 }) = last_entry
1154 {
1155 *id = message_id.or(id.take());
1156 content.append(chunk.clone(), &language_registry, cx);
1157 chunks.push(chunk);
1158 let idx = entries_len - 1;
1159 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1160 } else {
1161 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1162 self.push_entry(
1163 AgentThreadEntry::UserMessage(UserMessage {
1164 id: message_id,
1165 content,
1166 chunks: vec![chunk],
1167 checkpoint: None,
1168 }),
1169 cx,
1170 );
1171 }
1172 }
1173
1174 pub fn push_assistant_content_block(
1175 &mut self,
1176 chunk: acp::ContentBlock,
1177 is_thought: bool,
1178 cx: &mut Context<Self>,
1179 ) {
1180 let language_registry = self.project.read(cx).languages().clone();
1181 let entries_len = self.entries.len();
1182 if let Some(last_entry) = self.entries.last_mut()
1183 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1184 {
1185 let idx = entries_len - 1;
1186 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1187 match (chunks.last_mut(), is_thought) {
1188 (Some(AssistantMessageChunk::Message { block }), false)
1189 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1190 block.append(chunk, &language_registry, cx)
1191 }
1192 _ => {
1193 let block = ContentBlock::new(chunk, &language_registry, cx);
1194 if is_thought {
1195 chunks.push(AssistantMessageChunk::Thought { block })
1196 } else {
1197 chunks.push(AssistantMessageChunk::Message { block })
1198 }
1199 }
1200 }
1201 } else {
1202 let block = ContentBlock::new(chunk, &language_registry, cx);
1203 let chunk = if is_thought {
1204 AssistantMessageChunk::Thought { block }
1205 } else {
1206 AssistantMessageChunk::Message { block }
1207 };
1208
1209 self.push_entry(
1210 AgentThreadEntry::AssistantMessage(AssistantMessage {
1211 chunks: vec![chunk],
1212 }),
1213 cx,
1214 );
1215 }
1216 }
1217
1218 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1219 self.entries.push(entry);
1220 cx.emit(AcpThreadEvent::NewEntry);
1221 }
1222
1223 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1224 self.connection.set_title(&self.session_id, cx).is_some()
1225 }
1226
1227 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1228 if title != self.title {
1229 self.title = title.clone();
1230 cx.emit(AcpThreadEvent::TitleUpdated);
1231 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1232 return set_title.run(title, cx);
1233 }
1234 }
1235 Task::ready(Ok(()))
1236 }
1237
1238 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1239 self.token_usage = usage;
1240 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1241 }
1242
1243 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1244 cx.emit(AcpThreadEvent::Retry(status));
1245 }
1246
1247 pub fn update_tool_call(
1248 &mut self,
1249 update: impl Into<ToolCallUpdate>,
1250 cx: &mut Context<Self>,
1251 ) -> Result<()> {
1252 let update = update.into();
1253 let languages = self.project.read(cx).languages().clone();
1254
1255 let ix = match self.index_for_tool_call(update.id()) {
1256 Some(ix) => ix,
1257 None => {
1258 // Tool call not found - create a failed tool call entry
1259 let failed_tool_call = ToolCall {
1260 id: update.id().clone(),
1261 label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1262 kind: acp::ToolKind::Fetch,
1263 content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1264 acp::ContentBlock::Text(acp::TextContent {
1265 text: "Tool call not found".to_string(),
1266 annotations: None,
1267 meta: None,
1268 }),
1269 &languages,
1270 cx,
1271 ))],
1272 status: ToolCallStatus::Failed,
1273 locations: Vec::new(),
1274 resolved_locations: Vec::new(),
1275 raw_input: None,
1276 raw_output: None,
1277 };
1278 self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1279 return Ok(());
1280 }
1281 };
1282 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1283 unreachable!()
1284 };
1285
1286 match update {
1287 ToolCallUpdate::UpdateFields(update) => {
1288 let location_updated = update.fields.locations.is_some();
1289 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1290 if location_updated {
1291 self.resolve_locations(update.id, cx);
1292 }
1293 }
1294 ToolCallUpdate::UpdateDiff(update) => {
1295 call.content.clear();
1296 call.content.push(ToolCallContent::Diff(update.diff));
1297 }
1298 ToolCallUpdate::UpdateTerminal(update) => {
1299 call.content.clear();
1300 call.content
1301 .push(ToolCallContent::Terminal(update.terminal));
1302 }
1303 }
1304
1305 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1306
1307 Ok(())
1308 }
1309
1310 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1311 pub fn upsert_tool_call(
1312 &mut self,
1313 tool_call: acp::ToolCall,
1314 cx: &mut Context<Self>,
1315 ) -> Result<(), acp::Error> {
1316 let status = tool_call.status.into();
1317 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1318 }
1319
1320 /// Fails if id does not match an existing entry.
1321 pub fn upsert_tool_call_inner(
1322 &mut self,
1323 update: acp::ToolCallUpdate,
1324 status: ToolCallStatus,
1325 cx: &mut Context<Self>,
1326 ) -> Result<(), acp::Error> {
1327 let language_registry = self.project.read(cx).languages().clone();
1328 let id = update.id.clone();
1329
1330 if let Some(ix) = self.index_for_tool_call(&id) {
1331 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1332 unreachable!()
1333 };
1334
1335 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1336 call.status = status;
1337
1338 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1339 } else {
1340 let call = ToolCall::from_acp(
1341 update.try_into()?,
1342 status,
1343 language_registry,
1344 &self.terminals,
1345 cx,
1346 )?;
1347 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1348 };
1349
1350 self.resolve_locations(id, cx);
1351 Ok(())
1352 }
1353
1354 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1355 self.entries
1356 .iter()
1357 .enumerate()
1358 .rev()
1359 .find_map(|(index, entry)| {
1360 if let AgentThreadEntry::ToolCall(tool_call) = entry
1361 && &tool_call.id == id
1362 {
1363 Some(index)
1364 } else {
1365 None
1366 }
1367 })
1368 }
1369
1370 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1371 // The tool call we are looking for is typically the last one, or very close to the end.
1372 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1373 self.entries
1374 .iter_mut()
1375 .enumerate()
1376 .rev()
1377 .find_map(|(index, tool_call)| {
1378 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1379 && &tool_call.id == id
1380 {
1381 Some((index, tool_call))
1382 } else {
1383 None
1384 }
1385 })
1386 }
1387
1388 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1389 self.entries
1390 .iter()
1391 .enumerate()
1392 .rev()
1393 .find_map(|(index, tool_call)| {
1394 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1395 && &tool_call.id == id
1396 {
1397 Some((index, tool_call))
1398 } else {
1399 None
1400 }
1401 })
1402 }
1403
1404 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1405 let project = self.project.clone();
1406 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1407 return;
1408 };
1409 let task = tool_call.resolve_locations(project, cx);
1410 cx.spawn(async move |this, cx| {
1411 let resolved_locations = task.await;
1412
1413 this.update(cx, |this, cx| {
1414 let project = this.project.clone();
1415
1416 for location in resolved_locations.iter().flatten() {
1417 this.shared_buffers
1418 .insert(location.buffer.clone(), location.buffer.read(cx).snapshot());
1419 }
1420 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1421 return;
1422 };
1423
1424 if let Some(Some(location)) = resolved_locations.last() {
1425 project.update(cx, |project, cx| {
1426 let should_ignore = if let Some(agent_location) = project
1427 .agent_location()
1428 .filter(|agent_location| agent_location.buffer == location.buffer)
1429 {
1430 let snapshot = location.buffer.read(cx).snapshot();
1431 let old_position = agent_location.position.to_point(&snapshot);
1432 let new_position = location.position.to_point(&snapshot);
1433
1434 // ignore this so that when we get updates from the edit tool
1435 // the position doesn't reset to the startof line
1436 old_position.row == new_position.row
1437 && old_position.column > new_position.column
1438 } else {
1439 false
1440 };
1441 if !should_ignore {
1442 project.set_agent_location(Some(location.into()), cx);
1443 }
1444 });
1445 }
1446
1447 let resolved_locations = resolved_locations
1448 .iter()
1449 .map(|l| l.as_ref().map(|l| AgentLocation::from(l)))
1450 .collect::<Vec<_>>();
1451
1452 if tool_call.resolved_locations != resolved_locations {
1453 tool_call.resolved_locations = resolved_locations;
1454 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1455 }
1456 })
1457 })
1458 .detach();
1459 }
1460
1461 pub fn request_tool_call_authorization(
1462 &mut self,
1463 tool_call: acp::ToolCallUpdate,
1464 options: Vec<acp::PermissionOption>,
1465 respect_always_allow_setting: bool,
1466 cx: &mut Context<Self>,
1467 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1468 let (tx, rx) = oneshot::channel();
1469
1470 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1471 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1472 // some tools would (incorrectly) continue to auto-accept.
1473 if let Some(allow_once_option) = options.iter().find_map(|option| {
1474 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1475 Some(option.id.clone())
1476 } else {
1477 None
1478 }
1479 }) {
1480 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1481 return Ok(async {
1482 acp::RequestPermissionOutcome::Selected {
1483 option_id: allow_once_option,
1484 }
1485 }
1486 .boxed());
1487 }
1488 }
1489
1490 let status = ToolCallStatus::WaitingForConfirmation {
1491 options,
1492 respond_tx: tx,
1493 };
1494
1495 self.upsert_tool_call_inner(tool_call, status, cx)?;
1496 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1497
1498 let fut = async {
1499 match rx.await {
1500 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1501 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1502 }
1503 }
1504 .boxed();
1505
1506 Ok(fut)
1507 }
1508
1509 pub fn authorize_tool_call(
1510 &mut self,
1511 id: acp::ToolCallId,
1512 option_id: acp::PermissionOptionId,
1513 option_kind: acp::PermissionOptionKind,
1514 cx: &mut Context<Self>,
1515 ) {
1516 let Some((ix, call)) = self.tool_call_mut(&id) else {
1517 return;
1518 };
1519
1520 let new_status = match option_kind {
1521 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1522 ToolCallStatus::Rejected
1523 }
1524 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1525 ToolCallStatus::InProgress
1526 }
1527 };
1528
1529 let curr_status = mem::replace(&mut call.status, new_status);
1530
1531 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1532 respond_tx.send(option_id).log_err();
1533 } else if cfg!(debug_assertions) {
1534 panic!("tried to authorize an already authorized tool call");
1535 }
1536
1537 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1538 }
1539
1540 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1541 let mut first_tool_call = None;
1542
1543 for entry in self.entries.iter().rev() {
1544 match &entry {
1545 AgentThreadEntry::ToolCall(call) => {
1546 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1547 first_tool_call = Some(call);
1548 } else {
1549 continue;
1550 }
1551 }
1552 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1553 // Reached the beginning of the turn.
1554 // If we had pending permission requests in the previous turn, they have been cancelled.
1555 break;
1556 }
1557 }
1558 }
1559
1560 first_tool_call
1561 }
1562
1563 pub fn plan(&self) -> &Plan {
1564 &self.plan
1565 }
1566
1567 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1568 let new_entries_len = request.entries.len();
1569 let mut new_entries = request.entries.into_iter();
1570
1571 // Reuse existing markdown to prevent flickering
1572 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1573 let PlanEntry {
1574 content,
1575 priority,
1576 status,
1577 } = old;
1578 content.update(cx, |old, cx| {
1579 old.replace(new.content, cx);
1580 });
1581 *priority = new.priority;
1582 *status = new.status;
1583 }
1584 for new in new_entries {
1585 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1586 }
1587 self.plan.entries.truncate(new_entries_len);
1588
1589 cx.notify();
1590 }
1591
1592 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1593 self.plan
1594 .entries
1595 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1596 cx.notify();
1597 }
1598
1599 #[cfg(any(test, feature = "test-support"))]
1600 pub fn send_raw(
1601 &mut self,
1602 message: &str,
1603 cx: &mut Context<Self>,
1604 ) -> BoxFuture<'static, Result<()>> {
1605 self.send(
1606 vec![acp::ContentBlock::Text(acp::TextContent {
1607 text: message.to_string(),
1608 annotations: None,
1609 meta: None,
1610 })],
1611 cx,
1612 )
1613 }
1614
1615 pub fn send(
1616 &mut self,
1617 message: Vec<acp::ContentBlock>,
1618 cx: &mut Context<Self>,
1619 ) -> BoxFuture<'static, Result<()>> {
1620 let block = ContentBlock::new_combined(
1621 message.clone(),
1622 self.project.read(cx).languages().clone(),
1623 cx,
1624 );
1625 let request = acp::PromptRequest {
1626 prompt: message.clone(),
1627 session_id: self.session_id.clone(),
1628 meta: None,
1629 };
1630 let git_store = self.project.read(cx).git_store().clone();
1631
1632 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1633 Some(UserMessageId::new())
1634 } else {
1635 None
1636 };
1637
1638 self.run_turn(cx, async move |this, cx| {
1639 this.update(cx, |this, cx| {
1640 this.push_entry(
1641 AgentThreadEntry::UserMessage(UserMessage {
1642 id: message_id.clone(),
1643 content: block,
1644 chunks: message,
1645 checkpoint: None,
1646 }),
1647 cx,
1648 );
1649 })
1650 .ok();
1651
1652 let old_checkpoint = git_store
1653 .update(cx, |git, cx| git.checkpoint(cx))?
1654 .await
1655 .context("failed to get old checkpoint")
1656 .log_err();
1657 this.update(cx, |this, cx| {
1658 if let Some((_ix, message)) = this.last_user_message() {
1659 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1660 git_checkpoint,
1661 show: false,
1662 });
1663 }
1664 this.connection.prompt(message_id, request, cx)
1665 })?
1666 .await
1667 })
1668 }
1669
1670 pub fn can_resume(&self, cx: &App) -> bool {
1671 self.connection.resume(&self.session_id, cx).is_some()
1672 }
1673
1674 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1675 self.run_turn(cx, async move |this, cx| {
1676 this.update(cx, |this, cx| {
1677 this.connection
1678 .resume(&this.session_id, cx)
1679 .map(|resume| resume.run(cx))
1680 })?
1681 .context("resuming a session is not supported")?
1682 .await
1683 })
1684 }
1685
1686 fn run_turn(
1687 &mut self,
1688 cx: &mut Context<Self>,
1689 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1690 ) -> BoxFuture<'static, Result<()>> {
1691 self.clear_completed_plan_entries(cx);
1692
1693 let (tx, rx) = oneshot::channel();
1694 let cancel_task = self.cancel(cx);
1695
1696 self.send_task = Some(cx.spawn(async move |this, cx| {
1697 cancel_task.await;
1698 tx.send(f(this, cx).await).ok();
1699 }));
1700
1701 cx.spawn(async move |this, cx| {
1702 let response = rx.await;
1703
1704 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1705 .await?;
1706
1707 this.update(cx, |this, cx| {
1708 this.project
1709 .update(cx, |project, cx| project.set_agent_location(None, cx));
1710 match response {
1711 Ok(Err(e)) => {
1712 this.send_task.take();
1713 cx.emit(AcpThreadEvent::Error);
1714 Err(e)
1715 }
1716 result => {
1717 let canceled = matches!(
1718 result,
1719 Ok(Ok(acp::PromptResponse {
1720 stop_reason: acp::StopReason::Cancelled,
1721 meta: None,
1722 }))
1723 );
1724
1725 // We only take the task if the current prompt wasn't canceled.
1726 //
1727 // This prompt may have been canceled because another one was sent
1728 // while it was still generating. In these cases, dropping `send_task`
1729 // would cause the next generation to be canceled.
1730 if !canceled {
1731 this.send_task.take();
1732 }
1733
1734 // Handle refusal - distinguish between user prompt and tool call refusals
1735 if let Ok(Ok(acp::PromptResponse {
1736 stop_reason: acp::StopReason::Refusal,
1737 meta: _,
1738 })) = result
1739 {
1740 if let Some((user_msg_ix, _)) = this.last_user_message() {
1741 // Check if there's a completed tool call with results after the last user message
1742 // This indicates the refusal is in response to tool output, not the user's prompt
1743 let has_completed_tool_call_after_user_msg =
1744 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1745 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1746 // Check if the tool call has completed and has output
1747 matches!(tool_call.status, ToolCallStatus::Completed)
1748 && tool_call.raw_output.is_some()
1749 } else {
1750 false
1751 }
1752 });
1753
1754 if has_completed_tool_call_after_user_msg {
1755 // Refusal is due to tool output - don't truncate, just notify
1756 // The model refused based on what the tool returned
1757 cx.emit(AcpThreadEvent::Refusal);
1758 } else {
1759 // User prompt was refused - truncate back to before the user message
1760 let range = user_msg_ix..this.entries.len();
1761 if range.start < range.end {
1762 this.entries.truncate(user_msg_ix);
1763 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1764 }
1765 cx.emit(AcpThreadEvent::Refusal);
1766 }
1767 } else {
1768 // No user message found, treat as general refusal
1769 cx.emit(AcpThreadEvent::Refusal);
1770 }
1771 }
1772
1773 cx.emit(AcpThreadEvent::Stopped);
1774 Ok(())
1775 }
1776 }
1777 })?
1778 })
1779 .boxed()
1780 }
1781
1782 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1783 let Some(send_task) = self.send_task.take() else {
1784 return Task::ready(());
1785 };
1786
1787 for entry in self.entries.iter_mut() {
1788 if let AgentThreadEntry::ToolCall(call) = entry {
1789 let cancel = matches!(
1790 call.status,
1791 ToolCallStatus::Pending
1792 | ToolCallStatus::WaitingForConfirmation { .. }
1793 | ToolCallStatus::InProgress
1794 );
1795
1796 if cancel {
1797 call.status = ToolCallStatus::Canceled;
1798 }
1799 }
1800 }
1801
1802 self.connection.cancel(&self.session_id, cx);
1803
1804 // Wait for the send task to complete
1805 cx.foreground_executor().spawn(send_task)
1806 }
1807
1808 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1809 pub fn restore_checkpoint(
1810 &mut self,
1811 id: UserMessageId,
1812 cx: &mut Context<Self>,
1813 ) -> Task<Result<()>> {
1814 let Some((_, message)) = self.user_message_mut(&id) else {
1815 return Task::ready(Err(anyhow!("message not found")));
1816 };
1817
1818 let checkpoint = message
1819 .checkpoint
1820 .as_ref()
1821 .map(|c| c.git_checkpoint.clone());
1822 let rewind = self.rewind(id.clone(), cx);
1823 let git_store = self.project.read(cx).git_store().clone();
1824
1825 cx.spawn(async move |_, cx| {
1826 rewind.await?;
1827 if let Some(checkpoint) = checkpoint {
1828 git_store
1829 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1830 .await?;
1831 }
1832
1833 Ok(())
1834 })
1835 }
1836
1837 /// Rewinds this thread to before the entry at `index`, removing it and all
1838 /// subsequent entries while rejecting any action_log changes made from that point.
1839 /// Unlike `restore_checkpoint`, this method does not restore from git.
1840 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1841 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1842 return Task::ready(Err(anyhow!("not supported")));
1843 };
1844
1845 cx.spawn(async move |this, cx| {
1846 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1847 this.update(cx, |this, cx| {
1848 if let Some((ix, _)) = this.user_message_mut(&id) {
1849 let range = ix..this.entries.len();
1850 this.entries.truncate(ix);
1851 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1852 }
1853 this.action_log()
1854 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1855 })?
1856 .await;
1857 Ok(())
1858 })
1859 }
1860
1861 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1862 let git_store = self.project.read(cx).git_store().clone();
1863
1864 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1865 if let Some(checkpoint) = message.checkpoint.as_ref() {
1866 checkpoint.git_checkpoint.clone()
1867 } else {
1868 return Task::ready(Ok(()));
1869 }
1870 } else {
1871 return Task::ready(Ok(()));
1872 };
1873
1874 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1875 cx.spawn(async move |this, cx| {
1876 let new_checkpoint = new_checkpoint
1877 .await
1878 .context("failed to get new checkpoint")
1879 .log_err();
1880 if let Some(new_checkpoint) = new_checkpoint {
1881 let equal = git_store
1882 .update(cx, |git, cx| {
1883 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1884 })?
1885 .await
1886 .unwrap_or(true);
1887 this.update(cx, |this, cx| {
1888 let (ix, message) = this.last_user_message().context("no user message")?;
1889 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1890 checkpoint.show = !equal;
1891 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1892 anyhow::Ok(())
1893 })??;
1894 }
1895
1896 Ok(())
1897 })
1898 }
1899
1900 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1901 self.entries
1902 .iter_mut()
1903 .enumerate()
1904 .rev()
1905 .find_map(|(ix, entry)| {
1906 if let AgentThreadEntry::UserMessage(message) = entry {
1907 Some((ix, message))
1908 } else {
1909 None
1910 }
1911 })
1912 }
1913
1914 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1915 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1916 if let AgentThreadEntry::UserMessage(message) = entry {
1917 if message.id.as_ref() == Some(id) {
1918 Some((ix, message))
1919 } else {
1920 None
1921 }
1922 } else {
1923 None
1924 }
1925 })
1926 }
1927
1928 pub fn read_text_file(
1929 &self,
1930 path: PathBuf,
1931 line: Option<u32>,
1932 limit: Option<u32>,
1933 reuse_shared_snapshot: bool,
1934 cx: &mut Context<Self>,
1935 ) -> Task<Result<String, acp::Error>> {
1936 // Args are 1-based, move to 0-based
1937 let line = line.unwrap_or_default().saturating_sub(1);
1938 let limit = limit.unwrap_or(u32::MAX);
1939 let project = self.project.clone();
1940 let action_log = self.action_log.clone();
1941 cx.spawn(async move |this, cx| {
1942 let load = project
1943 .update(cx, |project, cx| {
1944 let path = project
1945 .project_path_for_absolute_path(&path, cx)
1946 .ok_or_else(|| {
1947 acp::Error::resource_not_found(Some(path.display().to_string()))
1948 })?;
1949 Ok(project.open_buffer(path, cx))
1950 })
1951 .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1952 .flatten()?;
1953
1954 let buffer = load.await?;
1955
1956 let snapshot = if reuse_shared_snapshot {
1957 this.read_with(cx, |this, _| {
1958 this.shared_buffers.get(&buffer.clone()).cloned()
1959 })
1960 .log_err()
1961 .flatten()
1962 } else {
1963 None
1964 };
1965
1966 let snapshot = if let Some(snapshot) = snapshot {
1967 snapshot
1968 } else {
1969 action_log.update(cx, |action_log, cx| {
1970 action_log.buffer_read(buffer.clone(), cx);
1971 })?;
1972
1973 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1974 this.update(cx, |this, _| {
1975 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1976 })?;
1977 snapshot
1978 };
1979
1980 let max_point = snapshot.max_point();
1981 let start_position = Point::new(line, 0);
1982
1983 if start_position > max_point {
1984 return Err(acp::Error::invalid_params().with_data(format!(
1985 "Attempting to read beyond the end of the file, line {}:{}",
1986 max_point.row + 1,
1987 max_point.column
1988 )));
1989 }
1990
1991 let start = snapshot.anchor_before(start_position);
1992 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1993
1994 project.update(cx, |project, cx| {
1995 project.set_agent_location(
1996 Some(AgentLocation {
1997 buffer: buffer.downgrade(),
1998 position: start,
1999 }),
2000 cx,
2001 );
2002 })?;
2003
2004 Ok(snapshot.text_for_range(start..end).collect::<String>())
2005 })
2006 }
2007
2008 pub fn write_text_file(
2009 &self,
2010 path: PathBuf,
2011 content: String,
2012 cx: &mut Context<Self>,
2013 ) -> Task<Result<()>> {
2014 let project = self.project.clone();
2015 let action_log = self.action_log.clone();
2016 cx.spawn(async move |this, cx| {
2017 let load = project.update(cx, |project, cx| {
2018 let path = project
2019 .project_path_for_absolute_path(&path, cx)
2020 .context("invalid path")?;
2021 anyhow::Ok(project.open_buffer(path, cx))
2022 });
2023 let buffer = load??.await?;
2024 let snapshot = this.update(cx, |this, cx| {
2025 this.shared_buffers
2026 .get(&buffer)
2027 .cloned()
2028 .unwrap_or_else(|| buffer.read(cx).snapshot())
2029 })?;
2030 let edits = cx
2031 .background_executor()
2032 .spawn(async move {
2033 let old_text = snapshot.text();
2034 text_diff(old_text.as_str(), &content)
2035 .into_iter()
2036 .map(|(range, replacement)| {
2037 (
2038 snapshot.anchor_after(range.start)
2039 ..snapshot.anchor_before(range.end),
2040 replacement,
2041 )
2042 })
2043 .collect::<Vec<_>>()
2044 })
2045 .await;
2046
2047 project.update(cx, |project, cx| {
2048 project.set_agent_location(
2049 Some(AgentLocation {
2050 buffer: buffer.downgrade(),
2051 position: edits
2052 .last()
2053 .map(|(range, _)| range.end)
2054 .unwrap_or(Anchor::MIN),
2055 }),
2056 cx,
2057 );
2058 })?;
2059
2060 let format_on_save = cx.update(|cx| {
2061 action_log.update(cx, |action_log, cx| {
2062 action_log.buffer_read(buffer.clone(), cx);
2063 });
2064
2065 let format_on_save = buffer.update(cx, |buffer, cx| {
2066 buffer.edit(edits, None, cx);
2067
2068 let settings = language::language_settings::language_settings(
2069 buffer.language().map(|l| l.name()),
2070 buffer.file(),
2071 cx,
2072 );
2073
2074 settings.format_on_save != FormatOnSave::Off
2075 });
2076 action_log.update(cx, |action_log, cx| {
2077 action_log.buffer_edited(buffer.clone(), cx);
2078 });
2079 format_on_save
2080 })?;
2081
2082 if format_on_save {
2083 let format_task = project.update(cx, |project, cx| {
2084 project.format(
2085 HashSet::from_iter([buffer.clone()]),
2086 LspFormatTarget::Buffers,
2087 false,
2088 FormatTrigger::Save,
2089 cx,
2090 )
2091 })?;
2092 format_task.await.log_err();
2093
2094 action_log.update(cx, |action_log, cx| {
2095 action_log.buffer_edited(buffer.clone(), cx);
2096 })?;
2097 }
2098
2099 project
2100 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
2101 .await
2102 })
2103 }
2104
2105 pub fn create_terminal(
2106 &self,
2107 command: String,
2108 args: Vec<String>,
2109 extra_env: Vec<acp::EnvVariable>,
2110 cwd: Option<PathBuf>,
2111 output_byte_limit: Option<u64>,
2112 cx: &mut Context<Self>,
2113 ) -> Task<Result<Entity<Terminal>>> {
2114 let env = match &cwd {
2115 Some(dir) => self.project.update(cx, |project, cx| {
2116 let worktree = project.find_worktree(dir.as_path(), cx);
2117 let shell = TerminalSettings::get(
2118 worktree.as_ref().map(|(worktree, path)| SettingsLocation {
2119 worktree_id: worktree.read(cx).id(),
2120 path: &path,
2121 }),
2122 cx,
2123 )
2124 .shell
2125 .clone();
2126 project.directory_environment(&shell, dir.as_path().into(), cx)
2127 }),
2128 None => Task::ready(None).shared(),
2129 };
2130 let env = cx.spawn(async move |_, _| {
2131 let mut env = env.await.unwrap_or_default();
2132 // Disables paging for `git` and hopefully other commands
2133 env.insert("PAGER".into(), "".into());
2134 for var in extra_env {
2135 env.insert(var.name, var.value);
2136 }
2137 env
2138 });
2139
2140 let project = self.project.clone();
2141 let language_registry = project.read(cx).languages().clone();
2142 let is_windows = project.read(cx).path_style(cx).is_windows();
2143
2144 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2145 let terminal_task = cx.spawn({
2146 let terminal_id = terminal_id.clone();
2147 async move |_this, cx| {
2148 let env = env.await;
2149 let shell = project
2150 .update(cx, |project, cx| {
2151 project
2152 .remote_client()
2153 .and_then(|r| r.read(cx).default_system_shell())
2154 })?
2155 .unwrap_or_else(|| get_default_system_shell_preferring_bash());
2156 let (task_command, task_args) =
2157 ShellBuilder::new(&Shell::Program(shell), is_windows)
2158 .redirect_stdin_to_dev_null()
2159 .build(Some(command.clone()), &args);
2160 let terminal = project
2161 .update(cx, |project, cx| {
2162 project.create_terminal_task(
2163 task::SpawnInTerminal {
2164 command: Some(task_command),
2165 args: task_args,
2166 cwd: cwd.clone(),
2167 env,
2168 ..Default::default()
2169 },
2170 cx,
2171 )
2172 })?
2173 .await?;
2174
2175 cx.new(|cx| {
2176 Terminal::new(
2177 terminal_id,
2178 &format!("{} {}", command, args.join(" ")),
2179 cwd,
2180 output_byte_limit.map(|l| l as usize),
2181 terminal,
2182 language_registry,
2183 cx,
2184 )
2185 })
2186 }
2187 });
2188
2189 cx.spawn(async move |this, cx| {
2190 let terminal = terminal_task.await?;
2191 this.update(cx, |this, _cx| {
2192 this.terminals.insert(terminal_id, terminal.clone());
2193 terminal
2194 })
2195 })
2196 }
2197
2198 pub fn kill_terminal(
2199 &mut self,
2200 terminal_id: acp::TerminalId,
2201 cx: &mut Context<Self>,
2202 ) -> Result<()> {
2203 self.terminals
2204 .get(&terminal_id)
2205 .context("Terminal not found")?
2206 .update(cx, |terminal, cx| {
2207 terminal.kill(cx);
2208 });
2209
2210 Ok(())
2211 }
2212
2213 pub fn release_terminal(
2214 &mut self,
2215 terminal_id: acp::TerminalId,
2216 cx: &mut Context<Self>,
2217 ) -> Result<()> {
2218 self.terminals
2219 .remove(&terminal_id)
2220 .context("Terminal not found")?
2221 .update(cx, |terminal, cx| {
2222 terminal.kill(cx);
2223 });
2224
2225 Ok(())
2226 }
2227
2228 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2229 self.terminals
2230 .get(&terminal_id)
2231 .context("Terminal not found")
2232 .cloned()
2233 }
2234
2235 pub fn to_markdown(&self, cx: &App) -> String {
2236 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2237 }
2238
2239 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2240 cx.emit(AcpThreadEvent::LoadError(error));
2241 }
2242
2243 pub fn register_terminal_created(
2244 &mut self,
2245 terminal_id: acp::TerminalId,
2246 command_label: String,
2247 working_dir: Option<PathBuf>,
2248 output_byte_limit: Option<u64>,
2249 terminal: Entity<::terminal::Terminal>,
2250 cx: &mut Context<Self>,
2251 ) -> Entity<Terminal> {
2252 let language_registry = self.project.read(cx).languages().clone();
2253
2254 let entity = cx.new(|cx| {
2255 Terminal::new(
2256 terminal_id.clone(),
2257 &command_label,
2258 working_dir.clone(),
2259 output_byte_limit.map(|l| l as usize),
2260 terminal,
2261 language_registry,
2262 cx,
2263 )
2264 });
2265 self.terminals.insert(terminal_id.clone(), entity.clone());
2266 entity
2267 }
2268}
2269
2270fn markdown_for_raw_output(
2271 raw_output: &serde_json::Value,
2272 language_registry: &Arc<LanguageRegistry>,
2273 cx: &mut App,
2274) -> Option<Entity<Markdown>> {
2275 match raw_output {
2276 serde_json::Value::Null => None,
2277 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2278 Markdown::new(
2279 value.to_string().into(),
2280 Some(language_registry.clone()),
2281 None,
2282 cx,
2283 )
2284 })),
2285 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2286 Markdown::new(
2287 value.to_string().into(),
2288 Some(language_registry.clone()),
2289 None,
2290 cx,
2291 )
2292 })),
2293 serde_json::Value::String(value) => Some(cx.new(|cx| {
2294 Markdown::new(
2295 value.clone().into(),
2296 Some(language_registry.clone()),
2297 None,
2298 cx,
2299 )
2300 })),
2301 value => Some(cx.new(|cx| {
2302 Markdown::new(
2303 format!("```json\n{}\n```", value).into(),
2304 Some(language_registry.clone()),
2305 None,
2306 cx,
2307 )
2308 })),
2309 }
2310}
2311
2312#[cfg(test)]
2313mod tests {
2314 use super::*;
2315 use anyhow::anyhow;
2316 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2317 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2318 use indoc::indoc;
2319 use project::{FakeFs, Fs};
2320 use rand::{distr, prelude::*};
2321 use serde_json::json;
2322 use settings::SettingsStore;
2323 use smol::stream::StreamExt as _;
2324 use std::{
2325 any::Any,
2326 cell::RefCell,
2327 path::Path,
2328 rc::Rc,
2329 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2330 time::Duration,
2331 };
2332 use util::path;
2333
2334 fn init_test(cx: &mut TestAppContext) {
2335 env_logger::try_init().ok();
2336 cx.update(|cx| {
2337 let settings_store = SettingsStore::test(cx);
2338 cx.set_global(settings_store);
2339 Project::init_settings(cx);
2340 language::init(cx);
2341 });
2342 }
2343
2344 #[gpui::test]
2345 async fn test_terminal_output_buffered_before_created_renders(cx: &mut gpui::TestAppContext) {
2346 init_test(cx);
2347
2348 let fs = FakeFs::new(cx.executor());
2349 let project = Project::test(fs, [], cx).await;
2350 let connection = Rc::new(FakeAgentConnection::new());
2351 let thread = cx
2352 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2353 .await
2354 .unwrap();
2355
2356 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2357
2358 // Send Output BEFORE Created - should be buffered by acp_thread
2359 thread.update(cx, |thread, cx| {
2360 thread.on_terminal_provider_event(
2361 TerminalProviderEvent::Output {
2362 terminal_id: terminal_id.clone(),
2363 data: b"hello buffered".to_vec(),
2364 },
2365 cx,
2366 );
2367 });
2368
2369 // Create a display-only terminal and then send Created
2370 let lower = cx.new(|cx| {
2371 let builder = ::terminal::TerminalBuilder::new_display_only(
2372 ::terminal::terminal_settings::CursorShape::default(),
2373 ::terminal::terminal_settings::AlternateScroll::On,
2374 None,
2375 0,
2376 )
2377 .unwrap();
2378 builder.subscribe(cx)
2379 });
2380
2381 thread.update(cx, |thread, cx| {
2382 thread.on_terminal_provider_event(
2383 TerminalProviderEvent::Created {
2384 terminal_id: terminal_id.clone(),
2385 label: "Buffered Test".to_string(),
2386 cwd: None,
2387 output_byte_limit: None,
2388 terminal: lower.clone(),
2389 },
2390 cx,
2391 );
2392 });
2393
2394 // After Created, buffered Output should have been flushed into the renderer
2395 let content = thread.read_with(cx, |thread, cx| {
2396 let term = thread.terminal(terminal_id.clone()).unwrap();
2397 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2398 });
2399
2400 assert!(
2401 content.contains("hello buffered"),
2402 "expected buffered output to render, got: {content}"
2403 );
2404 }
2405
2406 #[gpui::test]
2407 async fn test_terminal_output_and_exit_buffered_before_created(cx: &mut gpui::TestAppContext) {
2408 init_test(cx);
2409
2410 let fs = FakeFs::new(cx.executor());
2411 let project = Project::test(fs, [], cx).await;
2412 let connection = Rc::new(FakeAgentConnection::new());
2413 let thread = cx
2414 .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx))
2415 .await
2416 .unwrap();
2417
2418 let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
2419
2420 // Send Output BEFORE Created
2421 thread.update(cx, |thread, cx| {
2422 thread.on_terminal_provider_event(
2423 TerminalProviderEvent::Output {
2424 terminal_id: terminal_id.clone(),
2425 data: b"pre-exit data".to_vec(),
2426 },
2427 cx,
2428 );
2429 });
2430
2431 // Send Exit BEFORE Created
2432 thread.update(cx, |thread, cx| {
2433 thread.on_terminal_provider_event(
2434 TerminalProviderEvent::Exit {
2435 terminal_id: terminal_id.clone(),
2436 status: acp::TerminalExitStatus {
2437 exit_code: Some(0),
2438 signal: None,
2439 meta: None,
2440 },
2441 },
2442 cx,
2443 );
2444 });
2445
2446 // Now create a display-only lower-level terminal and send Created
2447 let lower = cx.new(|cx| {
2448 let builder = ::terminal::TerminalBuilder::new_display_only(
2449 ::terminal::terminal_settings::CursorShape::default(),
2450 ::terminal::terminal_settings::AlternateScroll::On,
2451 None,
2452 0,
2453 )
2454 .unwrap();
2455 builder.subscribe(cx)
2456 });
2457
2458 thread.update(cx, |thread, cx| {
2459 thread.on_terminal_provider_event(
2460 TerminalProviderEvent::Created {
2461 terminal_id: terminal_id.clone(),
2462 label: "Buffered Exit Test".to_string(),
2463 cwd: None,
2464 output_byte_limit: None,
2465 terminal: lower.clone(),
2466 },
2467 cx,
2468 );
2469 });
2470
2471 // Output should be present after Created (flushed from buffer)
2472 let content = thread.read_with(cx, |thread, cx| {
2473 let term = thread.terminal(terminal_id.clone()).unwrap();
2474 term.read_with(cx, |t, cx| t.inner().read(cx).get_content())
2475 });
2476
2477 assert!(
2478 content.contains("pre-exit data"),
2479 "expected pre-exit data to render, got: {content}"
2480 );
2481 }
2482
2483 #[gpui::test]
2484 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2485 init_test(cx);
2486
2487 let fs = FakeFs::new(cx.executor());
2488 let project = Project::test(fs, [], cx).await;
2489 let connection = Rc::new(FakeAgentConnection::new());
2490 let thread = cx
2491 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2492 .await
2493 .unwrap();
2494
2495 // Test creating a new user message
2496 thread.update(cx, |thread, cx| {
2497 thread.push_user_content_block(
2498 None,
2499 acp::ContentBlock::Text(acp::TextContent {
2500 annotations: None,
2501 text: "Hello, ".to_string(),
2502 meta: None,
2503 }),
2504 cx,
2505 );
2506 });
2507
2508 thread.update(cx, |thread, cx| {
2509 assert_eq!(thread.entries.len(), 1);
2510 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2511 assert_eq!(user_msg.id, None);
2512 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2513 } else {
2514 panic!("Expected UserMessage");
2515 }
2516 });
2517
2518 // Test appending to existing user message
2519 let message_1_id = UserMessageId::new();
2520 thread.update(cx, |thread, cx| {
2521 thread.push_user_content_block(
2522 Some(message_1_id.clone()),
2523 acp::ContentBlock::Text(acp::TextContent {
2524 annotations: None,
2525 text: "world!".to_string(),
2526 meta: None,
2527 }),
2528 cx,
2529 );
2530 });
2531
2532 thread.update(cx, |thread, cx| {
2533 assert_eq!(thread.entries.len(), 1);
2534 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2535 assert_eq!(user_msg.id, Some(message_1_id));
2536 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2537 } else {
2538 panic!("Expected UserMessage");
2539 }
2540 });
2541
2542 // Test creating new user message after assistant message
2543 thread.update(cx, |thread, cx| {
2544 thread.push_assistant_content_block(
2545 acp::ContentBlock::Text(acp::TextContent {
2546 annotations: None,
2547 text: "Assistant response".to_string(),
2548 meta: None,
2549 }),
2550 false,
2551 cx,
2552 );
2553 });
2554
2555 let message_2_id = UserMessageId::new();
2556 thread.update(cx, |thread, cx| {
2557 thread.push_user_content_block(
2558 Some(message_2_id.clone()),
2559 acp::ContentBlock::Text(acp::TextContent {
2560 annotations: None,
2561 text: "New user message".to_string(),
2562 meta: None,
2563 }),
2564 cx,
2565 );
2566 });
2567
2568 thread.update(cx, |thread, cx| {
2569 assert_eq!(thread.entries.len(), 3);
2570 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2571 assert_eq!(user_msg.id, Some(message_2_id));
2572 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2573 } else {
2574 panic!("Expected UserMessage at index 2");
2575 }
2576 });
2577 }
2578
2579 #[gpui::test]
2580 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2581 init_test(cx);
2582
2583 let fs = FakeFs::new(cx.executor());
2584 let project = Project::test(fs, [], cx).await;
2585 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2586 |_, thread, mut cx| {
2587 async move {
2588 thread.update(&mut cx, |thread, cx| {
2589 thread
2590 .handle_session_update(
2591 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk {
2592 content: "Thinking ".into(),
2593 meta: None,
2594 }),
2595 cx,
2596 )
2597 .unwrap();
2598 thread
2599 .handle_session_update(
2600 acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk {
2601 content: "hard!".into(),
2602 meta: None,
2603 }),
2604 cx,
2605 )
2606 .unwrap();
2607 })?;
2608 Ok(acp::PromptResponse {
2609 stop_reason: acp::StopReason::EndTurn,
2610 meta: None,
2611 })
2612 }
2613 .boxed_local()
2614 },
2615 ));
2616
2617 let thread = cx
2618 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2619 .await
2620 .unwrap();
2621
2622 thread
2623 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2624 .await
2625 .unwrap();
2626
2627 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2628 assert_eq!(
2629 output,
2630 indoc! {r#"
2631 ## User
2632
2633 Hello from Zed!
2634
2635 ## Assistant
2636
2637 <thinking>
2638 Thinking hard!
2639 </thinking>
2640
2641 "#}
2642 );
2643 }
2644
2645 #[gpui::test]
2646 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2647 init_test(cx);
2648
2649 let fs = FakeFs::new(cx.executor());
2650 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2651 .await;
2652 let project = Project::test(fs.clone(), [], cx).await;
2653 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2654 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2655 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2656 move |_, thread, mut cx| {
2657 let read_file_tx = read_file_tx.clone();
2658 async move {
2659 let content = thread
2660 .update(&mut cx, |thread, cx| {
2661 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2662 })
2663 .unwrap()
2664 .await
2665 .unwrap();
2666 assert_eq!(content, "one\ntwo\nthree\n");
2667 read_file_tx.take().unwrap().send(()).unwrap();
2668 thread
2669 .update(&mut cx, |thread, cx| {
2670 thread.write_text_file(
2671 path!("/tmp/foo").into(),
2672 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2673 cx,
2674 )
2675 })
2676 .unwrap()
2677 .await
2678 .unwrap();
2679 Ok(acp::PromptResponse {
2680 stop_reason: acp::StopReason::EndTurn,
2681 meta: None,
2682 })
2683 }
2684 .boxed_local()
2685 },
2686 ));
2687
2688 let (worktree, pathbuf) = project
2689 .update(cx, |project, cx| {
2690 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2691 })
2692 .await
2693 .unwrap();
2694 let buffer = project
2695 .update(cx, |project, cx| {
2696 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2697 })
2698 .await
2699 .unwrap();
2700
2701 let thread = cx
2702 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2703 .await
2704 .unwrap();
2705
2706 let request = thread.update(cx, |thread, cx| {
2707 thread.send_raw("Extend the count in /tmp/foo", cx)
2708 });
2709 read_file_rx.await.ok();
2710 buffer.update(cx, |buffer, cx| {
2711 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2712 });
2713 cx.run_until_parked();
2714 assert_eq!(
2715 buffer.read_with(cx, |buffer, _| buffer.text()),
2716 "zero\none\ntwo\nthree\nfour\nfive\n"
2717 );
2718 assert_eq!(
2719 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2720 "zero\none\ntwo\nthree\nfour\nfive\n"
2721 );
2722 request.await.unwrap();
2723 }
2724
2725 #[gpui::test]
2726 async fn test_reading_from_line(cx: &mut TestAppContext) {
2727 init_test(cx);
2728
2729 let fs = FakeFs::new(cx.executor());
2730 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2731 .await;
2732 let project = Project::test(fs.clone(), [], cx).await;
2733 project
2734 .update(cx, |project, cx| {
2735 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2736 })
2737 .await
2738 .unwrap();
2739
2740 let connection = Rc::new(FakeAgentConnection::new());
2741
2742 let thread = cx
2743 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2744 .await
2745 .unwrap();
2746
2747 // Whole file
2748 let content = thread
2749 .update(cx, |thread, cx| {
2750 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2751 })
2752 .await
2753 .unwrap();
2754
2755 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2756
2757 // Only start line
2758 let content = thread
2759 .update(cx, |thread, cx| {
2760 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2761 })
2762 .await
2763 .unwrap();
2764
2765 assert_eq!(content, "three\nfour\n");
2766
2767 // Only limit
2768 let content = thread
2769 .update(cx, |thread, cx| {
2770 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2771 })
2772 .await
2773 .unwrap();
2774
2775 assert_eq!(content, "one\ntwo\n");
2776
2777 // Range
2778 let content = thread
2779 .update(cx, |thread, cx| {
2780 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2781 })
2782 .await
2783 .unwrap();
2784
2785 assert_eq!(content, "two\nthree\n");
2786
2787 // Invalid
2788 let err = thread
2789 .update(cx, |thread, cx| {
2790 thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2791 })
2792 .await
2793 .unwrap_err();
2794
2795 assert_eq!(
2796 err.to_string(),
2797 "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2798 );
2799 }
2800
2801 #[gpui::test]
2802 async fn test_reading_empty_file(cx: &mut TestAppContext) {
2803 init_test(cx);
2804
2805 let fs = FakeFs::new(cx.executor());
2806 fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2807 let project = Project::test(fs.clone(), [], cx).await;
2808 project
2809 .update(cx, |project, cx| {
2810 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2811 })
2812 .await
2813 .unwrap();
2814
2815 let connection = Rc::new(FakeAgentConnection::new());
2816
2817 let thread = cx
2818 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2819 .await
2820 .unwrap();
2821
2822 // Whole file
2823 let content = thread
2824 .update(cx, |thread, cx| {
2825 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2826 })
2827 .await
2828 .unwrap();
2829
2830 assert_eq!(content, "");
2831
2832 // Only start line
2833 let content = thread
2834 .update(cx, |thread, cx| {
2835 thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2836 })
2837 .await
2838 .unwrap();
2839
2840 assert_eq!(content, "");
2841
2842 // Only limit
2843 let content = thread
2844 .update(cx, |thread, cx| {
2845 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2846 })
2847 .await
2848 .unwrap();
2849
2850 assert_eq!(content, "");
2851
2852 // Range
2853 let content = thread
2854 .update(cx, |thread, cx| {
2855 thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2856 })
2857 .await
2858 .unwrap();
2859
2860 assert_eq!(content, "");
2861
2862 // Invalid
2863 let err = thread
2864 .update(cx, |thread, cx| {
2865 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2866 })
2867 .await
2868 .unwrap_err();
2869
2870 assert_eq!(
2871 err.to_string(),
2872 "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2873 );
2874 }
2875 #[gpui::test]
2876 async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2877 init_test(cx);
2878
2879 let fs = FakeFs::new(cx.executor());
2880 fs.insert_tree(path!("/tmp"), json!({})).await;
2881 let project = Project::test(fs.clone(), [], cx).await;
2882 project
2883 .update(cx, |project, cx| {
2884 project.find_or_create_worktree(path!("/tmp"), true, cx)
2885 })
2886 .await
2887 .unwrap();
2888
2889 let connection = Rc::new(FakeAgentConnection::new());
2890
2891 let thread = cx
2892 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2893 .await
2894 .unwrap();
2895
2896 // Out of project file
2897 let err = thread
2898 .update(cx, |thread, cx| {
2899 thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2900 })
2901 .await
2902 .unwrap_err();
2903
2904 assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2905 }
2906
2907 #[gpui::test]
2908 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2909 init_test(cx);
2910
2911 let fs = FakeFs::new(cx.executor());
2912 let project = Project::test(fs, [], cx).await;
2913 let id = acp::ToolCallId("test".into());
2914
2915 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2916 let id = id.clone();
2917 move |_, thread, mut cx| {
2918 let id = id.clone();
2919 async move {
2920 thread
2921 .update(&mut cx, |thread, cx| {
2922 thread.handle_session_update(
2923 acp::SessionUpdate::ToolCall(acp::ToolCall {
2924 id: id.clone(),
2925 title: "Label".into(),
2926 kind: acp::ToolKind::Fetch,
2927 status: acp::ToolCallStatus::InProgress,
2928 content: vec![],
2929 locations: vec![],
2930 raw_input: None,
2931 raw_output: None,
2932 meta: None,
2933 }),
2934 cx,
2935 )
2936 })
2937 .unwrap()
2938 .unwrap();
2939 Ok(acp::PromptResponse {
2940 stop_reason: acp::StopReason::EndTurn,
2941 meta: None,
2942 })
2943 }
2944 .boxed_local()
2945 }
2946 }));
2947
2948 let thread = cx
2949 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2950 .await
2951 .unwrap();
2952
2953 let request = thread.update(cx, |thread, cx| {
2954 thread.send_raw("Fetch https://example.com", cx)
2955 });
2956
2957 run_until_first_tool_call(&thread, cx).await;
2958
2959 thread.read_with(cx, |thread, _| {
2960 assert!(matches!(
2961 thread.entries[1],
2962 AgentThreadEntry::ToolCall(ToolCall {
2963 status: ToolCallStatus::InProgress,
2964 ..
2965 })
2966 ));
2967 });
2968
2969 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2970
2971 thread.read_with(cx, |thread, _| {
2972 assert!(matches!(
2973 &thread.entries[1],
2974 AgentThreadEntry::ToolCall(ToolCall {
2975 status: ToolCallStatus::Canceled,
2976 ..
2977 })
2978 ));
2979 });
2980
2981 thread
2982 .update(cx, |thread, cx| {
2983 thread.handle_session_update(
2984 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2985 id,
2986 fields: acp::ToolCallUpdateFields {
2987 status: Some(acp::ToolCallStatus::Completed),
2988 ..Default::default()
2989 },
2990 meta: None,
2991 }),
2992 cx,
2993 )
2994 })
2995 .unwrap();
2996
2997 request.await.unwrap();
2998
2999 thread.read_with(cx, |thread, _| {
3000 assert!(matches!(
3001 thread.entries[1],
3002 AgentThreadEntry::ToolCall(ToolCall {
3003 status: ToolCallStatus::Completed,
3004 ..
3005 })
3006 ));
3007 });
3008 }
3009
3010 #[gpui::test]
3011 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
3012 init_test(cx);
3013 let fs = FakeFs::new(cx.background_executor.clone());
3014 fs.insert_tree(path!("/test"), json!({})).await;
3015 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3016
3017 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3018 move |_, thread, mut cx| {
3019 async move {
3020 thread
3021 .update(&mut cx, |thread, cx| {
3022 thread.handle_session_update(
3023 acp::SessionUpdate::ToolCall(acp::ToolCall {
3024 id: acp::ToolCallId("test".into()),
3025 title: "Label".into(),
3026 kind: acp::ToolKind::Edit,
3027 status: acp::ToolCallStatus::Completed,
3028 content: vec![acp::ToolCallContent::Diff {
3029 diff: acp::Diff {
3030 path: "/test/test.txt".into(),
3031 old_text: None,
3032 new_text: "foo".into(),
3033 meta: None,
3034 },
3035 }],
3036 locations: vec![],
3037 raw_input: None,
3038 raw_output: None,
3039 meta: None,
3040 }),
3041 cx,
3042 )
3043 })
3044 .unwrap()
3045 .unwrap();
3046 Ok(acp::PromptResponse {
3047 stop_reason: acp::StopReason::EndTurn,
3048 meta: None,
3049 })
3050 }
3051 .boxed_local()
3052 }
3053 }));
3054
3055 let thread = cx
3056 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3057 .await
3058 .unwrap();
3059
3060 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
3061 .await
3062 .unwrap();
3063
3064 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
3065 }
3066
3067 #[gpui::test(iterations = 10)]
3068 async fn test_checkpoints(cx: &mut TestAppContext) {
3069 init_test(cx);
3070 let fs = FakeFs::new(cx.background_executor.clone());
3071 fs.insert_tree(
3072 path!("/test"),
3073 json!({
3074 ".git": {}
3075 }),
3076 )
3077 .await;
3078 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3079
3080 let simulate_changes = Arc::new(AtomicBool::new(true));
3081 let next_filename = Arc::new(AtomicUsize::new(0));
3082 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3083 let simulate_changes = simulate_changes.clone();
3084 let next_filename = next_filename.clone();
3085 let fs = fs.clone();
3086 move |request, thread, mut cx| {
3087 let fs = fs.clone();
3088 let simulate_changes = simulate_changes.clone();
3089 let next_filename = next_filename.clone();
3090 async move {
3091 if simulate_changes.load(SeqCst) {
3092 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
3093 fs.write(Path::new(&filename), b"").await?;
3094 }
3095
3096 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3097 panic!("expected text content block");
3098 };
3099 thread.update(&mut cx, |thread, cx| {
3100 thread
3101 .handle_session_update(
3102 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk {
3103 content: content.text.to_uppercase().into(),
3104 meta: None,
3105 }),
3106 cx,
3107 )
3108 .unwrap();
3109 })?;
3110 Ok(acp::PromptResponse {
3111 stop_reason: acp::StopReason::EndTurn,
3112 meta: None,
3113 })
3114 }
3115 .boxed_local()
3116 }
3117 }));
3118 let thread = cx
3119 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3120 .await
3121 .unwrap();
3122
3123 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
3124 .await
3125 .unwrap();
3126 thread.read_with(cx, |thread, cx| {
3127 assert_eq!(
3128 thread.to_markdown(cx),
3129 indoc! {"
3130 ## User (checkpoint)
3131
3132 Lorem
3133
3134 ## Assistant
3135
3136 LOREM
3137
3138 "}
3139 );
3140 });
3141 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3142
3143 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
3144 .await
3145 .unwrap();
3146 thread.read_with(cx, |thread, cx| {
3147 assert_eq!(
3148 thread.to_markdown(cx),
3149 indoc! {"
3150 ## User (checkpoint)
3151
3152 Lorem
3153
3154 ## Assistant
3155
3156 LOREM
3157
3158 ## User (checkpoint)
3159
3160 ipsum
3161
3162 ## Assistant
3163
3164 IPSUM
3165
3166 "}
3167 );
3168 });
3169 assert_eq!(
3170 fs.files(),
3171 vec![
3172 Path::new(path!("/test/file-0")),
3173 Path::new(path!("/test/file-1"))
3174 ]
3175 );
3176
3177 // Checkpoint isn't stored when there are no changes.
3178 simulate_changes.store(false, SeqCst);
3179 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
3180 .await
3181 .unwrap();
3182 thread.read_with(cx, |thread, cx| {
3183 assert_eq!(
3184 thread.to_markdown(cx),
3185 indoc! {"
3186 ## User (checkpoint)
3187
3188 Lorem
3189
3190 ## Assistant
3191
3192 LOREM
3193
3194 ## User (checkpoint)
3195
3196 ipsum
3197
3198 ## Assistant
3199
3200 IPSUM
3201
3202 ## User
3203
3204 dolor
3205
3206 ## Assistant
3207
3208 DOLOR
3209
3210 "}
3211 );
3212 });
3213 assert_eq!(
3214 fs.files(),
3215 vec![
3216 Path::new(path!("/test/file-0")),
3217 Path::new(path!("/test/file-1"))
3218 ]
3219 );
3220
3221 // Rewinding the conversation truncates the history and restores the checkpoint.
3222 thread
3223 .update(cx, |thread, cx| {
3224 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
3225 panic!("unexpected entries {:?}", thread.entries)
3226 };
3227 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
3228 })
3229 .await
3230 .unwrap();
3231 thread.read_with(cx, |thread, cx| {
3232 assert_eq!(
3233 thread.to_markdown(cx),
3234 indoc! {"
3235 ## User (checkpoint)
3236
3237 Lorem
3238
3239 ## Assistant
3240
3241 LOREM
3242
3243 "}
3244 );
3245 });
3246 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
3247 }
3248
3249 #[gpui::test]
3250 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
3251 use std::sync::atomic::AtomicUsize;
3252 init_test(cx);
3253
3254 let fs = FakeFs::new(cx.executor());
3255 let project = Project::test(fs, None, cx).await;
3256
3257 // Create a connection that simulates refusal after tool result
3258 let prompt_count = Arc::new(AtomicUsize::new(0));
3259 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3260 let prompt_count = prompt_count.clone();
3261 move |_request, thread, mut cx| {
3262 let count = prompt_count.fetch_add(1, SeqCst);
3263 async move {
3264 if count == 0 {
3265 // First prompt: Generate a tool call with result
3266 thread.update(&mut cx, |thread, cx| {
3267 thread
3268 .handle_session_update(
3269 acp::SessionUpdate::ToolCall(acp::ToolCall {
3270 id: acp::ToolCallId("tool1".into()),
3271 title: "Test Tool".into(),
3272 kind: acp::ToolKind::Fetch,
3273 status: acp::ToolCallStatus::Completed,
3274 content: vec![],
3275 locations: vec![],
3276 raw_input: Some(serde_json::json!({"query": "test"})),
3277 raw_output: Some(
3278 serde_json::json!({"result": "inappropriate content"}),
3279 ),
3280 meta: None,
3281 }),
3282 cx,
3283 )
3284 .unwrap();
3285 })?;
3286
3287 // Now return refusal because of the tool result
3288 Ok(acp::PromptResponse {
3289 stop_reason: acp::StopReason::Refusal,
3290 meta: None,
3291 })
3292 } else {
3293 Ok(acp::PromptResponse {
3294 stop_reason: acp::StopReason::EndTurn,
3295 meta: None,
3296 })
3297 }
3298 }
3299 .boxed_local()
3300 }
3301 }));
3302
3303 let thread = cx
3304 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3305 .await
3306 .unwrap();
3307
3308 // Track if we see a Refusal event
3309 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3310 let saw_refusal_event_captured = saw_refusal_event.clone();
3311 thread.update(cx, |_thread, cx| {
3312 cx.subscribe(
3313 &thread,
3314 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3315 if matches!(event, AcpThreadEvent::Refusal) {
3316 *saw_refusal_event_captured.lock().unwrap() = true;
3317 }
3318 },
3319 )
3320 .detach();
3321 });
3322
3323 // Send a user message - this will trigger tool call and then refusal
3324 let send_task = thread.update(cx, |thread, cx| {
3325 thread.send(
3326 vec![acp::ContentBlock::Text(acp::TextContent {
3327 text: "Hello".into(),
3328 annotations: None,
3329 meta: None,
3330 })],
3331 cx,
3332 )
3333 });
3334 cx.background_executor.spawn(send_task).detach();
3335 cx.run_until_parked();
3336
3337 // Verify that:
3338 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3339 // 2. The user message was NOT truncated
3340 assert!(
3341 *saw_refusal_event.lock().unwrap(),
3342 "Refusal event should be emitted for tool result refusals"
3343 );
3344
3345 thread.read_with(cx, |thread, _| {
3346 let entries = thread.entries();
3347 assert!(entries.len() >= 2, "Should have user message and tool call");
3348
3349 // Verify user message is still there
3350 assert!(
3351 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3352 "User message should not be truncated"
3353 );
3354
3355 // Verify tool call is there with result
3356 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3357 assert!(
3358 tool_call.raw_output.is_some(),
3359 "Tool call should have output"
3360 );
3361 } else {
3362 panic!("Expected tool call at index 1");
3363 }
3364 });
3365 }
3366
3367 #[gpui::test]
3368 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3369 init_test(cx);
3370
3371 let fs = FakeFs::new(cx.executor());
3372 let project = Project::test(fs, None, cx).await;
3373
3374 let refuse_next = Arc::new(AtomicBool::new(false));
3375 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3376 let refuse_next = refuse_next.clone();
3377 move |_request, _thread, _cx| {
3378 if refuse_next.load(SeqCst) {
3379 async move {
3380 Ok(acp::PromptResponse {
3381 stop_reason: acp::StopReason::Refusal,
3382 meta: None,
3383 })
3384 }
3385 .boxed_local()
3386 } else {
3387 async move {
3388 Ok(acp::PromptResponse {
3389 stop_reason: acp::StopReason::EndTurn,
3390 meta: None,
3391 })
3392 }
3393 .boxed_local()
3394 }
3395 }
3396 }));
3397
3398 let thread = cx
3399 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3400 .await
3401 .unwrap();
3402
3403 // Track if we see a Refusal event
3404 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3405 let saw_refusal_event_captured = saw_refusal_event.clone();
3406 thread.update(cx, |_thread, cx| {
3407 cx.subscribe(
3408 &thread,
3409 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3410 if matches!(event, AcpThreadEvent::Refusal) {
3411 *saw_refusal_event_captured.lock().unwrap() = true;
3412 }
3413 },
3414 )
3415 .detach();
3416 });
3417
3418 // Send a message that will be refused
3419 refuse_next.store(true, SeqCst);
3420 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3421 .await
3422 .unwrap();
3423
3424 // Verify that a Refusal event WAS emitted for user prompt refusal
3425 assert!(
3426 *saw_refusal_event.lock().unwrap(),
3427 "Refusal event should be emitted for user prompt refusals"
3428 );
3429
3430 // Verify the message was truncated (user prompt refusal)
3431 thread.read_with(cx, |thread, cx| {
3432 assert_eq!(thread.to_markdown(cx), "");
3433 });
3434 }
3435
3436 #[gpui::test]
3437 async fn test_refusal(cx: &mut TestAppContext) {
3438 init_test(cx);
3439 let fs = FakeFs::new(cx.background_executor.clone());
3440 fs.insert_tree(path!("/"), json!({})).await;
3441 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3442
3443 let refuse_next = Arc::new(AtomicBool::new(false));
3444 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3445 let refuse_next = refuse_next.clone();
3446 move |request, thread, mut cx| {
3447 let refuse_next = refuse_next.clone();
3448 async move {
3449 if refuse_next.load(SeqCst) {
3450 return Ok(acp::PromptResponse {
3451 stop_reason: acp::StopReason::Refusal,
3452 meta: None,
3453 });
3454 }
3455
3456 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3457 panic!("expected text content block");
3458 };
3459 thread.update(&mut cx, |thread, cx| {
3460 thread
3461 .handle_session_update(
3462 acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk {
3463 content: content.text.to_uppercase().into(),
3464 meta: None,
3465 }),
3466 cx,
3467 )
3468 .unwrap();
3469 })?;
3470 Ok(acp::PromptResponse {
3471 stop_reason: acp::StopReason::EndTurn,
3472 meta: None,
3473 })
3474 }
3475 .boxed_local()
3476 }
3477 }));
3478 let thread = cx
3479 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3480 .await
3481 .unwrap();
3482
3483 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3484 .await
3485 .unwrap();
3486 thread.read_with(cx, |thread, cx| {
3487 assert_eq!(
3488 thread.to_markdown(cx),
3489 indoc! {"
3490 ## User
3491
3492 hello
3493
3494 ## Assistant
3495
3496 HELLO
3497
3498 "}
3499 );
3500 });
3501
3502 // Simulate refusing the second message. The message should be truncated
3503 // when a user prompt is refused.
3504 refuse_next.store(true, SeqCst);
3505 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3506 .await
3507 .unwrap();
3508 thread.read_with(cx, |thread, cx| {
3509 assert_eq!(
3510 thread.to_markdown(cx),
3511 indoc! {"
3512 ## User
3513
3514 hello
3515
3516 ## Assistant
3517
3518 HELLO
3519
3520 "}
3521 );
3522 });
3523 }
3524
3525 async fn run_until_first_tool_call(
3526 thread: &Entity<AcpThread>,
3527 cx: &mut TestAppContext,
3528 ) -> usize {
3529 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3530
3531 let subscription = cx.update(|cx| {
3532 cx.subscribe(thread, move |thread, _, cx| {
3533 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3534 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3535 return tx.try_send(ix).unwrap();
3536 }
3537 }
3538 })
3539 });
3540
3541 select! {
3542 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3543 panic!("Timeout waiting for tool call")
3544 }
3545 ix = rx.next().fuse() => {
3546 drop(subscription);
3547 ix.unwrap()
3548 }
3549 }
3550 }
3551
3552 #[derive(Clone, Default)]
3553 struct FakeAgentConnection {
3554 auth_methods: Vec<acp::AuthMethod>,
3555 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3556 on_user_message: Option<
3557 Rc<
3558 dyn Fn(
3559 acp::PromptRequest,
3560 WeakEntity<AcpThread>,
3561 AsyncApp,
3562 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3563 + 'static,
3564 >,
3565 >,
3566 }
3567
3568 impl FakeAgentConnection {
3569 fn new() -> Self {
3570 Self {
3571 auth_methods: Vec::new(),
3572 on_user_message: None,
3573 sessions: Arc::default(),
3574 }
3575 }
3576
3577 #[expect(unused)]
3578 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3579 self.auth_methods = auth_methods;
3580 self
3581 }
3582
3583 fn on_user_message(
3584 mut self,
3585 handler: impl Fn(
3586 acp::PromptRequest,
3587 WeakEntity<AcpThread>,
3588 AsyncApp,
3589 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3590 + 'static,
3591 ) -> Self {
3592 self.on_user_message.replace(Rc::new(handler));
3593 self
3594 }
3595 }
3596
3597 impl AgentConnection for FakeAgentConnection {
3598 fn auth_methods(&self) -> &[acp::AuthMethod] {
3599 &self.auth_methods
3600 }
3601
3602 fn new_thread(
3603 self: Rc<Self>,
3604 project: Entity<Project>,
3605 _cwd: &Path,
3606 cx: &mut App,
3607 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3608 let session_id = acp::SessionId(
3609 rand::rng()
3610 .sample_iter(&distr::Alphanumeric)
3611 .take(7)
3612 .map(char::from)
3613 .collect::<String>()
3614 .into(),
3615 );
3616 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3617 let thread = cx.new(|cx| {
3618 AcpThread::new(
3619 "Test",
3620 self.clone(),
3621 project,
3622 action_log,
3623 session_id.clone(),
3624 watch::Receiver::constant(acp::PromptCapabilities {
3625 image: true,
3626 audio: true,
3627 embedded_context: true,
3628 meta: None,
3629 }),
3630 cx,
3631 )
3632 });
3633 self.sessions.lock().insert(session_id, thread.downgrade());
3634 Task::ready(Ok(thread))
3635 }
3636
3637 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3638 if self.auth_methods().iter().any(|m| m.id == method) {
3639 Task::ready(Ok(()))
3640 } else {
3641 Task::ready(Err(anyhow!("Invalid Auth Method")))
3642 }
3643 }
3644
3645 fn prompt(
3646 &self,
3647 _id: Option<UserMessageId>,
3648 params: acp::PromptRequest,
3649 cx: &mut App,
3650 ) -> Task<gpui::Result<acp::PromptResponse>> {
3651 let sessions = self.sessions.lock();
3652 let thread = sessions.get(¶ms.session_id).unwrap();
3653 if let Some(handler) = &self.on_user_message {
3654 let handler = handler.clone();
3655 let thread = thread.clone();
3656 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3657 } else {
3658 Task::ready(Ok(acp::PromptResponse {
3659 stop_reason: acp::StopReason::EndTurn,
3660 meta: None,
3661 }))
3662 }
3663 }
3664
3665 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3666 let sessions = self.sessions.lock();
3667 let thread = sessions.get(session_id).unwrap().clone();
3668
3669 cx.spawn(async move |cx| {
3670 thread
3671 .update(cx, |thread, cx| thread.cancel(cx))
3672 .unwrap()
3673 .await
3674 })
3675 .detach();
3676 }
3677
3678 fn truncate(
3679 &self,
3680 session_id: &acp::SessionId,
3681 _cx: &App,
3682 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3683 Some(Rc::new(FakeAgentSessionEditor {
3684 _session_id: session_id.clone(),
3685 }))
3686 }
3687
3688 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3689 self
3690 }
3691 }
3692
3693 struct FakeAgentSessionEditor {
3694 _session_id: acp::SessionId,
3695 }
3696
3697 impl AgentSessionTruncate for FakeAgentSessionEditor {
3698 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3699 Task::ready(Ok(()))
3700 }
3701 }
3702
3703 #[gpui::test]
3704 async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3705 init_test(cx);
3706
3707 let fs = FakeFs::new(cx.executor());
3708 let project = Project::test(fs, [], cx).await;
3709 let connection = Rc::new(FakeAgentConnection::new());
3710 let thread = cx
3711 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3712 .await
3713 .unwrap();
3714
3715 // Try to update a tool call that doesn't exist
3716 let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3717 thread.update(cx, |thread, cx| {
3718 let result = thread.handle_session_update(
3719 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3720 id: nonexistent_id.clone(),
3721 fields: acp::ToolCallUpdateFields {
3722 status: Some(acp::ToolCallStatus::Completed),
3723 ..Default::default()
3724 },
3725 meta: None,
3726 }),
3727 cx,
3728 );
3729
3730 // The update should succeed (not return an error)
3731 assert!(result.is_ok());
3732
3733 // There should now be exactly one entry in the thread
3734 assert_eq!(thread.entries.len(), 1);
3735
3736 // The entry should be a failed tool call
3737 if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3738 assert_eq!(tool_call.id, nonexistent_id);
3739 assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3740 assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3741
3742 // Check that the content contains the error message
3743 assert_eq!(tool_call.content.len(), 1);
3744 if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3745 match content_block {
3746 ContentBlock::Markdown { markdown } => {
3747 let markdown_text = markdown.read(cx).source();
3748 assert!(markdown_text.contains("Tool call not found"));
3749 }
3750 ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3751 ContentBlock::ResourceLink { .. } => {
3752 panic!("Expected markdown content, got resource link")
3753 }
3754 }
3755 } else {
3756 panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3757 }
3758 } else {
3759 panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3760 }
3761 });
3762 }
3763}