1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use futures::future::Shared;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path,
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
789 determine_shell: Shared<Task<String>>,
790 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
791}
792
793#[derive(Debug)]
794pub enum AcpThreadEvent {
795 NewEntry,
796 TitleUpdated,
797 TokenUsageUpdated,
798 EntryUpdated(usize),
799 EntriesRemoved(Range<usize>),
800 ToolAuthorizationRequired,
801 Retry(RetryStatus),
802 Stopped,
803 Error,
804 LoadError(LoadError),
805 PromptCapabilitiesUpdated,
806 Refusal,
807 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
808 ModeUpdated(acp::SessionModeId),
809}
810
811impl EventEmitter<AcpThreadEvent> for AcpThread {}
812
813#[derive(PartialEq, Eq, Debug)]
814pub enum ThreadStatus {
815 Idle,
816 Generating,
817}
818
819#[derive(Debug, Clone)]
820pub enum LoadError {
821 Unsupported {
822 command: SharedString,
823 current_version: SharedString,
824 minimum_version: SharedString,
825 },
826 FailedToInstall(SharedString),
827 Exited {
828 status: ExitStatus,
829 },
830 Other(SharedString),
831}
832
833impl Display for LoadError {
834 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
835 match self {
836 LoadError::Unsupported {
837 command: path,
838 current_version,
839 minimum_version,
840 } => {
841 write!(
842 f,
843 "version {current_version} from {path} is not supported (need at least {minimum_version})"
844 )
845 }
846 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
847 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
848 LoadError::Other(msg) => write!(f, "{msg}"),
849 }
850 }
851}
852
853impl Error for LoadError {}
854
855impl AcpThread {
856 pub fn new(
857 title: impl Into<SharedString>,
858 connection: Rc<dyn AgentConnection>,
859 project: Entity<Project>,
860 action_log: Entity<ActionLog>,
861 session_id: acp::SessionId,
862 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
863 cx: &mut Context<Self>,
864 ) -> Self {
865 let prompt_capabilities = *prompt_capabilities_rx.borrow();
866 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
867 loop {
868 let caps = prompt_capabilities_rx.recv().await?;
869 this.update(cx, |this, cx| {
870 this.prompt_capabilities = caps;
871 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
872 })?;
873 }
874 });
875
876 let determine_shell = cx
877 .background_spawn(async move {
878 if cfg!(windows) {
879 return get_system_shell();
880 }
881
882 if which::which("bash").is_ok() {
883 "bash".into()
884 } else {
885 get_system_shell()
886 }
887 })
888 .shared();
889
890 Self {
891 action_log,
892 shared_buffers: Default::default(),
893 entries: Default::default(),
894 plan: Default::default(),
895 title: title.into(),
896 project,
897 send_task: None,
898 connection,
899 session_id,
900 token_usage: None,
901 prompt_capabilities,
902 _observe_prompt_capabilities: task,
903 terminals: HashMap::default(),
904 determine_shell,
905 }
906 }
907
908 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
909 self.prompt_capabilities
910 }
911
912 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
913 &self.connection
914 }
915
916 pub fn action_log(&self) -> &Entity<ActionLog> {
917 &self.action_log
918 }
919
920 pub fn project(&self) -> &Entity<Project> {
921 &self.project
922 }
923
924 pub fn title(&self) -> SharedString {
925 self.title.clone()
926 }
927
928 pub fn entries(&self) -> &[AgentThreadEntry] {
929 &self.entries
930 }
931
932 pub fn session_id(&self) -> &acp::SessionId {
933 &self.session_id
934 }
935
936 pub fn status(&self) -> ThreadStatus {
937 if self.send_task.is_some() {
938 ThreadStatus::Generating
939 } else {
940 ThreadStatus::Idle
941 }
942 }
943
944 pub fn token_usage(&self) -> Option<&TokenUsage> {
945 self.token_usage.as_ref()
946 }
947
948 pub fn has_pending_edit_tool_calls(&self) -> bool {
949 for entry in self.entries.iter().rev() {
950 match entry {
951 AgentThreadEntry::UserMessage(_) => return false,
952 AgentThreadEntry::ToolCall(
953 call @ ToolCall {
954 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
955 ..
956 },
957 ) if call.diffs().next().is_some() => {
958 return true;
959 }
960 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
961 }
962 }
963
964 false
965 }
966
967 pub fn used_tools_since_last_user_message(&self) -> bool {
968 for entry in self.entries.iter().rev() {
969 match entry {
970 AgentThreadEntry::UserMessage(..) => return false,
971 AgentThreadEntry::AssistantMessage(..) => continue,
972 AgentThreadEntry::ToolCall(..) => return true,
973 }
974 }
975
976 false
977 }
978
979 pub fn handle_session_update(
980 &mut self,
981 update: acp::SessionUpdate,
982 cx: &mut Context<Self>,
983 ) -> Result<(), acp::Error> {
984 match update {
985 acp::SessionUpdate::UserMessageChunk { content } => {
986 self.push_user_content_block(None, content, cx);
987 }
988 acp::SessionUpdate::AgentMessageChunk { content } => {
989 self.push_assistant_content_block(content, false, cx);
990 }
991 acp::SessionUpdate::AgentThoughtChunk { content } => {
992 self.push_assistant_content_block(content, true, cx);
993 }
994 acp::SessionUpdate::ToolCall(tool_call) => {
995 self.upsert_tool_call(tool_call, cx)?;
996 }
997 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
998 self.update_tool_call(tool_call_update, cx)?;
999 }
1000 acp::SessionUpdate::Plan(plan) => {
1001 self.update_plan(plan, cx);
1002 }
1003 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
1004 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
1005 }
1006 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
1007 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
1008 }
1009 }
1010 Ok(())
1011 }
1012
1013 pub fn push_user_content_block(
1014 &mut self,
1015 message_id: Option<UserMessageId>,
1016 chunk: acp::ContentBlock,
1017 cx: &mut Context<Self>,
1018 ) {
1019 let language_registry = self.project.read(cx).languages().clone();
1020 let entries_len = self.entries.len();
1021
1022 if let Some(last_entry) = self.entries.last_mut()
1023 && let AgentThreadEntry::UserMessage(UserMessage {
1024 id,
1025 content,
1026 chunks,
1027 ..
1028 }) = last_entry
1029 {
1030 *id = message_id.or(id.take());
1031 content.append(chunk.clone(), &language_registry, cx);
1032 chunks.push(chunk);
1033 let idx = entries_len - 1;
1034 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1035 } else {
1036 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1037 self.push_entry(
1038 AgentThreadEntry::UserMessage(UserMessage {
1039 id: message_id,
1040 content,
1041 chunks: vec![chunk],
1042 checkpoint: None,
1043 }),
1044 cx,
1045 );
1046 }
1047 }
1048
1049 pub fn push_assistant_content_block(
1050 &mut self,
1051 chunk: acp::ContentBlock,
1052 is_thought: bool,
1053 cx: &mut Context<Self>,
1054 ) {
1055 let language_registry = self.project.read(cx).languages().clone();
1056 let entries_len = self.entries.len();
1057 if let Some(last_entry) = self.entries.last_mut()
1058 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1059 {
1060 let idx = entries_len - 1;
1061 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1062 match (chunks.last_mut(), is_thought) {
1063 (Some(AssistantMessageChunk::Message { block }), false)
1064 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1065 block.append(chunk, &language_registry, cx)
1066 }
1067 _ => {
1068 let block = ContentBlock::new(chunk, &language_registry, cx);
1069 if is_thought {
1070 chunks.push(AssistantMessageChunk::Thought { block })
1071 } else {
1072 chunks.push(AssistantMessageChunk::Message { block })
1073 }
1074 }
1075 }
1076 } else {
1077 let block = ContentBlock::new(chunk, &language_registry, cx);
1078 let chunk = if is_thought {
1079 AssistantMessageChunk::Thought { block }
1080 } else {
1081 AssistantMessageChunk::Message { block }
1082 };
1083
1084 self.push_entry(
1085 AgentThreadEntry::AssistantMessage(AssistantMessage {
1086 chunks: vec![chunk],
1087 }),
1088 cx,
1089 );
1090 }
1091 }
1092
1093 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1094 self.entries.push(entry);
1095 cx.emit(AcpThreadEvent::NewEntry);
1096 }
1097
1098 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1099 self.connection.set_title(&self.session_id, cx).is_some()
1100 }
1101
1102 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1103 if title != self.title {
1104 self.title = title.clone();
1105 cx.emit(AcpThreadEvent::TitleUpdated);
1106 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1107 return set_title.run(title, cx);
1108 }
1109 }
1110 Task::ready(Ok(()))
1111 }
1112
1113 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1114 self.token_usage = usage;
1115 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1116 }
1117
1118 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1119 cx.emit(AcpThreadEvent::Retry(status));
1120 }
1121
1122 pub fn update_tool_call(
1123 &mut self,
1124 update: impl Into<ToolCallUpdate>,
1125 cx: &mut Context<Self>,
1126 ) -> Result<()> {
1127 let update = update.into();
1128 let languages = self.project.read(cx).languages().clone();
1129
1130 let ix = self
1131 .index_for_tool_call(update.id())
1132 .context("Tool call not found")?;
1133 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1134 unreachable!()
1135 };
1136
1137 match update {
1138 ToolCallUpdate::UpdateFields(update) => {
1139 let location_updated = update.fields.locations.is_some();
1140 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1141 if location_updated {
1142 self.resolve_locations(update.id, cx);
1143 }
1144 }
1145 ToolCallUpdate::UpdateDiff(update) => {
1146 call.content.clear();
1147 call.content.push(ToolCallContent::Diff(update.diff));
1148 }
1149 ToolCallUpdate::UpdateTerminal(update) => {
1150 call.content.clear();
1151 call.content
1152 .push(ToolCallContent::Terminal(update.terminal));
1153 }
1154 }
1155
1156 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1157
1158 Ok(())
1159 }
1160
1161 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1162 pub fn upsert_tool_call(
1163 &mut self,
1164 tool_call: acp::ToolCall,
1165 cx: &mut Context<Self>,
1166 ) -> Result<(), acp::Error> {
1167 let status = tool_call.status.into();
1168 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1169 }
1170
1171 /// Fails if id does not match an existing entry.
1172 pub fn upsert_tool_call_inner(
1173 &mut self,
1174 update: acp::ToolCallUpdate,
1175 status: ToolCallStatus,
1176 cx: &mut Context<Self>,
1177 ) -> Result<(), acp::Error> {
1178 let language_registry = self.project.read(cx).languages().clone();
1179 let id = update.id.clone();
1180
1181 if let Some(ix) = self.index_for_tool_call(&id) {
1182 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1183 unreachable!()
1184 };
1185
1186 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1187 call.status = status;
1188
1189 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1190 } else {
1191 let call = ToolCall::from_acp(
1192 update.try_into()?,
1193 status,
1194 language_registry,
1195 &self.terminals,
1196 cx,
1197 )?;
1198 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1199 };
1200
1201 self.resolve_locations(id, cx);
1202 Ok(())
1203 }
1204
1205 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1206 self.entries
1207 .iter()
1208 .enumerate()
1209 .rev()
1210 .find_map(|(index, entry)| {
1211 if let AgentThreadEntry::ToolCall(tool_call) = entry
1212 && &tool_call.id == id
1213 {
1214 Some(index)
1215 } else {
1216 None
1217 }
1218 })
1219 }
1220
1221 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1222 // The tool call we are looking for is typically the last one, or very close to the end.
1223 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1224 self.entries
1225 .iter_mut()
1226 .enumerate()
1227 .rev()
1228 .find_map(|(index, tool_call)| {
1229 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1230 && &tool_call.id == id
1231 {
1232 Some((index, tool_call))
1233 } else {
1234 None
1235 }
1236 })
1237 }
1238
1239 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1240 self.entries
1241 .iter()
1242 .enumerate()
1243 .rev()
1244 .find_map(|(index, tool_call)| {
1245 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1246 && &tool_call.id == id
1247 {
1248 Some((index, tool_call))
1249 } else {
1250 None
1251 }
1252 })
1253 }
1254
1255 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1256 let project = self.project.clone();
1257 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1258 return;
1259 };
1260 let task = tool_call.resolve_locations(project, cx);
1261 cx.spawn(async move |this, cx| {
1262 let resolved_locations = task.await;
1263 this.update(cx, |this, cx| {
1264 let project = this.project.clone();
1265 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1266 return;
1267 };
1268 if let Some(Some(location)) = resolved_locations.last() {
1269 project.update(cx, |project, cx| {
1270 if let Some(agent_location) = project.agent_location() {
1271 let should_ignore = agent_location.buffer == location.buffer
1272 && location
1273 .buffer
1274 .update(cx, |buffer, _| {
1275 let snapshot = buffer.snapshot();
1276 let old_position =
1277 agent_location.position.to_point(&snapshot);
1278 let new_position = location.position.to_point(&snapshot);
1279 // ignore this so that when we get updates from the edit tool
1280 // the position doesn't reset to the startof line
1281 old_position.row == new_position.row
1282 && old_position.column > new_position.column
1283 })
1284 .ok()
1285 .unwrap_or_default();
1286 if !should_ignore {
1287 project.set_agent_location(Some(location.clone()), cx);
1288 }
1289 }
1290 });
1291 }
1292 if tool_call.resolved_locations != resolved_locations {
1293 tool_call.resolved_locations = resolved_locations;
1294 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1295 }
1296 })
1297 })
1298 .detach();
1299 }
1300
1301 pub fn request_tool_call_authorization(
1302 &mut self,
1303 tool_call: acp::ToolCallUpdate,
1304 options: Vec<acp::PermissionOption>,
1305 respect_always_allow_setting: bool,
1306 cx: &mut Context<Self>,
1307 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1308 let (tx, rx) = oneshot::channel();
1309
1310 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1311 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1312 // some tools would (incorrectly) continue to auto-accept.
1313 if let Some(allow_once_option) = options.iter().find_map(|option| {
1314 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1315 Some(option.id.clone())
1316 } else {
1317 None
1318 }
1319 }) {
1320 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1321 return Ok(async {
1322 acp::RequestPermissionOutcome::Selected {
1323 option_id: allow_once_option,
1324 }
1325 }
1326 .boxed());
1327 }
1328 }
1329
1330 let status = ToolCallStatus::WaitingForConfirmation {
1331 options,
1332 respond_tx: tx,
1333 };
1334
1335 self.upsert_tool_call_inner(tool_call, status, cx)?;
1336 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1337
1338 let fut = async {
1339 match rx.await {
1340 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1341 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1342 }
1343 }
1344 .boxed();
1345
1346 Ok(fut)
1347 }
1348
1349 pub fn authorize_tool_call(
1350 &mut self,
1351 id: acp::ToolCallId,
1352 option_id: acp::PermissionOptionId,
1353 option_kind: acp::PermissionOptionKind,
1354 cx: &mut Context<Self>,
1355 ) {
1356 let Some((ix, call)) = self.tool_call_mut(&id) else {
1357 return;
1358 };
1359
1360 let new_status = match option_kind {
1361 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1362 ToolCallStatus::Rejected
1363 }
1364 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1365 ToolCallStatus::InProgress
1366 }
1367 };
1368
1369 let curr_status = mem::replace(&mut call.status, new_status);
1370
1371 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1372 respond_tx.send(option_id).log_err();
1373 } else if cfg!(debug_assertions) {
1374 panic!("tried to authorize an already authorized tool call");
1375 }
1376
1377 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1378 }
1379
1380 pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1381 let mut first_tool_call = None;
1382
1383 for entry in self.entries.iter().rev() {
1384 match &entry {
1385 AgentThreadEntry::ToolCall(call) => {
1386 if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1387 first_tool_call = Some(call);
1388 } else {
1389 continue;
1390 }
1391 }
1392 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1393 // Reached the beginning of the turn.
1394 // If we had pending permission requests in the previous turn, they have been cancelled.
1395 break;
1396 }
1397 }
1398 }
1399
1400 first_tool_call
1401 }
1402
1403 pub fn plan(&self) -> &Plan {
1404 &self.plan
1405 }
1406
1407 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1408 let new_entries_len = request.entries.len();
1409 let mut new_entries = request.entries.into_iter();
1410
1411 // Reuse existing markdown to prevent flickering
1412 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1413 let PlanEntry {
1414 content,
1415 priority,
1416 status,
1417 } = old;
1418 content.update(cx, |old, cx| {
1419 old.replace(new.content, cx);
1420 });
1421 *priority = new.priority;
1422 *status = new.status;
1423 }
1424 for new in new_entries {
1425 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1426 }
1427 self.plan.entries.truncate(new_entries_len);
1428
1429 cx.notify();
1430 }
1431
1432 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1433 self.plan
1434 .entries
1435 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1436 cx.notify();
1437 }
1438
1439 #[cfg(any(test, feature = "test-support"))]
1440 pub fn send_raw(
1441 &mut self,
1442 message: &str,
1443 cx: &mut Context<Self>,
1444 ) -> BoxFuture<'static, Result<()>> {
1445 self.send(
1446 vec![acp::ContentBlock::Text(acp::TextContent {
1447 text: message.to_string(),
1448 annotations: None,
1449 })],
1450 cx,
1451 )
1452 }
1453
1454 pub fn send(
1455 &mut self,
1456 message: Vec<acp::ContentBlock>,
1457 cx: &mut Context<Self>,
1458 ) -> BoxFuture<'static, Result<()>> {
1459 let block = ContentBlock::new_combined(
1460 message.clone(),
1461 self.project.read(cx).languages().clone(),
1462 cx,
1463 );
1464 let request = acp::PromptRequest {
1465 prompt: message.clone(),
1466 session_id: self.session_id.clone(),
1467 };
1468 let git_store = self.project.read(cx).git_store().clone();
1469
1470 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1471 Some(UserMessageId::new())
1472 } else {
1473 None
1474 };
1475
1476 self.run_turn(cx, async move |this, cx| {
1477 this.update(cx, |this, cx| {
1478 this.push_entry(
1479 AgentThreadEntry::UserMessage(UserMessage {
1480 id: message_id.clone(),
1481 content: block,
1482 chunks: message,
1483 checkpoint: None,
1484 }),
1485 cx,
1486 );
1487 })
1488 .ok();
1489
1490 let old_checkpoint = git_store
1491 .update(cx, |git, cx| git.checkpoint(cx))?
1492 .await
1493 .context("failed to get old checkpoint")
1494 .log_err();
1495 this.update(cx, |this, cx| {
1496 if let Some((_ix, message)) = this.last_user_message() {
1497 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1498 git_checkpoint,
1499 show: false,
1500 });
1501 }
1502 this.connection.prompt(message_id, request, cx)
1503 })?
1504 .await
1505 })
1506 }
1507
1508 pub fn can_resume(&self, cx: &App) -> bool {
1509 self.connection.resume(&self.session_id, cx).is_some()
1510 }
1511
1512 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1513 self.run_turn(cx, async move |this, cx| {
1514 this.update(cx, |this, cx| {
1515 this.connection
1516 .resume(&this.session_id, cx)
1517 .map(|resume| resume.run(cx))
1518 })?
1519 .context("resuming a session is not supported")?
1520 .await
1521 })
1522 }
1523
1524 fn run_turn(
1525 &mut self,
1526 cx: &mut Context<Self>,
1527 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1528 ) -> BoxFuture<'static, Result<()>> {
1529 self.clear_completed_plan_entries(cx);
1530
1531 let (tx, rx) = oneshot::channel();
1532 let cancel_task = self.cancel(cx);
1533
1534 self.send_task = Some(cx.spawn(async move |this, cx| {
1535 cancel_task.await;
1536 tx.send(f(this, cx).await).ok();
1537 }));
1538
1539 cx.spawn(async move |this, cx| {
1540 let response = rx.await;
1541
1542 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1543 .await?;
1544
1545 this.update(cx, |this, cx| {
1546 this.project
1547 .update(cx, |project, cx| project.set_agent_location(None, cx));
1548 match response {
1549 Ok(Err(e)) => {
1550 this.send_task.take();
1551 cx.emit(AcpThreadEvent::Error);
1552 Err(e)
1553 }
1554 result => {
1555 let canceled = matches!(
1556 result,
1557 Ok(Ok(acp::PromptResponse {
1558 stop_reason: acp::StopReason::Cancelled
1559 }))
1560 );
1561
1562 // We only take the task if the current prompt wasn't canceled.
1563 //
1564 // This prompt may have been canceled because another one was sent
1565 // while it was still generating. In these cases, dropping `send_task`
1566 // would cause the next generation to be canceled.
1567 if !canceled {
1568 this.send_task.take();
1569 }
1570
1571 // Handle refusal - distinguish between user prompt and tool call refusals
1572 if let Ok(Ok(acp::PromptResponse {
1573 stop_reason: acp::StopReason::Refusal,
1574 })) = result
1575 {
1576 if let Some((user_msg_ix, _)) = this.last_user_message() {
1577 // Check if there's a completed tool call with results after the last user message
1578 // This indicates the refusal is in response to tool output, not the user's prompt
1579 let has_completed_tool_call_after_user_msg =
1580 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1581 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1582 // Check if the tool call has completed and has output
1583 matches!(tool_call.status, ToolCallStatus::Completed)
1584 && tool_call.raw_output.is_some()
1585 } else {
1586 false
1587 }
1588 });
1589
1590 if has_completed_tool_call_after_user_msg {
1591 // Refusal is due to tool output - don't truncate, just notify
1592 // The model refused based on what the tool returned
1593 cx.emit(AcpThreadEvent::Refusal);
1594 } else {
1595 // User prompt was refused - truncate back to before the user message
1596 let range = user_msg_ix..this.entries.len();
1597 if range.start < range.end {
1598 this.entries.truncate(user_msg_ix);
1599 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1600 }
1601 cx.emit(AcpThreadEvent::Refusal);
1602 }
1603 } else {
1604 // No user message found, treat as general refusal
1605 cx.emit(AcpThreadEvent::Refusal);
1606 }
1607 }
1608
1609 cx.emit(AcpThreadEvent::Stopped);
1610 Ok(())
1611 }
1612 }
1613 })?
1614 })
1615 .boxed()
1616 }
1617
1618 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1619 let Some(send_task) = self.send_task.take() else {
1620 return Task::ready(());
1621 };
1622
1623 for entry in self.entries.iter_mut() {
1624 if let AgentThreadEntry::ToolCall(call) = entry {
1625 let cancel = matches!(
1626 call.status,
1627 ToolCallStatus::Pending
1628 | ToolCallStatus::WaitingForConfirmation { .. }
1629 | ToolCallStatus::InProgress
1630 );
1631
1632 if cancel {
1633 call.status = ToolCallStatus::Canceled;
1634 }
1635 }
1636 }
1637
1638 self.connection.cancel(&self.session_id, cx);
1639
1640 // Wait for the send task to complete
1641 cx.foreground_executor().spawn(send_task)
1642 }
1643
1644 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1645 pub fn restore_checkpoint(
1646 &mut self,
1647 id: UserMessageId,
1648 cx: &mut Context<Self>,
1649 ) -> Task<Result<()>> {
1650 let Some((_, message)) = self.user_message_mut(&id) else {
1651 return Task::ready(Err(anyhow!("message not found")));
1652 };
1653
1654 let checkpoint = message
1655 .checkpoint
1656 .as_ref()
1657 .map(|c| c.git_checkpoint.clone());
1658 let rewind = self.rewind(id.clone(), cx);
1659 let git_store = self.project.read(cx).git_store().clone();
1660
1661 cx.spawn(async move |_, cx| {
1662 rewind.await?;
1663 if let Some(checkpoint) = checkpoint {
1664 git_store
1665 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1666 .await?;
1667 }
1668
1669 Ok(())
1670 })
1671 }
1672
1673 /// Rewinds this thread to before the entry at `index`, removing it and all
1674 /// subsequent entries while rejecting any action_log changes made from that point.
1675 /// Unlike `restore_checkpoint`, this method does not restore from git.
1676 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1677 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1678 return Task::ready(Err(anyhow!("not supported")));
1679 };
1680
1681 cx.spawn(async move |this, cx| {
1682 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1683 this.update(cx, |this, cx| {
1684 if let Some((ix, _)) = this.user_message_mut(&id) {
1685 let range = ix..this.entries.len();
1686 this.entries.truncate(ix);
1687 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1688 }
1689 this.action_log()
1690 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1691 })?
1692 .await;
1693 Ok(())
1694 })
1695 }
1696
1697 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1698 let git_store = self.project.read(cx).git_store().clone();
1699
1700 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1701 if let Some(checkpoint) = message.checkpoint.as_ref() {
1702 checkpoint.git_checkpoint.clone()
1703 } else {
1704 return Task::ready(Ok(()));
1705 }
1706 } else {
1707 return Task::ready(Ok(()));
1708 };
1709
1710 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1711 cx.spawn(async move |this, cx| {
1712 let new_checkpoint = new_checkpoint
1713 .await
1714 .context("failed to get new checkpoint")
1715 .log_err();
1716 if let Some(new_checkpoint) = new_checkpoint {
1717 let equal = git_store
1718 .update(cx, |git, cx| {
1719 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1720 })?
1721 .await
1722 .unwrap_or(true);
1723 this.update(cx, |this, cx| {
1724 let (ix, message) = this.last_user_message().context("no user message")?;
1725 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1726 checkpoint.show = !equal;
1727 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1728 anyhow::Ok(())
1729 })??;
1730 }
1731
1732 Ok(())
1733 })
1734 }
1735
1736 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1737 self.entries
1738 .iter_mut()
1739 .enumerate()
1740 .rev()
1741 .find_map(|(ix, entry)| {
1742 if let AgentThreadEntry::UserMessage(message) = entry {
1743 Some((ix, message))
1744 } else {
1745 None
1746 }
1747 })
1748 }
1749
1750 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1751 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1752 if let AgentThreadEntry::UserMessage(message) = entry {
1753 if message.id.as_ref() == Some(id) {
1754 Some((ix, message))
1755 } else {
1756 None
1757 }
1758 } else {
1759 None
1760 }
1761 })
1762 }
1763
1764 pub fn read_text_file(
1765 &self,
1766 path: PathBuf,
1767 line: Option<u32>,
1768 limit: Option<u32>,
1769 reuse_shared_snapshot: bool,
1770 cx: &mut Context<Self>,
1771 ) -> Task<Result<String>> {
1772 // Args are 1-based, move to 0-based
1773 let line = line.unwrap_or_default().saturating_sub(1);
1774 let limit = limit.unwrap_or(u32::MAX);
1775 let project = self.project.clone();
1776 let action_log = self.action_log.clone();
1777 cx.spawn(async move |this, cx| {
1778 let load = project.update(cx, |project, cx| {
1779 let path = project
1780 .project_path_for_absolute_path(&path, cx)
1781 .context("invalid path")?;
1782 anyhow::Ok(project.open_buffer(path, cx))
1783 });
1784 let buffer = load??.await?;
1785
1786 let snapshot = if reuse_shared_snapshot {
1787 this.read_with(cx, |this, _| {
1788 this.shared_buffers.get(&buffer.clone()).cloned()
1789 })
1790 .log_err()
1791 .flatten()
1792 } else {
1793 None
1794 };
1795
1796 let snapshot = if let Some(snapshot) = snapshot {
1797 snapshot
1798 } else {
1799 action_log.update(cx, |action_log, cx| {
1800 action_log.buffer_read(buffer.clone(), cx);
1801 })?;
1802
1803 let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1804 this.update(cx, |this, _| {
1805 this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1806 })?;
1807 snapshot
1808 };
1809
1810 let max_point = snapshot.max_point();
1811 if line >= max_point.row {
1812 anyhow::bail!(
1813 "Attempting to read beyond the end of the file, line {}:{}",
1814 max_point.row + 1,
1815 max_point.column
1816 );
1817 }
1818
1819 let start = snapshot.anchor_before(Point::new(line, 0));
1820 let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1821
1822 project.update(cx, |project, cx| {
1823 project.set_agent_location(
1824 Some(AgentLocation {
1825 buffer: buffer.downgrade(),
1826 position: start,
1827 }),
1828 cx,
1829 );
1830 })?;
1831
1832 Ok(snapshot.text_for_range(start..end).collect::<String>())
1833 })
1834 }
1835
1836 pub fn write_text_file(
1837 &self,
1838 path: PathBuf,
1839 content: String,
1840 cx: &mut Context<Self>,
1841 ) -> Task<Result<()>> {
1842 let project = self.project.clone();
1843 let action_log = self.action_log.clone();
1844 cx.spawn(async move |this, cx| {
1845 let load = project.update(cx, |project, cx| {
1846 let path = project
1847 .project_path_for_absolute_path(&path, cx)
1848 .context("invalid path")?;
1849 anyhow::Ok(project.open_buffer(path, cx))
1850 });
1851 let buffer = load??.await?;
1852 let snapshot = this.update(cx, |this, cx| {
1853 this.shared_buffers
1854 .get(&buffer)
1855 .cloned()
1856 .unwrap_or_else(|| buffer.read(cx).snapshot())
1857 })?;
1858 let edits = cx
1859 .background_executor()
1860 .spawn(async move {
1861 let old_text = snapshot.text();
1862 text_diff(old_text.as_str(), &content)
1863 .into_iter()
1864 .map(|(range, replacement)| {
1865 (
1866 snapshot.anchor_after(range.start)
1867 ..snapshot.anchor_before(range.end),
1868 replacement,
1869 )
1870 })
1871 .collect::<Vec<_>>()
1872 })
1873 .await;
1874
1875 project.update(cx, |project, cx| {
1876 project.set_agent_location(
1877 Some(AgentLocation {
1878 buffer: buffer.downgrade(),
1879 position: edits
1880 .last()
1881 .map(|(range, _)| range.end)
1882 .unwrap_or(Anchor::MIN),
1883 }),
1884 cx,
1885 );
1886 })?;
1887
1888 let format_on_save = cx.update(|cx| {
1889 action_log.update(cx, |action_log, cx| {
1890 action_log.buffer_read(buffer.clone(), cx);
1891 });
1892
1893 let format_on_save = buffer.update(cx, |buffer, cx| {
1894 buffer.edit(edits, None, cx);
1895
1896 let settings = language::language_settings::language_settings(
1897 buffer.language().map(|l| l.name()),
1898 buffer.file(),
1899 cx,
1900 );
1901
1902 settings.format_on_save != FormatOnSave::Off
1903 });
1904 action_log.update(cx, |action_log, cx| {
1905 action_log.buffer_edited(buffer.clone(), cx);
1906 });
1907 format_on_save
1908 })?;
1909
1910 if format_on_save {
1911 let format_task = project.update(cx, |project, cx| {
1912 project.format(
1913 HashSet::from_iter([buffer.clone()]),
1914 LspFormatTarget::Buffers,
1915 false,
1916 FormatTrigger::Save,
1917 cx,
1918 )
1919 })?;
1920 format_task.await.log_err();
1921
1922 action_log.update(cx, |action_log, cx| {
1923 action_log.buffer_edited(buffer.clone(), cx);
1924 })?;
1925 }
1926
1927 project
1928 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1929 .await
1930 })
1931 }
1932
1933 pub fn create_terminal(
1934 &self,
1935 mut command: String,
1936 args: Vec<String>,
1937 extra_env: Vec<acp::EnvVariable>,
1938 cwd: Option<PathBuf>,
1939 output_byte_limit: Option<u64>,
1940 cx: &mut Context<Self>,
1941 ) -> Task<Result<Entity<Terminal>>> {
1942 for arg in args {
1943 command.push(' ');
1944 command.push_str(&arg);
1945 }
1946
1947 let shell_command = if cfg!(windows) {
1948 format!("$null | & {{{}}}", command.replace("\"", "'"))
1949 } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1950 // Make sure once we're *inside* the shell, we cd into `cwd`
1951 format!("(cd {cwd}; {}) </dev/null", command)
1952 } else {
1953 format!("({}) </dev/null", command)
1954 };
1955 let args = vec!["-c".into(), shell_command];
1956
1957 let env = match &cwd {
1958 Some(dir) => self.project.update(cx, |project, cx| {
1959 project.directory_environment(dir.as_path().into(), cx)
1960 }),
1961 None => Task::ready(None).shared(),
1962 };
1963
1964 let env = cx.spawn(async move |_, _| {
1965 let mut env = env.await.unwrap_or_default();
1966 if cfg!(unix) {
1967 env.insert("PAGER".into(), "cat".into());
1968 }
1969 for var in extra_env {
1970 env.insert(var.name, var.value);
1971 }
1972 env
1973 });
1974
1975 let project = self.project.clone();
1976 let language_registry = project.read(cx).languages().clone();
1977 let determine_shell = self.determine_shell.clone();
1978
1979 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1980 let terminal_task = cx.spawn({
1981 let terminal_id = terminal_id.clone();
1982 async move |_this, cx| {
1983 let program = determine_shell.await;
1984 let env = env.await;
1985 let terminal = project
1986 .update(cx, |project, cx| {
1987 project.create_terminal_task(
1988 task::SpawnInTerminal {
1989 command: Some(program),
1990 args,
1991 cwd: cwd.clone(),
1992 env,
1993 ..Default::default()
1994 },
1995 cx,
1996 )
1997 })?
1998 .await?;
1999
2000 cx.new(|cx| {
2001 Terminal::new(
2002 terminal_id,
2003 command,
2004 cwd,
2005 output_byte_limit.map(|l| l as usize),
2006 terminal,
2007 language_registry,
2008 cx,
2009 )
2010 })
2011 }
2012 });
2013
2014 cx.spawn(async move |this, cx| {
2015 let terminal = terminal_task.await?;
2016 this.update(cx, |this, _cx| {
2017 this.terminals.insert(terminal_id, terminal.clone());
2018 terminal
2019 })
2020 })
2021 }
2022
2023 pub fn kill_terminal(
2024 &mut self,
2025 terminal_id: acp::TerminalId,
2026 cx: &mut Context<Self>,
2027 ) -> Result<()> {
2028 self.terminals
2029 .get(&terminal_id)
2030 .context("Terminal not found")?
2031 .update(cx, |terminal, cx| {
2032 terminal.kill(cx);
2033 });
2034
2035 Ok(())
2036 }
2037
2038 pub fn release_terminal(
2039 &mut self,
2040 terminal_id: acp::TerminalId,
2041 cx: &mut Context<Self>,
2042 ) -> Result<()> {
2043 self.terminals
2044 .remove(&terminal_id)
2045 .context("Terminal not found")?
2046 .update(cx, |terminal, cx| {
2047 terminal.kill(cx);
2048 });
2049
2050 Ok(())
2051 }
2052
2053 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2054 self.terminals
2055 .get(&terminal_id)
2056 .context("Terminal not found")
2057 .cloned()
2058 }
2059
2060 pub fn to_markdown(&self, cx: &App) -> String {
2061 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2062 }
2063
2064 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2065 cx.emit(AcpThreadEvent::LoadError(error));
2066 }
2067}
2068
2069fn markdown_for_raw_output(
2070 raw_output: &serde_json::Value,
2071 language_registry: &Arc<LanguageRegistry>,
2072 cx: &mut App,
2073) -> Option<Entity<Markdown>> {
2074 match raw_output {
2075 serde_json::Value::Null => None,
2076 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2077 Markdown::new(
2078 value.to_string().into(),
2079 Some(language_registry.clone()),
2080 None,
2081 cx,
2082 )
2083 })),
2084 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2085 Markdown::new(
2086 value.to_string().into(),
2087 Some(language_registry.clone()),
2088 None,
2089 cx,
2090 )
2091 })),
2092 serde_json::Value::String(value) => Some(cx.new(|cx| {
2093 Markdown::new(
2094 value.clone().into(),
2095 Some(language_registry.clone()),
2096 None,
2097 cx,
2098 )
2099 })),
2100 value => Some(cx.new(|cx| {
2101 Markdown::new(
2102 format!("```json\n{}\n```", value).into(),
2103 Some(language_registry.clone()),
2104 None,
2105 cx,
2106 )
2107 })),
2108 }
2109}
2110
2111#[cfg(test)]
2112mod tests {
2113 use super::*;
2114 use anyhow::anyhow;
2115 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2116 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2117 use indoc::indoc;
2118 use project::{FakeFs, Fs};
2119 use rand::{distr, prelude::*};
2120 use serde_json::json;
2121 use settings::SettingsStore;
2122 use smol::stream::StreamExt as _;
2123 use std::{
2124 any::Any,
2125 cell::RefCell,
2126 path::Path,
2127 rc::Rc,
2128 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2129 time::Duration,
2130 };
2131 use util::path;
2132
2133 fn init_test(cx: &mut TestAppContext) {
2134 env_logger::try_init().ok();
2135 cx.update(|cx| {
2136 let settings_store = SettingsStore::test(cx);
2137 cx.set_global(settings_store);
2138 Project::init_settings(cx);
2139 language::init(cx);
2140 });
2141 }
2142
2143 #[gpui::test]
2144 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2145 init_test(cx);
2146
2147 let fs = FakeFs::new(cx.executor());
2148 let project = Project::test(fs, [], cx).await;
2149 let connection = Rc::new(FakeAgentConnection::new());
2150 let thread = cx
2151 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2152 .await
2153 .unwrap();
2154
2155 // Test creating a new user message
2156 thread.update(cx, |thread, cx| {
2157 thread.push_user_content_block(
2158 None,
2159 acp::ContentBlock::Text(acp::TextContent {
2160 annotations: None,
2161 text: "Hello, ".to_string(),
2162 }),
2163 cx,
2164 );
2165 });
2166
2167 thread.update(cx, |thread, cx| {
2168 assert_eq!(thread.entries.len(), 1);
2169 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2170 assert_eq!(user_msg.id, None);
2171 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2172 } else {
2173 panic!("Expected UserMessage");
2174 }
2175 });
2176
2177 // Test appending to existing user message
2178 let message_1_id = UserMessageId::new();
2179 thread.update(cx, |thread, cx| {
2180 thread.push_user_content_block(
2181 Some(message_1_id.clone()),
2182 acp::ContentBlock::Text(acp::TextContent {
2183 annotations: None,
2184 text: "world!".to_string(),
2185 }),
2186 cx,
2187 );
2188 });
2189
2190 thread.update(cx, |thread, cx| {
2191 assert_eq!(thread.entries.len(), 1);
2192 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2193 assert_eq!(user_msg.id, Some(message_1_id));
2194 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2195 } else {
2196 panic!("Expected UserMessage");
2197 }
2198 });
2199
2200 // Test creating new user message after assistant message
2201 thread.update(cx, |thread, cx| {
2202 thread.push_assistant_content_block(
2203 acp::ContentBlock::Text(acp::TextContent {
2204 annotations: None,
2205 text: "Assistant response".to_string(),
2206 }),
2207 false,
2208 cx,
2209 );
2210 });
2211
2212 let message_2_id = UserMessageId::new();
2213 thread.update(cx, |thread, cx| {
2214 thread.push_user_content_block(
2215 Some(message_2_id.clone()),
2216 acp::ContentBlock::Text(acp::TextContent {
2217 annotations: None,
2218 text: "New user message".to_string(),
2219 }),
2220 cx,
2221 );
2222 });
2223
2224 thread.update(cx, |thread, cx| {
2225 assert_eq!(thread.entries.len(), 3);
2226 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2227 assert_eq!(user_msg.id, Some(message_2_id));
2228 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2229 } else {
2230 panic!("Expected UserMessage at index 2");
2231 }
2232 });
2233 }
2234
2235 #[gpui::test]
2236 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2237 init_test(cx);
2238
2239 let fs = FakeFs::new(cx.executor());
2240 let project = Project::test(fs, [], cx).await;
2241 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2242 |_, thread, mut cx| {
2243 async move {
2244 thread.update(&mut cx, |thread, cx| {
2245 thread
2246 .handle_session_update(
2247 acp::SessionUpdate::AgentThoughtChunk {
2248 content: "Thinking ".into(),
2249 },
2250 cx,
2251 )
2252 .unwrap();
2253 thread
2254 .handle_session_update(
2255 acp::SessionUpdate::AgentThoughtChunk {
2256 content: "hard!".into(),
2257 },
2258 cx,
2259 )
2260 .unwrap();
2261 })?;
2262 Ok(acp::PromptResponse {
2263 stop_reason: acp::StopReason::EndTurn,
2264 })
2265 }
2266 .boxed_local()
2267 },
2268 ));
2269
2270 let thread = cx
2271 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2272 .await
2273 .unwrap();
2274
2275 thread
2276 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2277 .await
2278 .unwrap();
2279
2280 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2281 assert_eq!(
2282 output,
2283 indoc! {r#"
2284 ## User
2285
2286 Hello from Zed!
2287
2288 ## Assistant
2289
2290 <thinking>
2291 Thinking hard!
2292 </thinking>
2293
2294 "#}
2295 );
2296 }
2297
2298 #[gpui::test]
2299 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2300 init_test(cx);
2301
2302 let fs = FakeFs::new(cx.executor());
2303 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2304 .await;
2305 let project = Project::test(fs.clone(), [], cx).await;
2306 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2307 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2308 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2309 move |_, thread, mut cx| {
2310 let read_file_tx = read_file_tx.clone();
2311 async move {
2312 let content = thread
2313 .update(&mut cx, |thread, cx| {
2314 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2315 })
2316 .unwrap()
2317 .await
2318 .unwrap();
2319 assert_eq!(content, "one\ntwo\nthree\n");
2320 read_file_tx.take().unwrap().send(()).unwrap();
2321 thread
2322 .update(&mut cx, |thread, cx| {
2323 thread.write_text_file(
2324 path!("/tmp/foo").into(),
2325 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2326 cx,
2327 )
2328 })
2329 .unwrap()
2330 .await
2331 .unwrap();
2332 Ok(acp::PromptResponse {
2333 stop_reason: acp::StopReason::EndTurn,
2334 })
2335 }
2336 .boxed_local()
2337 },
2338 ));
2339
2340 let (worktree, pathbuf) = project
2341 .update(cx, |project, cx| {
2342 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2343 })
2344 .await
2345 .unwrap();
2346 let buffer = project
2347 .update(cx, |project, cx| {
2348 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2349 })
2350 .await
2351 .unwrap();
2352
2353 let thread = cx
2354 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2355 .await
2356 .unwrap();
2357
2358 let request = thread.update(cx, |thread, cx| {
2359 thread.send_raw("Extend the count in /tmp/foo", cx)
2360 });
2361 read_file_rx.await.ok();
2362 buffer.update(cx, |buffer, cx| {
2363 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2364 });
2365 cx.run_until_parked();
2366 assert_eq!(
2367 buffer.read_with(cx, |buffer, _| buffer.text()),
2368 "zero\none\ntwo\nthree\nfour\nfive\n"
2369 );
2370 assert_eq!(
2371 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2372 "zero\none\ntwo\nthree\nfour\nfive\n"
2373 );
2374 request.await.unwrap();
2375 }
2376
2377 #[gpui::test]
2378 async fn test_reading_from_line(cx: &mut TestAppContext) {
2379 init_test(cx);
2380
2381 let fs = FakeFs::new(cx.executor());
2382 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2383 .await;
2384 let project = Project::test(fs.clone(), [], cx).await;
2385 project
2386 .update(cx, |project, cx| {
2387 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2388 })
2389 .await
2390 .unwrap();
2391
2392 let connection = Rc::new(FakeAgentConnection::new());
2393
2394 let thread = cx
2395 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2396 .await
2397 .unwrap();
2398
2399 // Whole file
2400 let content = thread
2401 .update(cx, |thread, cx| {
2402 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2403 })
2404 .await
2405 .unwrap();
2406
2407 assert_eq!(content, "one\ntwo\nthree\nfour\n");
2408
2409 // Only start line
2410 let content = thread
2411 .update(cx, |thread, cx| {
2412 thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2413 })
2414 .await
2415 .unwrap();
2416
2417 assert_eq!(content, "three\nfour\n");
2418
2419 // Only limit
2420 let content = thread
2421 .update(cx, |thread, cx| {
2422 thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2423 })
2424 .await
2425 .unwrap();
2426
2427 assert_eq!(content, "one\ntwo\n");
2428
2429 // Range
2430 let content = thread
2431 .update(cx, |thread, cx| {
2432 thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2433 })
2434 .await
2435 .unwrap();
2436
2437 assert_eq!(content, "two\nthree\n");
2438
2439 // Invalid
2440 let err = thread
2441 .update(cx, |thread, cx| {
2442 thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2443 })
2444 .await
2445 .unwrap_err();
2446
2447 assert_eq!(
2448 err.to_string(),
2449 "Attempting to read beyond the end of the file, line 5:0"
2450 );
2451 }
2452
2453 #[gpui::test]
2454 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2455 init_test(cx);
2456
2457 let fs = FakeFs::new(cx.executor());
2458 let project = Project::test(fs, [], cx).await;
2459 let id = acp::ToolCallId("test".into());
2460
2461 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2462 let id = id.clone();
2463 move |_, thread, mut cx| {
2464 let id = id.clone();
2465 async move {
2466 thread
2467 .update(&mut cx, |thread, cx| {
2468 thread.handle_session_update(
2469 acp::SessionUpdate::ToolCall(acp::ToolCall {
2470 id: id.clone(),
2471 title: "Label".into(),
2472 kind: acp::ToolKind::Fetch,
2473 status: acp::ToolCallStatus::InProgress,
2474 content: vec![],
2475 locations: vec![],
2476 raw_input: None,
2477 raw_output: None,
2478 }),
2479 cx,
2480 )
2481 })
2482 .unwrap()
2483 .unwrap();
2484 Ok(acp::PromptResponse {
2485 stop_reason: acp::StopReason::EndTurn,
2486 })
2487 }
2488 .boxed_local()
2489 }
2490 }));
2491
2492 let thread = cx
2493 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2494 .await
2495 .unwrap();
2496
2497 let request = thread.update(cx, |thread, cx| {
2498 thread.send_raw("Fetch https://example.com", cx)
2499 });
2500
2501 run_until_first_tool_call(&thread, cx).await;
2502
2503 thread.read_with(cx, |thread, _| {
2504 assert!(matches!(
2505 thread.entries[1],
2506 AgentThreadEntry::ToolCall(ToolCall {
2507 status: ToolCallStatus::InProgress,
2508 ..
2509 })
2510 ));
2511 });
2512
2513 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2514
2515 thread.read_with(cx, |thread, _| {
2516 assert!(matches!(
2517 &thread.entries[1],
2518 AgentThreadEntry::ToolCall(ToolCall {
2519 status: ToolCallStatus::Canceled,
2520 ..
2521 })
2522 ));
2523 });
2524
2525 thread
2526 .update(cx, |thread, cx| {
2527 thread.handle_session_update(
2528 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2529 id,
2530 fields: acp::ToolCallUpdateFields {
2531 status: Some(acp::ToolCallStatus::Completed),
2532 ..Default::default()
2533 },
2534 }),
2535 cx,
2536 )
2537 })
2538 .unwrap();
2539
2540 request.await.unwrap();
2541
2542 thread.read_with(cx, |thread, _| {
2543 assert!(matches!(
2544 thread.entries[1],
2545 AgentThreadEntry::ToolCall(ToolCall {
2546 status: ToolCallStatus::Completed,
2547 ..
2548 })
2549 ));
2550 });
2551 }
2552
2553 #[gpui::test]
2554 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2555 init_test(cx);
2556 let fs = FakeFs::new(cx.background_executor.clone());
2557 fs.insert_tree(path!("/test"), json!({})).await;
2558 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2559
2560 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2561 move |_, thread, mut cx| {
2562 async move {
2563 thread
2564 .update(&mut cx, |thread, cx| {
2565 thread.handle_session_update(
2566 acp::SessionUpdate::ToolCall(acp::ToolCall {
2567 id: acp::ToolCallId("test".into()),
2568 title: "Label".into(),
2569 kind: acp::ToolKind::Edit,
2570 status: acp::ToolCallStatus::Completed,
2571 content: vec![acp::ToolCallContent::Diff {
2572 diff: acp::Diff {
2573 path: "/test/test.txt".into(),
2574 old_text: None,
2575 new_text: "foo".into(),
2576 },
2577 }],
2578 locations: vec![],
2579 raw_input: None,
2580 raw_output: None,
2581 }),
2582 cx,
2583 )
2584 })
2585 .unwrap()
2586 .unwrap();
2587 Ok(acp::PromptResponse {
2588 stop_reason: acp::StopReason::EndTurn,
2589 })
2590 }
2591 .boxed_local()
2592 }
2593 }));
2594
2595 let thread = cx
2596 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2597 .await
2598 .unwrap();
2599
2600 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2601 .await
2602 .unwrap();
2603
2604 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2605 }
2606
2607 #[gpui::test(iterations = 10)]
2608 async fn test_checkpoints(cx: &mut TestAppContext) {
2609 init_test(cx);
2610 let fs = FakeFs::new(cx.background_executor.clone());
2611 fs.insert_tree(
2612 path!("/test"),
2613 json!({
2614 ".git": {}
2615 }),
2616 )
2617 .await;
2618 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2619
2620 let simulate_changes = Arc::new(AtomicBool::new(true));
2621 let next_filename = Arc::new(AtomicUsize::new(0));
2622 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2623 let simulate_changes = simulate_changes.clone();
2624 let next_filename = next_filename.clone();
2625 let fs = fs.clone();
2626 move |request, thread, mut cx| {
2627 let fs = fs.clone();
2628 let simulate_changes = simulate_changes.clone();
2629 let next_filename = next_filename.clone();
2630 async move {
2631 if simulate_changes.load(SeqCst) {
2632 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2633 fs.write(Path::new(&filename), b"").await?;
2634 }
2635
2636 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2637 panic!("expected text content block");
2638 };
2639 thread.update(&mut cx, |thread, cx| {
2640 thread
2641 .handle_session_update(
2642 acp::SessionUpdate::AgentMessageChunk {
2643 content: content.text.to_uppercase().into(),
2644 },
2645 cx,
2646 )
2647 .unwrap();
2648 })?;
2649 Ok(acp::PromptResponse {
2650 stop_reason: acp::StopReason::EndTurn,
2651 })
2652 }
2653 .boxed_local()
2654 }
2655 }));
2656 let thread = cx
2657 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2658 .await
2659 .unwrap();
2660
2661 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2662 .await
2663 .unwrap();
2664 thread.read_with(cx, |thread, cx| {
2665 assert_eq!(
2666 thread.to_markdown(cx),
2667 indoc! {"
2668 ## User (checkpoint)
2669
2670 Lorem
2671
2672 ## Assistant
2673
2674 LOREM
2675
2676 "}
2677 );
2678 });
2679 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2680
2681 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2682 .await
2683 .unwrap();
2684 thread.read_with(cx, |thread, cx| {
2685 assert_eq!(
2686 thread.to_markdown(cx),
2687 indoc! {"
2688 ## User (checkpoint)
2689
2690 Lorem
2691
2692 ## Assistant
2693
2694 LOREM
2695
2696 ## User (checkpoint)
2697
2698 ipsum
2699
2700 ## Assistant
2701
2702 IPSUM
2703
2704 "}
2705 );
2706 });
2707 assert_eq!(
2708 fs.files(),
2709 vec![
2710 Path::new(path!("/test/file-0")),
2711 Path::new(path!("/test/file-1"))
2712 ]
2713 );
2714
2715 // Checkpoint isn't stored when there are no changes.
2716 simulate_changes.store(false, SeqCst);
2717 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2718 .await
2719 .unwrap();
2720 thread.read_with(cx, |thread, cx| {
2721 assert_eq!(
2722 thread.to_markdown(cx),
2723 indoc! {"
2724 ## User (checkpoint)
2725
2726 Lorem
2727
2728 ## Assistant
2729
2730 LOREM
2731
2732 ## User (checkpoint)
2733
2734 ipsum
2735
2736 ## Assistant
2737
2738 IPSUM
2739
2740 ## User
2741
2742 dolor
2743
2744 ## Assistant
2745
2746 DOLOR
2747
2748 "}
2749 );
2750 });
2751 assert_eq!(
2752 fs.files(),
2753 vec![
2754 Path::new(path!("/test/file-0")),
2755 Path::new(path!("/test/file-1"))
2756 ]
2757 );
2758
2759 // Rewinding the conversation truncates the history and restores the checkpoint.
2760 thread
2761 .update(cx, |thread, cx| {
2762 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2763 panic!("unexpected entries {:?}", thread.entries)
2764 };
2765 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2766 })
2767 .await
2768 .unwrap();
2769 thread.read_with(cx, |thread, cx| {
2770 assert_eq!(
2771 thread.to_markdown(cx),
2772 indoc! {"
2773 ## User (checkpoint)
2774
2775 Lorem
2776
2777 ## Assistant
2778
2779 LOREM
2780
2781 "}
2782 );
2783 });
2784 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2785 }
2786
2787 #[gpui::test]
2788 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2789 use std::sync::atomic::AtomicUsize;
2790 init_test(cx);
2791
2792 let fs = FakeFs::new(cx.executor());
2793 let project = Project::test(fs, None, cx).await;
2794
2795 // Create a connection that simulates refusal after tool result
2796 let prompt_count = Arc::new(AtomicUsize::new(0));
2797 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2798 let prompt_count = prompt_count.clone();
2799 move |_request, thread, mut cx| {
2800 let count = prompt_count.fetch_add(1, SeqCst);
2801 async move {
2802 if count == 0 {
2803 // First prompt: Generate a tool call with result
2804 thread.update(&mut cx, |thread, cx| {
2805 thread
2806 .handle_session_update(
2807 acp::SessionUpdate::ToolCall(acp::ToolCall {
2808 id: acp::ToolCallId("tool1".into()),
2809 title: "Test Tool".into(),
2810 kind: acp::ToolKind::Fetch,
2811 status: acp::ToolCallStatus::Completed,
2812 content: vec![],
2813 locations: vec![],
2814 raw_input: Some(serde_json::json!({"query": "test"})),
2815 raw_output: Some(
2816 serde_json::json!({"result": "inappropriate content"}),
2817 ),
2818 }),
2819 cx,
2820 )
2821 .unwrap();
2822 })?;
2823
2824 // Now return refusal because of the tool result
2825 Ok(acp::PromptResponse {
2826 stop_reason: acp::StopReason::Refusal,
2827 })
2828 } else {
2829 Ok(acp::PromptResponse {
2830 stop_reason: acp::StopReason::EndTurn,
2831 })
2832 }
2833 }
2834 .boxed_local()
2835 }
2836 }));
2837
2838 let thread = cx
2839 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2840 .await
2841 .unwrap();
2842
2843 // Track if we see a Refusal event
2844 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2845 let saw_refusal_event_captured = saw_refusal_event.clone();
2846 thread.update(cx, |_thread, cx| {
2847 cx.subscribe(
2848 &thread,
2849 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2850 if matches!(event, AcpThreadEvent::Refusal) {
2851 *saw_refusal_event_captured.lock().unwrap() = true;
2852 }
2853 },
2854 )
2855 .detach();
2856 });
2857
2858 // Send a user message - this will trigger tool call and then refusal
2859 let send_task = thread.update(cx, |thread, cx| {
2860 thread.send(
2861 vec![acp::ContentBlock::Text(acp::TextContent {
2862 text: "Hello".into(),
2863 annotations: None,
2864 })],
2865 cx,
2866 )
2867 });
2868 cx.background_executor.spawn(send_task).detach();
2869 cx.run_until_parked();
2870
2871 // Verify that:
2872 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2873 // 2. The user message was NOT truncated
2874 assert!(
2875 *saw_refusal_event.lock().unwrap(),
2876 "Refusal event should be emitted for tool result refusals"
2877 );
2878
2879 thread.read_with(cx, |thread, _| {
2880 let entries = thread.entries();
2881 assert!(entries.len() >= 2, "Should have user message and tool call");
2882
2883 // Verify user message is still there
2884 assert!(
2885 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2886 "User message should not be truncated"
2887 );
2888
2889 // Verify tool call is there with result
2890 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2891 assert!(
2892 tool_call.raw_output.is_some(),
2893 "Tool call should have output"
2894 );
2895 } else {
2896 panic!("Expected tool call at index 1");
2897 }
2898 });
2899 }
2900
2901 #[gpui::test]
2902 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2903 init_test(cx);
2904
2905 let fs = FakeFs::new(cx.executor());
2906 let project = Project::test(fs, None, cx).await;
2907
2908 let refuse_next = Arc::new(AtomicBool::new(false));
2909 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2910 let refuse_next = refuse_next.clone();
2911 move |_request, _thread, _cx| {
2912 if refuse_next.load(SeqCst) {
2913 async move {
2914 Ok(acp::PromptResponse {
2915 stop_reason: acp::StopReason::Refusal,
2916 })
2917 }
2918 .boxed_local()
2919 } else {
2920 async move {
2921 Ok(acp::PromptResponse {
2922 stop_reason: acp::StopReason::EndTurn,
2923 })
2924 }
2925 .boxed_local()
2926 }
2927 }
2928 }));
2929
2930 let thread = cx
2931 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2932 .await
2933 .unwrap();
2934
2935 // Track if we see a Refusal event
2936 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2937 let saw_refusal_event_captured = saw_refusal_event.clone();
2938 thread.update(cx, |_thread, cx| {
2939 cx.subscribe(
2940 &thread,
2941 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2942 if matches!(event, AcpThreadEvent::Refusal) {
2943 *saw_refusal_event_captured.lock().unwrap() = true;
2944 }
2945 },
2946 )
2947 .detach();
2948 });
2949
2950 // Send a message that will be refused
2951 refuse_next.store(true, SeqCst);
2952 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2953 .await
2954 .unwrap();
2955
2956 // Verify that a Refusal event WAS emitted for user prompt refusal
2957 assert!(
2958 *saw_refusal_event.lock().unwrap(),
2959 "Refusal event should be emitted for user prompt refusals"
2960 );
2961
2962 // Verify the message was truncated (user prompt refusal)
2963 thread.read_with(cx, |thread, cx| {
2964 assert_eq!(thread.to_markdown(cx), "");
2965 });
2966 }
2967
2968 #[gpui::test]
2969 async fn test_refusal(cx: &mut TestAppContext) {
2970 init_test(cx);
2971 let fs = FakeFs::new(cx.background_executor.clone());
2972 fs.insert_tree(path!("/"), json!({})).await;
2973 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2974
2975 let refuse_next = Arc::new(AtomicBool::new(false));
2976 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2977 let refuse_next = refuse_next.clone();
2978 move |request, thread, mut cx| {
2979 let refuse_next = refuse_next.clone();
2980 async move {
2981 if refuse_next.load(SeqCst) {
2982 return Ok(acp::PromptResponse {
2983 stop_reason: acp::StopReason::Refusal,
2984 });
2985 }
2986
2987 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2988 panic!("expected text content block");
2989 };
2990 thread.update(&mut cx, |thread, cx| {
2991 thread
2992 .handle_session_update(
2993 acp::SessionUpdate::AgentMessageChunk {
2994 content: content.text.to_uppercase().into(),
2995 },
2996 cx,
2997 )
2998 .unwrap();
2999 })?;
3000 Ok(acp::PromptResponse {
3001 stop_reason: acp::StopReason::EndTurn,
3002 })
3003 }
3004 .boxed_local()
3005 }
3006 }));
3007 let thread = cx
3008 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3009 .await
3010 .unwrap();
3011
3012 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3013 .await
3014 .unwrap();
3015 thread.read_with(cx, |thread, cx| {
3016 assert_eq!(
3017 thread.to_markdown(cx),
3018 indoc! {"
3019 ## User
3020
3021 hello
3022
3023 ## Assistant
3024
3025 HELLO
3026
3027 "}
3028 );
3029 });
3030
3031 // Simulate refusing the second message. The message should be truncated
3032 // when a user prompt is refused.
3033 refuse_next.store(true, SeqCst);
3034 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3035 .await
3036 .unwrap();
3037 thread.read_with(cx, |thread, cx| {
3038 assert_eq!(
3039 thread.to_markdown(cx),
3040 indoc! {"
3041 ## User
3042
3043 hello
3044
3045 ## Assistant
3046
3047 HELLO
3048
3049 "}
3050 );
3051 });
3052 }
3053
3054 async fn run_until_first_tool_call(
3055 thread: &Entity<AcpThread>,
3056 cx: &mut TestAppContext,
3057 ) -> usize {
3058 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3059
3060 let subscription = cx.update(|cx| {
3061 cx.subscribe(thread, move |thread, _, cx| {
3062 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3063 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3064 return tx.try_send(ix).unwrap();
3065 }
3066 }
3067 })
3068 });
3069
3070 select! {
3071 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3072 panic!("Timeout waiting for tool call")
3073 }
3074 ix = rx.next().fuse() => {
3075 drop(subscription);
3076 ix.unwrap()
3077 }
3078 }
3079 }
3080
3081 #[derive(Clone, Default)]
3082 struct FakeAgentConnection {
3083 auth_methods: Vec<acp::AuthMethod>,
3084 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3085 on_user_message: Option<
3086 Rc<
3087 dyn Fn(
3088 acp::PromptRequest,
3089 WeakEntity<AcpThread>,
3090 AsyncApp,
3091 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3092 + 'static,
3093 >,
3094 >,
3095 }
3096
3097 impl FakeAgentConnection {
3098 fn new() -> Self {
3099 Self {
3100 auth_methods: Vec::new(),
3101 on_user_message: None,
3102 sessions: Arc::default(),
3103 }
3104 }
3105
3106 #[expect(unused)]
3107 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3108 self.auth_methods = auth_methods;
3109 self
3110 }
3111
3112 fn on_user_message(
3113 mut self,
3114 handler: impl Fn(
3115 acp::PromptRequest,
3116 WeakEntity<AcpThread>,
3117 AsyncApp,
3118 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3119 + 'static,
3120 ) -> Self {
3121 self.on_user_message.replace(Rc::new(handler));
3122 self
3123 }
3124 }
3125
3126 impl AgentConnection for FakeAgentConnection {
3127 fn auth_methods(&self) -> &[acp::AuthMethod] {
3128 &self.auth_methods
3129 }
3130
3131 fn new_thread(
3132 self: Rc<Self>,
3133 project: Entity<Project>,
3134 _cwd: &Path,
3135 cx: &mut App,
3136 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3137 let session_id = acp::SessionId(
3138 rand::rng()
3139 .sample_iter(&distr::Alphanumeric)
3140 .take(7)
3141 .map(char::from)
3142 .collect::<String>()
3143 .into(),
3144 );
3145 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3146 let thread = cx.new(|cx| {
3147 AcpThread::new(
3148 "Test",
3149 self.clone(),
3150 project,
3151 action_log,
3152 session_id.clone(),
3153 watch::Receiver::constant(acp::PromptCapabilities {
3154 image: true,
3155 audio: true,
3156 embedded_context: true,
3157 }),
3158 cx,
3159 )
3160 });
3161 self.sessions.lock().insert(session_id, thread.downgrade());
3162 Task::ready(Ok(thread))
3163 }
3164
3165 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3166 if self.auth_methods().iter().any(|m| m.id == method) {
3167 Task::ready(Ok(()))
3168 } else {
3169 Task::ready(Err(anyhow!("Invalid Auth Method")))
3170 }
3171 }
3172
3173 fn prompt(
3174 &self,
3175 _id: Option<UserMessageId>,
3176 params: acp::PromptRequest,
3177 cx: &mut App,
3178 ) -> Task<gpui::Result<acp::PromptResponse>> {
3179 let sessions = self.sessions.lock();
3180 let thread = sessions.get(¶ms.session_id).unwrap();
3181 if let Some(handler) = &self.on_user_message {
3182 let handler = handler.clone();
3183 let thread = thread.clone();
3184 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3185 } else {
3186 Task::ready(Ok(acp::PromptResponse {
3187 stop_reason: acp::StopReason::EndTurn,
3188 }))
3189 }
3190 }
3191
3192 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3193 let sessions = self.sessions.lock();
3194 let thread = sessions.get(session_id).unwrap().clone();
3195
3196 cx.spawn(async move |cx| {
3197 thread
3198 .update(cx, |thread, cx| thread.cancel(cx))
3199 .unwrap()
3200 .await
3201 })
3202 .detach();
3203 }
3204
3205 fn truncate(
3206 &self,
3207 session_id: &acp::SessionId,
3208 _cx: &App,
3209 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3210 Some(Rc::new(FakeAgentSessionEditor {
3211 _session_id: session_id.clone(),
3212 }))
3213 }
3214
3215 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3216 self
3217 }
3218 }
3219
3220 struct FakeAgentSessionEditor {
3221 _session_id: acp::SessionId,
3222 }
3223
3224 impl AgentSessionTruncate for FakeAgentSessionEditor {
3225 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3226 Task::ready(Ok(()))
3227 }
3228 }
3229}