1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use futures::future::Shared;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path,
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
789 determine_shell: Shared<Task<String>>,
790 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
791}
792
793#[derive(Debug)]
794pub enum AcpThreadEvent {
795 NewEntry,
796 TitleUpdated,
797 TokenUsageUpdated,
798 EntryUpdated(usize),
799 EntriesRemoved(Range<usize>),
800 ToolAuthorizationRequired,
801 Retry(RetryStatus),
802 Stopped,
803 Error,
804 LoadError(LoadError),
805 PromptCapabilitiesUpdated,
806 Refusal,
807 AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
808 ModeUpdated(acp::SessionModeId),
809}
810
811impl EventEmitter<AcpThreadEvent> for AcpThread {}
812
813#[derive(PartialEq, Eq, Debug)]
814pub enum ThreadStatus {
815 Idle,
816 WaitingForToolConfirmation,
817 Generating,
818}
819
820#[derive(Debug, Clone)]
821pub enum LoadError {
822 Unsupported {
823 command: SharedString,
824 current_version: SharedString,
825 minimum_version: SharedString,
826 },
827 FailedToInstall(SharedString),
828 Exited {
829 status: ExitStatus,
830 },
831 Other(SharedString),
832}
833
834impl Display for LoadError {
835 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
836 match self {
837 LoadError::Unsupported {
838 command: path,
839 current_version,
840 minimum_version,
841 } => {
842 write!(
843 f,
844 "version {current_version} from {path} is not supported (need at least {minimum_version})"
845 )
846 }
847 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
848 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
849 LoadError::Other(msg) => write!(f, "{msg}"),
850 }
851 }
852}
853
854impl Error for LoadError {}
855
856impl AcpThread {
857 pub fn new(
858 title: impl Into<SharedString>,
859 connection: Rc<dyn AgentConnection>,
860 project: Entity<Project>,
861 action_log: Entity<ActionLog>,
862 session_id: acp::SessionId,
863 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
864 cx: &mut Context<Self>,
865 ) -> Self {
866 let prompt_capabilities = *prompt_capabilities_rx.borrow();
867 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
868 loop {
869 let caps = prompt_capabilities_rx.recv().await?;
870 this.update(cx, |this, cx| {
871 this.prompt_capabilities = caps;
872 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
873 })?;
874 }
875 });
876
877 let determine_shell = cx
878 .background_spawn(async move {
879 if cfg!(windows) {
880 return get_system_shell();
881 }
882
883 if which::which("bash").is_ok() {
884 "bash".into()
885 } else {
886 get_system_shell()
887 }
888 })
889 .shared();
890
891 Self {
892 action_log,
893 shared_buffers: Default::default(),
894 entries: Default::default(),
895 plan: Default::default(),
896 title: title.into(),
897 project,
898 send_task: None,
899 connection,
900 session_id,
901 token_usage: None,
902 prompt_capabilities,
903 _observe_prompt_capabilities: task,
904 terminals: HashMap::default(),
905 determine_shell,
906 }
907 }
908
909 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
910 self.prompt_capabilities
911 }
912
913 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
914 &self.connection
915 }
916
917 pub fn action_log(&self) -> &Entity<ActionLog> {
918 &self.action_log
919 }
920
921 pub fn project(&self) -> &Entity<Project> {
922 &self.project
923 }
924
925 pub fn title(&self) -> SharedString {
926 self.title.clone()
927 }
928
929 pub fn entries(&self) -> &[AgentThreadEntry] {
930 &self.entries
931 }
932
933 pub fn session_id(&self) -> &acp::SessionId {
934 &self.session_id
935 }
936
937 pub fn status(&self) -> ThreadStatus {
938 if self.send_task.is_some() {
939 if self.waiting_for_tool_confirmation() {
940 ThreadStatus::WaitingForToolConfirmation
941 } else {
942 ThreadStatus::Generating
943 }
944 } else {
945 ThreadStatus::Idle
946 }
947 }
948
949 pub fn token_usage(&self) -> Option<&TokenUsage> {
950 self.token_usage.as_ref()
951 }
952
953 pub fn has_pending_edit_tool_calls(&self) -> bool {
954 for entry in self.entries.iter().rev() {
955 match entry {
956 AgentThreadEntry::UserMessage(_) => return false,
957 AgentThreadEntry::ToolCall(
958 call @ ToolCall {
959 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
960 ..
961 },
962 ) if call.diffs().next().is_some() => {
963 return true;
964 }
965 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
966 }
967 }
968
969 false
970 }
971
972 pub fn used_tools_since_last_user_message(&self) -> bool {
973 for entry in self.entries.iter().rev() {
974 match entry {
975 AgentThreadEntry::UserMessage(..) => return false,
976 AgentThreadEntry::AssistantMessage(..) => continue,
977 AgentThreadEntry::ToolCall(..) => return true,
978 }
979 }
980
981 false
982 }
983
984 pub fn handle_session_update(
985 &mut self,
986 update: acp::SessionUpdate,
987 cx: &mut Context<Self>,
988 ) -> Result<(), acp::Error> {
989 match update {
990 acp::SessionUpdate::UserMessageChunk { content } => {
991 self.push_user_content_block(None, content, cx);
992 }
993 acp::SessionUpdate::AgentMessageChunk { content } => {
994 self.push_assistant_content_block(content, false, cx);
995 }
996 acp::SessionUpdate::AgentThoughtChunk { content } => {
997 self.push_assistant_content_block(content, true, cx);
998 }
999 acp::SessionUpdate::ToolCall(tool_call) => {
1000 self.upsert_tool_call(tool_call, cx)?;
1001 }
1002 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1003 self.update_tool_call(tool_call_update, cx)?;
1004 }
1005 acp::SessionUpdate::Plan(plan) => {
1006 self.update_plan(plan, cx);
1007 }
1008 acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
1009 cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
1010 }
1011 acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
1012 cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
1013 }
1014 }
1015 Ok(())
1016 }
1017
1018 pub fn push_user_content_block(
1019 &mut self,
1020 message_id: Option<UserMessageId>,
1021 chunk: acp::ContentBlock,
1022 cx: &mut Context<Self>,
1023 ) {
1024 let language_registry = self.project.read(cx).languages().clone();
1025 let entries_len = self.entries.len();
1026
1027 if let Some(last_entry) = self.entries.last_mut()
1028 && let AgentThreadEntry::UserMessage(UserMessage {
1029 id,
1030 content,
1031 chunks,
1032 ..
1033 }) = last_entry
1034 {
1035 *id = message_id.or(id.take());
1036 content.append(chunk.clone(), &language_registry, cx);
1037 chunks.push(chunk);
1038 let idx = entries_len - 1;
1039 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1040 } else {
1041 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1042 self.push_entry(
1043 AgentThreadEntry::UserMessage(UserMessage {
1044 id: message_id,
1045 content,
1046 chunks: vec![chunk],
1047 checkpoint: None,
1048 }),
1049 cx,
1050 );
1051 }
1052 }
1053
1054 pub fn push_assistant_content_block(
1055 &mut self,
1056 chunk: acp::ContentBlock,
1057 is_thought: bool,
1058 cx: &mut Context<Self>,
1059 ) {
1060 let language_registry = self.project.read(cx).languages().clone();
1061 let entries_len = self.entries.len();
1062 if let Some(last_entry) = self.entries.last_mut()
1063 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1064 {
1065 let idx = entries_len - 1;
1066 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1067 match (chunks.last_mut(), is_thought) {
1068 (Some(AssistantMessageChunk::Message { block }), false)
1069 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1070 block.append(chunk, &language_registry, cx)
1071 }
1072 _ => {
1073 let block = ContentBlock::new(chunk, &language_registry, cx);
1074 if is_thought {
1075 chunks.push(AssistantMessageChunk::Thought { block })
1076 } else {
1077 chunks.push(AssistantMessageChunk::Message { block })
1078 }
1079 }
1080 }
1081 } else {
1082 let block = ContentBlock::new(chunk, &language_registry, cx);
1083 let chunk = if is_thought {
1084 AssistantMessageChunk::Thought { block }
1085 } else {
1086 AssistantMessageChunk::Message { block }
1087 };
1088
1089 self.push_entry(
1090 AgentThreadEntry::AssistantMessage(AssistantMessage {
1091 chunks: vec![chunk],
1092 }),
1093 cx,
1094 );
1095 }
1096 }
1097
1098 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1099 self.entries.push(entry);
1100 cx.emit(AcpThreadEvent::NewEntry);
1101 }
1102
1103 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1104 self.connection.set_title(&self.session_id, cx).is_some()
1105 }
1106
1107 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1108 if title != self.title {
1109 self.title = title.clone();
1110 cx.emit(AcpThreadEvent::TitleUpdated);
1111 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1112 return set_title.run(title, cx);
1113 }
1114 }
1115 Task::ready(Ok(()))
1116 }
1117
1118 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1119 self.token_usage = usage;
1120 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1121 }
1122
1123 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1124 cx.emit(AcpThreadEvent::Retry(status));
1125 }
1126
1127 pub fn update_tool_call(
1128 &mut self,
1129 update: impl Into<ToolCallUpdate>,
1130 cx: &mut Context<Self>,
1131 ) -> Result<()> {
1132 let update = update.into();
1133 let languages = self.project.read(cx).languages().clone();
1134
1135 let ix = self
1136 .index_for_tool_call(update.id())
1137 .context("Tool call not found")?;
1138 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1139 unreachable!()
1140 };
1141
1142 match update {
1143 ToolCallUpdate::UpdateFields(update) => {
1144 let location_updated = update.fields.locations.is_some();
1145 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1146 if location_updated {
1147 self.resolve_locations(update.id, cx);
1148 }
1149 }
1150 ToolCallUpdate::UpdateDiff(update) => {
1151 call.content.clear();
1152 call.content.push(ToolCallContent::Diff(update.diff));
1153 }
1154 ToolCallUpdate::UpdateTerminal(update) => {
1155 call.content.clear();
1156 call.content
1157 .push(ToolCallContent::Terminal(update.terminal));
1158 }
1159 }
1160
1161 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1162
1163 Ok(())
1164 }
1165
1166 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1167 pub fn upsert_tool_call(
1168 &mut self,
1169 tool_call: acp::ToolCall,
1170 cx: &mut Context<Self>,
1171 ) -> Result<(), acp::Error> {
1172 let status = tool_call.status.into();
1173 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1174 }
1175
1176 /// Fails if id does not match an existing entry.
1177 pub fn upsert_tool_call_inner(
1178 &mut self,
1179 update: acp::ToolCallUpdate,
1180 status: ToolCallStatus,
1181 cx: &mut Context<Self>,
1182 ) -> Result<(), acp::Error> {
1183 let language_registry = self.project.read(cx).languages().clone();
1184 let id = update.id.clone();
1185
1186 if let Some(ix) = self.index_for_tool_call(&id) {
1187 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1188 unreachable!()
1189 };
1190
1191 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1192 call.status = status;
1193
1194 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1195 } else {
1196 let call = ToolCall::from_acp(
1197 update.try_into()?,
1198 status,
1199 language_registry,
1200 &self.terminals,
1201 cx,
1202 )?;
1203 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1204 };
1205
1206 self.resolve_locations(id, cx);
1207 Ok(())
1208 }
1209
1210 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1211 self.entries
1212 .iter()
1213 .enumerate()
1214 .rev()
1215 .find_map(|(index, entry)| {
1216 if let AgentThreadEntry::ToolCall(tool_call) = entry
1217 && &tool_call.id == id
1218 {
1219 Some(index)
1220 } else {
1221 None
1222 }
1223 })
1224 }
1225
1226 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1227 // The tool call we are looking for is typically the last one, or very close to the end.
1228 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1229 self.entries
1230 .iter_mut()
1231 .enumerate()
1232 .rev()
1233 .find_map(|(index, tool_call)| {
1234 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1235 && &tool_call.id == id
1236 {
1237 Some((index, tool_call))
1238 } else {
1239 None
1240 }
1241 })
1242 }
1243
1244 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1245 self.entries
1246 .iter()
1247 .enumerate()
1248 .rev()
1249 .find_map(|(index, tool_call)| {
1250 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1251 && &tool_call.id == id
1252 {
1253 Some((index, tool_call))
1254 } else {
1255 None
1256 }
1257 })
1258 }
1259
1260 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1261 let project = self.project.clone();
1262 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1263 return;
1264 };
1265 let task = tool_call.resolve_locations(project, cx);
1266 cx.spawn(async move |this, cx| {
1267 let resolved_locations = task.await;
1268 this.update(cx, |this, cx| {
1269 let project = this.project.clone();
1270 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1271 return;
1272 };
1273 if let Some(Some(location)) = resolved_locations.last() {
1274 project.update(cx, |project, cx| {
1275 if let Some(agent_location) = project.agent_location() {
1276 let should_ignore = agent_location.buffer == location.buffer
1277 && location
1278 .buffer
1279 .update(cx, |buffer, _| {
1280 let snapshot = buffer.snapshot();
1281 let old_position =
1282 agent_location.position.to_point(&snapshot);
1283 let new_position = location.position.to_point(&snapshot);
1284 // ignore this so that when we get updates from the edit tool
1285 // the position doesn't reset to the startof line
1286 old_position.row == new_position.row
1287 && old_position.column > new_position.column
1288 })
1289 .ok()
1290 .unwrap_or_default();
1291 if !should_ignore {
1292 project.set_agent_location(Some(location.clone()), cx);
1293 }
1294 }
1295 });
1296 }
1297 if tool_call.resolved_locations != resolved_locations {
1298 tool_call.resolved_locations = resolved_locations;
1299 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1300 }
1301 })
1302 })
1303 .detach();
1304 }
1305
1306 pub fn request_tool_call_authorization(
1307 &mut self,
1308 tool_call: acp::ToolCallUpdate,
1309 options: Vec<acp::PermissionOption>,
1310 respect_always_allow_setting: bool,
1311 cx: &mut Context<Self>,
1312 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1313 let (tx, rx) = oneshot::channel();
1314
1315 if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1316 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1317 // some tools would (incorrectly) continue to auto-accept.
1318 if let Some(allow_once_option) = options.iter().find_map(|option| {
1319 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1320 Some(option.id.clone())
1321 } else {
1322 None
1323 }
1324 }) {
1325 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1326 return Ok(async {
1327 acp::RequestPermissionOutcome::Selected {
1328 option_id: allow_once_option,
1329 }
1330 }
1331 .boxed());
1332 }
1333 }
1334
1335 let status = ToolCallStatus::WaitingForConfirmation {
1336 options,
1337 respond_tx: tx,
1338 };
1339
1340 self.upsert_tool_call_inner(tool_call, status, cx)?;
1341 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1342
1343 let fut = async {
1344 match rx.await {
1345 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1346 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1347 }
1348 }
1349 .boxed();
1350
1351 Ok(fut)
1352 }
1353
1354 pub fn authorize_tool_call(
1355 &mut self,
1356 id: acp::ToolCallId,
1357 option_id: acp::PermissionOptionId,
1358 option_kind: acp::PermissionOptionKind,
1359 cx: &mut Context<Self>,
1360 ) {
1361 let Some((ix, call)) = self.tool_call_mut(&id) else {
1362 return;
1363 };
1364
1365 let new_status = match option_kind {
1366 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1367 ToolCallStatus::Rejected
1368 }
1369 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1370 ToolCallStatus::InProgress
1371 }
1372 };
1373
1374 let curr_status = mem::replace(&mut call.status, new_status);
1375
1376 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1377 respond_tx.send(option_id).log_err();
1378 } else if cfg!(debug_assertions) {
1379 panic!("tried to authorize an already authorized tool call");
1380 }
1381
1382 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1383 }
1384
1385 /// Returns true if the last turn is awaiting tool authorization
1386 pub fn waiting_for_tool_confirmation(&self) -> bool {
1387 for entry in self.entries.iter().rev() {
1388 match &entry {
1389 AgentThreadEntry::ToolCall(call) => match call.status {
1390 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1391 ToolCallStatus::Pending
1392 | ToolCallStatus::InProgress
1393 | ToolCallStatus::Completed
1394 | ToolCallStatus::Failed
1395 | ToolCallStatus::Rejected
1396 | ToolCallStatus::Canceled => continue,
1397 },
1398 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1399 // Reached the beginning of the turn
1400 return false;
1401 }
1402 }
1403 }
1404 false
1405 }
1406
1407 pub fn plan(&self) -> &Plan {
1408 &self.plan
1409 }
1410
1411 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1412 let new_entries_len = request.entries.len();
1413 let mut new_entries = request.entries.into_iter();
1414
1415 // Reuse existing markdown to prevent flickering
1416 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1417 let PlanEntry {
1418 content,
1419 priority,
1420 status,
1421 } = old;
1422 content.update(cx, |old, cx| {
1423 old.replace(new.content, cx);
1424 });
1425 *priority = new.priority;
1426 *status = new.status;
1427 }
1428 for new in new_entries {
1429 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1430 }
1431 self.plan.entries.truncate(new_entries_len);
1432
1433 cx.notify();
1434 }
1435
1436 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1437 self.plan
1438 .entries
1439 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1440 cx.notify();
1441 }
1442
1443 #[cfg(any(test, feature = "test-support"))]
1444 pub fn send_raw(
1445 &mut self,
1446 message: &str,
1447 cx: &mut Context<Self>,
1448 ) -> BoxFuture<'static, Result<()>> {
1449 self.send(
1450 vec![acp::ContentBlock::Text(acp::TextContent {
1451 text: message.to_string(),
1452 annotations: None,
1453 })],
1454 cx,
1455 )
1456 }
1457
1458 pub fn send(
1459 &mut self,
1460 message: Vec<acp::ContentBlock>,
1461 cx: &mut Context<Self>,
1462 ) -> BoxFuture<'static, Result<()>> {
1463 let block = ContentBlock::new_combined(
1464 message.clone(),
1465 self.project.read(cx).languages().clone(),
1466 cx,
1467 );
1468 let request = acp::PromptRequest {
1469 prompt: message.clone(),
1470 session_id: self.session_id.clone(),
1471 };
1472 let git_store = self.project.read(cx).git_store().clone();
1473
1474 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1475 Some(UserMessageId::new())
1476 } else {
1477 None
1478 };
1479
1480 self.run_turn(cx, async move |this, cx| {
1481 this.update(cx, |this, cx| {
1482 this.push_entry(
1483 AgentThreadEntry::UserMessage(UserMessage {
1484 id: message_id.clone(),
1485 content: block,
1486 chunks: message,
1487 checkpoint: None,
1488 }),
1489 cx,
1490 );
1491 })
1492 .ok();
1493
1494 let old_checkpoint = git_store
1495 .update(cx, |git, cx| git.checkpoint(cx))?
1496 .await
1497 .context("failed to get old checkpoint")
1498 .log_err();
1499 this.update(cx, |this, cx| {
1500 if let Some((_ix, message)) = this.last_user_message() {
1501 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1502 git_checkpoint,
1503 show: false,
1504 });
1505 }
1506 this.connection.prompt(message_id, request, cx)
1507 })?
1508 .await
1509 })
1510 }
1511
1512 pub fn can_resume(&self, cx: &App) -> bool {
1513 self.connection.resume(&self.session_id, cx).is_some()
1514 }
1515
1516 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1517 self.run_turn(cx, async move |this, cx| {
1518 this.update(cx, |this, cx| {
1519 this.connection
1520 .resume(&this.session_id, cx)
1521 .map(|resume| resume.run(cx))
1522 })?
1523 .context("resuming a session is not supported")?
1524 .await
1525 })
1526 }
1527
1528 fn run_turn(
1529 &mut self,
1530 cx: &mut Context<Self>,
1531 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1532 ) -> BoxFuture<'static, Result<()>> {
1533 self.clear_completed_plan_entries(cx);
1534
1535 let (tx, rx) = oneshot::channel();
1536 let cancel_task = self.cancel(cx);
1537
1538 self.send_task = Some(cx.spawn(async move |this, cx| {
1539 cancel_task.await;
1540 tx.send(f(this, cx).await).ok();
1541 }));
1542
1543 cx.spawn(async move |this, cx| {
1544 let response = rx.await;
1545
1546 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1547 .await?;
1548
1549 this.update(cx, |this, cx| {
1550 this.project
1551 .update(cx, |project, cx| project.set_agent_location(None, cx));
1552 match response {
1553 Ok(Err(e)) => {
1554 this.send_task.take();
1555 cx.emit(AcpThreadEvent::Error);
1556 Err(e)
1557 }
1558 result => {
1559 let canceled = matches!(
1560 result,
1561 Ok(Ok(acp::PromptResponse {
1562 stop_reason: acp::StopReason::Cancelled
1563 }))
1564 );
1565
1566 // We only take the task if the current prompt wasn't canceled.
1567 //
1568 // This prompt may have been canceled because another one was sent
1569 // while it was still generating. In these cases, dropping `send_task`
1570 // would cause the next generation to be canceled.
1571 if !canceled {
1572 this.send_task.take();
1573 }
1574
1575 // Handle refusal - distinguish between user prompt and tool call refusals
1576 if let Ok(Ok(acp::PromptResponse {
1577 stop_reason: acp::StopReason::Refusal,
1578 })) = result
1579 {
1580 if let Some((user_msg_ix, _)) = this.last_user_message() {
1581 // Check if there's a completed tool call with results after the last user message
1582 // This indicates the refusal is in response to tool output, not the user's prompt
1583 let has_completed_tool_call_after_user_msg =
1584 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1585 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1586 // Check if the tool call has completed and has output
1587 matches!(tool_call.status, ToolCallStatus::Completed)
1588 && tool_call.raw_output.is_some()
1589 } else {
1590 false
1591 }
1592 });
1593
1594 if has_completed_tool_call_after_user_msg {
1595 // Refusal is due to tool output - don't truncate, just notify
1596 // The model refused based on what the tool returned
1597 cx.emit(AcpThreadEvent::Refusal);
1598 } else {
1599 // User prompt was refused - truncate back to before the user message
1600 let range = user_msg_ix..this.entries.len();
1601 if range.start < range.end {
1602 this.entries.truncate(user_msg_ix);
1603 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1604 }
1605 cx.emit(AcpThreadEvent::Refusal);
1606 }
1607 } else {
1608 // No user message found, treat as general refusal
1609 cx.emit(AcpThreadEvent::Refusal);
1610 }
1611 }
1612
1613 cx.emit(AcpThreadEvent::Stopped);
1614 Ok(())
1615 }
1616 }
1617 })?
1618 })
1619 .boxed()
1620 }
1621
1622 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1623 let Some(send_task) = self.send_task.take() else {
1624 return Task::ready(());
1625 };
1626
1627 for entry in self.entries.iter_mut() {
1628 if let AgentThreadEntry::ToolCall(call) = entry {
1629 let cancel = matches!(
1630 call.status,
1631 ToolCallStatus::Pending
1632 | ToolCallStatus::WaitingForConfirmation { .. }
1633 | ToolCallStatus::InProgress
1634 );
1635
1636 if cancel {
1637 call.status = ToolCallStatus::Canceled;
1638 }
1639 }
1640 }
1641
1642 self.connection.cancel(&self.session_id, cx);
1643
1644 // Wait for the send task to complete
1645 cx.foreground_executor().spawn(send_task)
1646 }
1647
1648 /// Restores the git working tree to the state at the given checkpoint (if one exists)
1649 pub fn restore_checkpoint(
1650 &mut self,
1651 id: UserMessageId,
1652 cx: &mut Context<Self>,
1653 ) -> Task<Result<()>> {
1654 let Some((_, message)) = self.user_message_mut(&id) else {
1655 return Task::ready(Err(anyhow!("message not found")));
1656 };
1657
1658 let checkpoint = message
1659 .checkpoint
1660 .as_ref()
1661 .map(|c| c.git_checkpoint.clone());
1662 let rewind = self.rewind(id.clone(), cx);
1663 let git_store = self.project.read(cx).git_store().clone();
1664
1665 cx.spawn(async move |_, cx| {
1666 rewind.await?;
1667 if let Some(checkpoint) = checkpoint {
1668 git_store
1669 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1670 .await?;
1671 }
1672
1673 Ok(())
1674 })
1675 }
1676
1677 /// Rewinds this thread to before the entry at `index`, removing it and all
1678 /// subsequent entries while rejecting any action_log changes made from that point.
1679 /// Unlike `restore_checkpoint`, this method does not restore from git.
1680 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1681 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1682 return Task::ready(Err(anyhow!("not supported")));
1683 };
1684
1685 cx.spawn(async move |this, cx| {
1686 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1687 this.update(cx, |this, cx| {
1688 if let Some((ix, _)) = this.user_message_mut(&id) {
1689 let range = ix..this.entries.len();
1690 this.entries.truncate(ix);
1691 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1692 }
1693 this.action_log()
1694 .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1695 })?
1696 .await;
1697 Ok(())
1698 })
1699 }
1700
1701 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1702 let git_store = self.project.read(cx).git_store().clone();
1703
1704 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1705 if let Some(checkpoint) = message.checkpoint.as_ref() {
1706 checkpoint.git_checkpoint.clone()
1707 } else {
1708 return Task::ready(Ok(()));
1709 }
1710 } else {
1711 return Task::ready(Ok(()));
1712 };
1713
1714 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1715 cx.spawn(async move |this, cx| {
1716 let new_checkpoint = new_checkpoint
1717 .await
1718 .context("failed to get new checkpoint")
1719 .log_err();
1720 if let Some(new_checkpoint) = new_checkpoint {
1721 let equal = git_store
1722 .update(cx, |git, cx| {
1723 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1724 })?
1725 .await
1726 .unwrap_or(true);
1727 this.update(cx, |this, cx| {
1728 let (ix, message) = this.last_user_message().context("no user message")?;
1729 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1730 checkpoint.show = !equal;
1731 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1732 anyhow::Ok(())
1733 })??;
1734 }
1735
1736 Ok(())
1737 })
1738 }
1739
1740 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1741 self.entries
1742 .iter_mut()
1743 .enumerate()
1744 .rev()
1745 .find_map(|(ix, entry)| {
1746 if let AgentThreadEntry::UserMessage(message) = entry {
1747 Some((ix, message))
1748 } else {
1749 None
1750 }
1751 })
1752 }
1753
1754 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1755 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1756 if let AgentThreadEntry::UserMessage(message) = entry {
1757 if message.id.as_ref() == Some(id) {
1758 Some((ix, message))
1759 } else {
1760 None
1761 }
1762 } else {
1763 None
1764 }
1765 })
1766 }
1767
1768 pub fn read_text_file(
1769 &self,
1770 path: PathBuf,
1771 line: Option<u32>,
1772 limit: Option<u32>,
1773 reuse_shared_snapshot: bool,
1774 cx: &mut Context<Self>,
1775 ) -> Task<Result<String>> {
1776 let project = self.project.clone();
1777 let action_log = self.action_log.clone();
1778 cx.spawn(async move |this, cx| {
1779 let load = project.update(cx, |project, cx| {
1780 let path = project
1781 .project_path_for_absolute_path(&path, cx)
1782 .context("invalid path")?;
1783 anyhow::Ok(project.open_buffer(path, cx))
1784 });
1785 let buffer = load??.await?;
1786
1787 let snapshot = if reuse_shared_snapshot {
1788 this.read_with(cx, |this, _| {
1789 this.shared_buffers.get(&buffer.clone()).cloned()
1790 })
1791 .log_err()
1792 .flatten()
1793 } else {
1794 None
1795 };
1796
1797 let snapshot = if let Some(snapshot) = snapshot {
1798 snapshot
1799 } else {
1800 action_log.update(cx, |action_log, cx| {
1801 action_log.buffer_read(buffer.clone(), cx);
1802 })?;
1803 project.update(cx, |project, cx| {
1804 let position = buffer
1805 .read(cx)
1806 .snapshot()
1807 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1808 project.set_agent_location(
1809 Some(AgentLocation {
1810 buffer: buffer.downgrade(),
1811 position,
1812 }),
1813 cx,
1814 );
1815 })?;
1816
1817 buffer.update(cx, |buffer, _| buffer.snapshot())?
1818 };
1819
1820 this.update(cx, |this, _| {
1821 let text = snapshot.text();
1822 this.shared_buffers.insert(buffer.clone(), snapshot);
1823 if line.is_none() && limit.is_none() {
1824 return Ok(text);
1825 }
1826 let limit = limit.unwrap_or(u32::MAX) as usize;
1827 let Some(line) = line else {
1828 return Ok(text.lines().take(limit).collect::<String>());
1829 };
1830
1831 let count = text.lines().count();
1832 if count < line as usize {
1833 anyhow::bail!("There are only {} lines", count);
1834 }
1835 Ok(text
1836 .lines()
1837 .skip(line as usize + 1)
1838 .take(limit)
1839 .collect::<String>())
1840 })?
1841 })
1842 }
1843
1844 pub fn write_text_file(
1845 &self,
1846 path: PathBuf,
1847 content: String,
1848 cx: &mut Context<Self>,
1849 ) -> Task<Result<()>> {
1850 let project = self.project.clone();
1851 let action_log = self.action_log.clone();
1852 cx.spawn(async move |this, cx| {
1853 let load = project.update(cx, |project, cx| {
1854 let path = project
1855 .project_path_for_absolute_path(&path, cx)
1856 .context("invalid path")?;
1857 anyhow::Ok(project.open_buffer(path, cx))
1858 });
1859 let buffer = load??.await?;
1860 let snapshot = this.update(cx, |this, cx| {
1861 this.shared_buffers
1862 .get(&buffer)
1863 .cloned()
1864 .unwrap_or_else(|| buffer.read(cx).snapshot())
1865 })?;
1866 let edits = cx
1867 .background_executor()
1868 .spawn(async move {
1869 let old_text = snapshot.text();
1870 text_diff(old_text.as_str(), &content)
1871 .into_iter()
1872 .map(|(range, replacement)| {
1873 (
1874 snapshot.anchor_after(range.start)
1875 ..snapshot.anchor_before(range.end),
1876 replacement,
1877 )
1878 })
1879 .collect::<Vec<_>>()
1880 })
1881 .await;
1882
1883 project.update(cx, |project, cx| {
1884 project.set_agent_location(
1885 Some(AgentLocation {
1886 buffer: buffer.downgrade(),
1887 position: edits
1888 .last()
1889 .map(|(range, _)| range.end)
1890 .unwrap_or(Anchor::MIN),
1891 }),
1892 cx,
1893 );
1894 })?;
1895
1896 let format_on_save = cx.update(|cx| {
1897 action_log.update(cx, |action_log, cx| {
1898 action_log.buffer_read(buffer.clone(), cx);
1899 });
1900
1901 let format_on_save = buffer.update(cx, |buffer, cx| {
1902 buffer.edit(edits, None, cx);
1903
1904 let settings = language::language_settings::language_settings(
1905 buffer.language().map(|l| l.name()),
1906 buffer.file(),
1907 cx,
1908 );
1909
1910 settings.format_on_save != FormatOnSave::Off
1911 });
1912 action_log.update(cx, |action_log, cx| {
1913 action_log.buffer_edited(buffer.clone(), cx);
1914 });
1915 format_on_save
1916 })?;
1917
1918 if format_on_save {
1919 let format_task = project.update(cx, |project, cx| {
1920 project.format(
1921 HashSet::from_iter([buffer.clone()]),
1922 LspFormatTarget::Buffers,
1923 false,
1924 FormatTrigger::Save,
1925 cx,
1926 )
1927 })?;
1928 format_task.await.log_err();
1929
1930 action_log.update(cx, |action_log, cx| {
1931 action_log.buffer_edited(buffer.clone(), cx);
1932 })?;
1933 }
1934
1935 project
1936 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1937 .await
1938 })
1939 }
1940
1941 pub fn create_terminal(
1942 &self,
1943 mut command: String,
1944 args: Vec<String>,
1945 extra_env: Vec<acp::EnvVariable>,
1946 cwd: Option<PathBuf>,
1947 output_byte_limit: Option<u64>,
1948 cx: &mut Context<Self>,
1949 ) -> Task<Result<Entity<Terminal>>> {
1950 for arg in args {
1951 command.push(' ');
1952 command.push_str(&arg);
1953 }
1954
1955 let shell_command = if cfg!(windows) {
1956 format!("$null | & {{{}}}", command.replace("\"", "'"))
1957 } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1958 // Make sure once we're *inside* the shell, we cd into `cwd`
1959 format!("(cd {cwd}; {}) </dev/null", command)
1960 } else {
1961 format!("({}) </dev/null", command)
1962 };
1963 let args = vec!["-c".into(), shell_command];
1964
1965 let env = match &cwd {
1966 Some(dir) => self.project.update(cx, |project, cx| {
1967 project.directory_environment(dir.as_path().into(), cx)
1968 }),
1969 None => Task::ready(None).shared(),
1970 };
1971
1972 let env = cx.spawn(async move |_, _| {
1973 let mut env = env.await.unwrap_or_default();
1974 if cfg!(unix) {
1975 env.insert("PAGER".into(), "cat".into());
1976 }
1977 for var in extra_env {
1978 env.insert(var.name, var.value);
1979 }
1980 env
1981 });
1982
1983 let project = self.project.clone();
1984 let language_registry = project.read(cx).languages().clone();
1985 let determine_shell = self.determine_shell.clone();
1986
1987 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1988 let terminal_task = cx.spawn({
1989 let terminal_id = terminal_id.clone();
1990 async move |_this, cx| {
1991 let program = determine_shell.await;
1992 let env = env.await;
1993 let terminal = project
1994 .update(cx, |project, cx| {
1995 project.create_terminal_task(
1996 task::SpawnInTerminal {
1997 command: Some(program),
1998 args,
1999 cwd: cwd.clone(),
2000 env,
2001 ..Default::default()
2002 },
2003 cx,
2004 )
2005 })?
2006 .await?;
2007
2008 cx.new(|cx| {
2009 Terminal::new(
2010 terminal_id,
2011 command,
2012 cwd,
2013 output_byte_limit.map(|l| l as usize),
2014 terminal,
2015 language_registry,
2016 cx,
2017 )
2018 })
2019 }
2020 });
2021
2022 cx.spawn(async move |this, cx| {
2023 let terminal = terminal_task.await?;
2024 this.update(cx, |this, _cx| {
2025 this.terminals.insert(terminal_id, terminal.clone());
2026 terminal
2027 })
2028 })
2029 }
2030
2031 pub fn kill_terminal(
2032 &mut self,
2033 terminal_id: acp::TerminalId,
2034 cx: &mut Context<Self>,
2035 ) -> Result<()> {
2036 self.terminals
2037 .get(&terminal_id)
2038 .context("Terminal not found")?
2039 .update(cx, |terminal, cx| {
2040 terminal.kill(cx);
2041 });
2042
2043 Ok(())
2044 }
2045
2046 pub fn release_terminal(
2047 &mut self,
2048 terminal_id: acp::TerminalId,
2049 cx: &mut Context<Self>,
2050 ) -> Result<()> {
2051 self.terminals
2052 .remove(&terminal_id)
2053 .context("Terminal not found")?
2054 .update(cx, |terminal, cx| {
2055 terminal.kill(cx);
2056 });
2057
2058 Ok(())
2059 }
2060
2061 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2062 self.terminals
2063 .get(&terminal_id)
2064 .context("Terminal not found")
2065 .cloned()
2066 }
2067
2068 pub fn to_markdown(&self, cx: &App) -> String {
2069 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2070 }
2071
2072 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2073 cx.emit(AcpThreadEvent::LoadError(error));
2074 }
2075}
2076
2077fn markdown_for_raw_output(
2078 raw_output: &serde_json::Value,
2079 language_registry: &Arc<LanguageRegistry>,
2080 cx: &mut App,
2081) -> Option<Entity<Markdown>> {
2082 match raw_output {
2083 serde_json::Value::Null => None,
2084 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2085 Markdown::new(
2086 value.to_string().into(),
2087 Some(language_registry.clone()),
2088 None,
2089 cx,
2090 )
2091 })),
2092 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2093 Markdown::new(
2094 value.to_string().into(),
2095 Some(language_registry.clone()),
2096 None,
2097 cx,
2098 )
2099 })),
2100 serde_json::Value::String(value) => Some(cx.new(|cx| {
2101 Markdown::new(
2102 value.clone().into(),
2103 Some(language_registry.clone()),
2104 None,
2105 cx,
2106 )
2107 })),
2108 value => Some(cx.new(|cx| {
2109 Markdown::new(
2110 format!("```json\n{}\n```", value).into(),
2111 Some(language_registry.clone()),
2112 None,
2113 cx,
2114 )
2115 })),
2116 }
2117}
2118
2119#[cfg(test)]
2120mod tests {
2121 use super::*;
2122 use anyhow::anyhow;
2123 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2124 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2125 use indoc::indoc;
2126 use project::{FakeFs, Fs};
2127 use rand::{distr, prelude::*};
2128 use serde_json::json;
2129 use settings::SettingsStore;
2130 use smol::stream::StreamExt as _;
2131 use std::{
2132 any::Any,
2133 cell::RefCell,
2134 path::Path,
2135 rc::Rc,
2136 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2137 time::Duration,
2138 };
2139 use util::path;
2140
2141 fn init_test(cx: &mut TestAppContext) {
2142 env_logger::try_init().ok();
2143 cx.update(|cx| {
2144 let settings_store = SettingsStore::test(cx);
2145 cx.set_global(settings_store);
2146 Project::init_settings(cx);
2147 language::init(cx);
2148 });
2149 }
2150
2151 #[gpui::test]
2152 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2153 init_test(cx);
2154
2155 let fs = FakeFs::new(cx.executor());
2156 let project = Project::test(fs, [], cx).await;
2157 let connection = Rc::new(FakeAgentConnection::new());
2158 let thread = cx
2159 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2160 .await
2161 .unwrap();
2162
2163 // Test creating a new user message
2164 thread.update(cx, |thread, cx| {
2165 thread.push_user_content_block(
2166 None,
2167 acp::ContentBlock::Text(acp::TextContent {
2168 annotations: None,
2169 text: "Hello, ".to_string(),
2170 }),
2171 cx,
2172 );
2173 });
2174
2175 thread.update(cx, |thread, cx| {
2176 assert_eq!(thread.entries.len(), 1);
2177 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2178 assert_eq!(user_msg.id, None);
2179 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2180 } else {
2181 panic!("Expected UserMessage");
2182 }
2183 });
2184
2185 // Test appending to existing user message
2186 let message_1_id = UserMessageId::new();
2187 thread.update(cx, |thread, cx| {
2188 thread.push_user_content_block(
2189 Some(message_1_id.clone()),
2190 acp::ContentBlock::Text(acp::TextContent {
2191 annotations: None,
2192 text: "world!".to_string(),
2193 }),
2194 cx,
2195 );
2196 });
2197
2198 thread.update(cx, |thread, cx| {
2199 assert_eq!(thread.entries.len(), 1);
2200 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2201 assert_eq!(user_msg.id, Some(message_1_id));
2202 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2203 } else {
2204 panic!("Expected UserMessage");
2205 }
2206 });
2207
2208 // Test creating new user message after assistant message
2209 thread.update(cx, |thread, cx| {
2210 thread.push_assistant_content_block(
2211 acp::ContentBlock::Text(acp::TextContent {
2212 annotations: None,
2213 text: "Assistant response".to_string(),
2214 }),
2215 false,
2216 cx,
2217 );
2218 });
2219
2220 let message_2_id = UserMessageId::new();
2221 thread.update(cx, |thread, cx| {
2222 thread.push_user_content_block(
2223 Some(message_2_id.clone()),
2224 acp::ContentBlock::Text(acp::TextContent {
2225 annotations: None,
2226 text: "New user message".to_string(),
2227 }),
2228 cx,
2229 );
2230 });
2231
2232 thread.update(cx, |thread, cx| {
2233 assert_eq!(thread.entries.len(), 3);
2234 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2235 assert_eq!(user_msg.id, Some(message_2_id));
2236 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2237 } else {
2238 panic!("Expected UserMessage at index 2");
2239 }
2240 });
2241 }
2242
2243 #[gpui::test]
2244 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2245 init_test(cx);
2246
2247 let fs = FakeFs::new(cx.executor());
2248 let project = Project::test(fs, [], cx).await;
2249 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2250 |_, thread, mut cx| {
2251 async move {
2252 thread.update(&mut cx, |thread, cx| {
2253 thread
2254 .handle_session_update(
2255 acp::SessionUpdate::AgentThoughtChunk {
2256 content: "Thinking ".into(),
2257 },
2258 cx,
2259 )
2260 .unwrap();
2261 thread
2262 .handle_session_update(
2263 acp::SessionUpdate::AgentThoughtChunk {
2264 content: "hard!".into(),
2265 },
2266 cx,
2267 )
2268 .unwrap();
2269 })?;
2270 Ok(acp::PromptResponse {
2271 stop_reason: acp::StopReason::EndTurn,
2272 })
2273 }
2274 .boxed_local()
2275 },
2276 ));
2277
2278 let thread = cx
2279 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2280 .await
2281 .unwrap();
2282
2283 thread
2284 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2285 .await
2286 .unwrap();
2287
2288 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2289 assert_eq!(
2290 output,
2291 indoc! {r#"
2292 ## User
2293
2294 Hello from Zed!
2295
2296 ## Assistant
2297
2298 <thinking>
2299 Thinking hard!
2300 </thinking>
2301
2302 "#}
2303 );
2304 }
2305
2306 #[gpui::test]
2307 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2308 init_test(cx);
2309
2310 let fs = FakeFs::new(cx.executor());
2311 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2312 .await;
2313 let project = Project::test(fs.clone(), [], cx).await;
2314 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2315 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2316 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2317 move |_, thread, mut cx| {
2318 let read_file_tx = read_file_tx.clone();
2319 async move {
2320 let content = thread
2321 .update(&mut cx, |thread, cx| {
2322 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2323 })
2324 .unwrap()
2325 .await
2326 .unwrap();
2327 assert_eq!(content, "one\ntwo\nthree\n");
2328 read_file_tx.take().unwrap().send(()).unwrap();
2329 thread
2330 .update(&mut cx, |thread, cx| {
2331 thread.write_text_file(
2332 path!("/tmp/foo").into(),
2333 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2334 cx,
2335 )
2336 })
2337 .unwrap()
2338 .await
2339 .unwrap();
2340 Ok(acp::PromptResponse {
2341 stop_reason: acp::StopReason::EndTurn,
2342 })
2343 }
2344 .boxed_local()
2345 },
2346 ));
2347
2348 let (worktree, pathbuf) = project
2349 .update(cx, |project, cx| {
2350 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2351 })
2352 .await
2353 .unwrap();
2354 let buffer = project
2355 .update(cx, |project, cx| {
2356 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2357 })
2358 .await
2359 .unwrap();
2360
2361 let thread = cx
2362 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2363 .await
2364 .unwrap();
2365
2366 let request = thread.update(cx, |thread, cx| {
2367 thread.send_raw("Extend the count in /tmp/foo", cx)
2368 });
2369 read_file_rx.await.ok();
2370 buffer.update(cx, |buffer, cx| {
2371 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2372 });
2373 cx.run_until_parked();
2374 assert_eq!(
2375 buffer.read_with(cx, |buffer, _| buffer.text()),
2376 "zero\none\ntwo\nthree\nfour\nfive\n"
2377 );
2378 assert_eq!(
2379 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2380 "zero\none\ntwo\nthree\nfour\nfive\n"
2381 );
2382 request.await.unwrap();
2383 }
2384
2385 #[gpui::test]
2386 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2387 init_test(cx);
2388
2389 let fs = FakeFs::new(cx.executor());
2390 let project = Project::test(fs, [], cx).await;
2391 let id = acp::ToolCallId("test".into());
2392
2393 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2394 let id = id.clone();
2395 move |_, thread, mut cx| {
2396 let id = id.clone();
2397 async move {
2398 thread
2399 .update(&mut cx, |thread, cx| {
2400 thread.handle_session_update(
2401 acp::SessionUpdate::ToolCall(acp::ToolCall {
2402 id: id.clone(),
2403 title: "Label".into(),
2404 kind: acp::ToolKind::Fetch,
2405 status: acp::ToolCallStatus::InProgress,
2406 content: vec![],
2407 locations: vec![],
2408 raw_input: None,
2409 raw_output: None,
2410 }),
2411 cx,
2412 )
2413 })
2414 .unwrap()
2415 .unwrap();
2416 Ok(acp::PromptResponse {
2417 stop_reason: acp::StopReason::EndTurn,
2418 })
2419 }
2420 .boxed_local()
2421 }
2422 }));
2423
2424 let thread = cx
2425 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2426 .await
2427 .unwrap();
2428
2429 let request = thread.update(cx, |thread, cx| {
2430 thread.send_raw("Fetch https://example.com", cx)
2431 });
2432
2433 run_until_first_tool_call(&thread, cx).await;
2434
2435 thread.read_with(cx, |thread, _| {
2436 assert!(matches!(
2437 thread.entries[1],
2438 AgentThreadEntry::ToolCall(ToolCall {
2439 status: ToolCallStatus::InProgress,
2440 ..
2441 })
2442 ));
2443 });
2444
2445 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2446
2447 thread.read_with(cx, |thread, _| {
2448 assert!(matches!(
2449 &thread.entries[1],
2450 AgentThreadEntry::ToolCall(ToolCall {
2451 status: ToolCallStatus::Canceled,
2452 ..
2453 })
2454 ));
2455 });
2456
2457 thread
2458 .update(cx, |thread, cx| {
2459 thread.handle_session_update(
2460 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2461 id,
2462 fields: acp::ToolCallUpdateFields {
2463 status: Some(acp::ToolCallStatus::Completed),
2464 ..Default::default()
2465 },
2466 }),
2467 cx,
2468 )
2469 })
2470 .unwrap();
2471
2472 request.await.unwrap();
2473
2474 thread.read_with(cx, |thread, _| {
2475 assert!(matches!(
2476 thread.entries[1],
2477 AgentThreadEntry::ToolCall(ToolCall {
2478 status: ToolCallStatus::Completed,
2479 ..
2480 })
2481 ));
2482 });
2483 }
2484
2485 #[gpui::test]
2486 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2487 init_test(cx);
2488 let fs = FakeFs::new(cx.background_executor.clone());
2489 fs.insert_tree(path!("/test"), json!({})).await;
2490 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2491
2492 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2493 move |_, thread, mut cx| {
2494 async move {
2495 thread
2496 .update(&mut cx, |thread, cx| {
2497 thread.handle_session_update(
2498 acp::SessionUpdate::ToolCall(acp::ToolCall {
2499 id: acp::ToolCallId("test".into()),
2500 title: "Label".into(),
2501 kind: acp::ToolKind::Edit,
2502 status: acp::ToolCallStatus::Completed,
2503 content: vec![acp::ToolCallContent::Diff {
2504 diff: acp::Diff {
2505 path: "/test/test.txt".into(),
2506 old_text: None,
2507 new_text: "foo".into(),
2508 },
2509 }],
2510 locations: vec![],
2511 raw_input: None,
2512 raw_output: None,
2513 }),
2514 cx,
2515 )
2516 })
2517 .unwrap()
2518 .unwrap();
2519 Ok(acp::PromptResponse {
2520 stop_reason: acp::StopReason::EndTurn,
2521 })
2522 }
2523 .boxed_local()
2524 }
2525 }));
2526
2527 let thread = cx
2528 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2529 .await
2530 .unwrap();
2531
2532 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2533 .await
2534 .unwrap();
2535
2536 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2537 }
2538
2539 #[gpui::test(iterations = 10)]
2540 async fn test_checkpoints(cx: &mut TestAppContext) {
2541 init_test(cx);
2542 let fs = FakeFs::new(cx.background_executor.clone());
2543 fs.insert_tree(
2544 path!("/test"),
2545 json!({
2546 ".git": {}
2547 }),
2548 )
2549 .await;
2550 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2551
2552 let simulate_changes = Arc::new(AtomicBool::new(true));
2553 let next_filename = Arc::new(AtomicUsize::new(0));
2554 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2555 let simulate_changes = simulate_changes.clone();
2556 let next_filename = next_filename.clone();
2557 let fs = fs.clone();
2558 move |request, thread, mut cx| {
2559 let fs = fs.clone();
2560 let simulate_changes = simulate_changes.clone();
2561 let next_filename = next_filename.clone();
2562 async move {
2563 if simulate_changes.load(SeqCst) {
2564 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2565 fs.write(Path::new(&filename), b"").await?;
2566 }
2567
2568 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2569 panic!("expected text content block");
2570 };
2571 thread.update(&mut cx, |thread, cx| {
2572 thread
2573 .handle_session_update(
2574 acp::SessionUpdate::AgentMessageChunk {
2575 content: content.text.to_uppercase().into(),
2576 },
2577 cx,
2578 )
2579 .unwrap();
2580 })?;
2581 Ok(acp::PromptResponse {
2582 stop_reason: acp::StopReason::EndTurn,
2583 })
2584 }
2585 .boxed_local()
2586 }
2587 }));
2588 let thread = cx
2589 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2590 .await
2591 .unwrap();
2592
2593 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2594 .await
2595 .unwrap();
2596 thread.read_with(cx, |thread, cx| {
2597 assert_eq!(
2598 thread.to_markdown(cx),
2599 indoc! {"
2600 ## User (checkpoint)
2601
2602 Lorem
2603
2604 ## Assistant
2605
2606 LOREM
2607
2608 "}
2609 );
2610 });
2611 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2612
2613 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2614 .await
2615 .unwrap();
2616 thread.read_with(cx, |thread, cx| {
2617 assert_eq!(
2618 thread.to_markdown(cx),
2619 indoc! {"
2620 ## User (checkpoint)
2621
2622 Lorem
2623
2624 ## Assistant
2625
2626 LOREM
2627
2628 ## User (checkpoint)
2629
2630 ipsum
2631
2632 ## Assistant
2633
2634 IPSUM
2635
2636 "}
2637 );
2638 });
2639 assert_eq!(
2640 fs.files(),
2641 vec![
2642 Path::new(path!("/test/file-0")),
2643 Path::new(path!("/test/file-1"))
2644 ]
2645 );
2646
2647 // Checkpoint isn't stored when there are no changes.
2648 simulate_changes.store(false, SeqCst);
2649 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2650 .await
2651 .unwrap();
2652 thread.read_with(cx, |thread, cx| {
2653 assert_eq!(
2654 thread.to_markdown(cx),
2655 indoc! {"
2656 ## User (checkpoint)
2657
2658 Lorem
2659
2660 ## Assistant
2661
2662 LOREM
2663
2664 ## User (checkpoint)
2665
2666 ipsum
2667
2668 ## Assistant
2669
2670 IPSUM
2671
2672 ## User
2673
2674 dolor
2675
2676 ## Assistant
2677
2678 DOLOR
2679
2680 "}
2681 );
2682 });
2683 assert_eq!(
2684 fs.files(),
2685 vec![
2686 Path::new(path!("/test/file-0")),
2687 Path::new(path!("/test/file-1"))
2688 ]
2689 );
2690
2691 // Rewinding the conversation truncates the history and restores the checkpoint.
2692 thread
2693 .update(cx, |thread, cx| {
2694 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2695 panic!("unexpected entries {:?}", thread.entries)
2696 };
2697 thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2698 })
2699 .await
2700 .unwrap();
2701 thread.read_with(cx, |thread, cx| {
2702 assert_eq!(
2703 thread.to_markdown(cx),
2704 indoc! {"
2705 ## User (checkpoint)
2706
2707 Lorem
2708
2709 ## Assistant
2710
2711 LOREM
2712
2713 "}
2714 );
2715 });
2716 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2717 }
2718
2719 #[gpui::test]
2720 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2721 use std::sync::atomic::AtomicUsize;
2722 init_test(cx);
2723
2724 let fs = FakeFs::new(cx.executor());
2725 let project = Project::test(fs, None, cx).await;
2726
2727 // Create a connection that simulates refusal after tool result
2728 let prompt_count = Arc::new(AtomicUsize::new(0));
2729 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2730 let prompt_count = prompt_count.clone();
2731 move |_request, thread, mut cx| {
2732 let count = prompt_count.fetch_add(1, SeqCst);
2733 async move {
2734 if count == 0 {
2735 // First prompt: Generate a tool call with result
2736 thread.update(&mut cx, |thread, cx| {
2737 thread
2738 .handle_session_update(
2739 acp::SessionUpdate::ToolCall(acp::ToolCall {
2740 id: acp::ToolCallId("tool1".into()),
2741 title: "Test Tool".into(),
2742 kind: acp::ToolKind::Fetch,
2743 status: acp::ToolCallStatus::Completed,
2744 content: vec![],
2745 locations: vec![],
2746 raw_input: Some(serde_json::json!({"query": "test"})),
2747 raw_output: Some(
2748 serde_json::json!({"result": "inappropriate content"}),
2749 ),
2750 }),
2751 cx,
2752 )
2753 .unwrap();
2754 })?;
2755
2756 // Now return refusal because of the tool result
2757 Ok(acp::PromptResponse {
2758 stop_reason: acp::StopReason::Refusal,
2759 })
2760 } else {
2761 Ok(acp::PromptResponse {
2762 stop_reason: acp::StopReason::EndTurn,
2763 })
2764 }
2765 }
2766 .boxed_local()
2767 }
2768 }));
2769
2770 let thread = cx
2771 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2772 .await
2773 .unwrap();
2774
2775 // Track if we see a Refusal event
2776 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2777 let saw_refusal_event_captured = saw_refusal_event.clone();
2778 thread.update(cx, |_thread, cx| {
2779 cx.subscribe(
2780 &thread,
2781 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2782 if matches!(event, AcpThreadEvent::Refusal) {
2783 *saw_refusal_event_captured.lock().unwrap() = true;
2784 }
2785 },
2786 )
2787 .detach();
2788 });
2789
2790 // Send a user message - this will trigger tool call and then refusal
2791 let send_task = thread.update(cx, |thread, cx| {
2792 thread.send(
2793 vec![acp::ContentBlock::Text(acp::TextContent {
2794 text: "Hello".into(),
2795 annotations: None,
2796 })],
2797 cx,
2798 )
2799 });
2800 cx.background_executor.spawn(send_task).detach();
2801 cx.run_until_parked();
2802
2803 // Verify that:
2804 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2805 // 2. The user message was NOT truncated
2806 assert!(
2807 *saw_refusal_event.lock().unwrap(),
2808 "Refusal event should be emitted for tool result refusals"
2809 );
2810
2811 thread.read_with(cx, |thread, _| {
2812 let entries = thread.entries();
2813 assert!(entries.len() >= 2, "Should have user message and tool call");
2814
2815 // Verify user message is still there
2816 assert!(
2817 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2818 "User message should not be truncated"
2819 );
2820
2821 // Verify tool call is there with result
2822 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2823 assert!(
2824 tool_call.raw_output.is_some(),
2825 "Tool call should have output"
2826 );
2827 } else {
2828 panic!("Expected tool call at index 1");
2829 }
2830 });
2831 }
2832
2833 #[gpui::test]
2834 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2835 init_test(cx);
2836
2837 let fs = FakeFs::new(cx.executor());
2838 let project = Project::test(fs, None, cx).await;
2839
2840 let refuse_next = Arc::new(AtomicBool::new(false));
2841 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2842 let refuse_next = refuse_next.clone();
2843 move |_request, _thread, _cx| {
2844 if refuse_next.load(SeqCst) {
2845 async move {
2846 Ok(acp::PromptResponse {
2847 stop_reason: acp::StopReason::Refusal,
2848 })
2849 }
2850 .boxed_local()
2851 } else {
2852 async move {
2853 Ok(acp::PromptResponse {
2854 stop_reason: acp::StopReason::EndTurn,
2855 })
2856 }
2857 .boxed_local()
2858 }
2859 }
2860 }));
2861
2862 let thread = cx
2863 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2864 .await
2865 .unwrap();
2866
2867 // Track if we see a Refusal event
2868 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2869 let saw_refusal_event_captured = saw_refusal_event.clone();
2870 thread.update(cx, |_thread, cx| {
2871 cx.subscribe(
2872 &thread,
2873 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2874 if matches!(event, AcpThreadEvent::Refusal) {
2875 *saw_refusal_event_captured.lock().unwrap() = true;
2876 }
2877 },
2878 )
2879 .detach();
2880 });
2881
2882 // Send a message that will be refused
2883 refuse_next.store(true, SeqCst);
2884 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2885 .await
2886 .unwrap();
2887
2888 // Verify that a Refusal event WAS emitted for user prompt refusal
2889 assert!(
2890 *saw_refusal_event.lock().unwrap(),
2891 "Refusal event should be emitted for user prompt refusals"
2892 );
2893
2894 // Verify the message was truncated (user prompt refusal)
2895 thread.read_with(cx, |thread, cx| {
2896 assert_eq!(thread.to_markdown(cx), "");
2897 });
2898 }
2899
2900 #[gpui::test]
2901 async fn test_refusal(cx: &mut TestAppContext) {
2902 init_test(cx);
2903 let fs = FakeFs::new(cx.background_executor.clone());
2904 fs.insert_tree(path!("/"), json!({})).await;
2905 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2906
2907 let refuse_next = Arc::new(AtomicBool::new(false));
2908 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2909 let refuse_next = refuse_next.clone();
2910 move |request, thread, mut cx| {
2911 let refuse_next = refuse_next.clone();
2912 async move {
2913 if refuse_next.load(SeqCst) {
2914 return Ok(acp::PromptResponse {
2915 stop_reason: acp::StopReason::Refusal,
2916 });
2917 }
2918
2919 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2920 panic!("expected text content block");
2921 };
2922 thread.update(&mut cx, |thread, cx| {
2923 thread
2924 .handle_session_update(
2925 acp::SessionUpdate::AgentMessageChunk {
2926 content: content.text.to_uppercase().into(),
2927 },
2928 cx,
2929 )
2930 .unwrap();
2931 })?;
2932 Ok(acp::PromptResponse {
2933 stop_reason: acp::StopReason::EndTurn,
2934 })
2935 }
2936 .boxed_local()
2937 }
2938 }));
2939 let thread = cx
2940 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2941 .await
2942 .unwrap();
2943
2944 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2945 .await
2946 .unwrap();
2947 thread.read_with(cx, |thread, cx| {
2948 assert_eq!(
2949 thread.to_markdown(cx),
2950 indoc! {"
2951 ## User
2952
2953 hello
2954
2955 ## Assistant
2956
2957 HELLO
2958
2959 "}
2960 );
2961 });
2962
2963 // Simulate refusing the second message. The message should be truncated
2964 // when a user prompt is refused.
2965 refuse_next.store(true, SeqCst);
2966 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2967 .await
2968 .unwrap();
2969 thread.read_with(cx, |thread, cx| {
2970 assert_eq!(
2971 thread.to_markdown(cx),
2972 indoc! {"
2973 ## User
2974
2975 hello
2976
2977 ## Assistant
2978
2979 HELLO
2980
2981 "}
2982 );
2983 });
2984 }
2985
2986 async fn run_until_first_tool_call(
2987 thread: &Entity<AcpThread>,
2988 cx: &mut TestAppContext,
2989 ) -> usize {
2990 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2991
2992 let subscription = cx.update(|cx| {
2993 cx.subscribe(thread, move |thread, _, cx| {
2994 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2995 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2996 return tx.try_send(ix).unwrap();
2997 }
2998 }
2999 })
3000 });
3001
3002 select! {
3003 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3004 panic!("Timeout waiting for tool call")
3005 }
3006 ix = rx.next().fuse() => {
3007 drop(subscription);
3008 ix.unwrap()
3009 }
3010 }
3011 }
3012
3013 #[derive(Clone, Default)]
3014 struct FakeAgentConnection {
3015 auth_methods: Vec<acp::AuthMethod>,
3016 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3017 on_user_message: Option<
3018 Rc<
3019 dyn Fn(
3020 acp::PromptRequest,
3021 WeakEntity<AcpThread>,
3022 AsyncApp,
3023 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3024 + 'static,
3025 >,
3026 >,
3027 }
3028
3029 impl FakeAgentConnection {
3030 fn new() -> Self {
3031 Self {
3032 auth_methods: Vec::new(),
3033 on_user_message: None,
3034 sessions: Arc::default(),
3035 }
3036 }
3037
3038 #[expect(unused)]
3039 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3040 self.auth_methods = auth_methods;
3041 self
3042 }
3043
3044 fn on_user_message(
3045 mut self,
3046 handler: impl Fn(
3047 acp::PromptRequest,
3048 WeakEntity<AcpThread>,
3049 AsyncApp,
3050 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3051 + 'static,
3052 ) -> Self {
3053 self.on_user_message.replace(Rc::new(handler));
3054 self
3055 }
3056 }
3057
3058 impl AgentConnection for FakeAgentConnection {
3059 fn auth_methods(&self) -> &[acp::AuthMethod] {
3060 &self.auth_methods
3061 }
3062
3063 fn new_thread(
3064 self: Rc<Self>,
3065 project: Entity<Project>,
3066 _cwd: &Path,
3067 cx: &mut App,
3068 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3069 let session_id = acp::SessionId(
3070 rand::rng()
3071 .sample_iter(&distr::Alphanumeric)
3072 .take(7)
3073 .map(char::from)
3074 .collect::<String>()
3075 .into(),
3076 );
3077 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3078 let thread = cx.new(|cx| {
3079 AcpThread::new(
3080 "Test",
3081 self.clone(),
3082 project,
3083 action_log,
3084 session_id.clone(),
3085 watch::Receiver::constant(acp::PromptCapabilities {
3086 image: true,
3087 audio: true,
3088 embedded_context: true,
3089 }),
3090 cx,
3091 )
3092 });
3093 self.sessions.lock().insert(session_id, thread.downgrade());
3094 Task::ready(Ok(thread))
3095 }
3096
3097 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3098 if self.auth_methods().iter().any(|m| m.id == method) {
3099 Task::ready(Ok(()))
3100 } else {
3101 Task::ready(Err(anyhow!("Invalid Auth Method")))
3102 }
3103 }
3104
3105 fn prompt(
3106 &self,
3107 _id: Option<UserMessageId>,
3108 params: acp::PromptRequest,
3109 cx: &mut App,
3110 ) -> Task<gpui::Result<acp::PromptResponse>> {
3111 let sessions = self.sessions.lock();
3112 let thread = sessions.get(¶ms.session_id).unwrap();
3113 if let Some(handler) = &self.on_user_message {
3114 let handler = handler.clone();
3115 let thread = thread.clone();
3116 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3117 } else {
3118 Task::ready(Ok(acp::PromptResponse {
3119 stop_reason: acp::StopReason::EndTurn,
3120 }))
3121 }
3122 }
3123
3124 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3125 let sessions = self.sessions.lock();
3126 let thread = sessions.get(session_id).unwrap().clone();
3127
3128 cx.spawn(async move |cx| {
3129 thread
3130 .update(cx, |thread, cx| thread.cancel(cx))
3131 .unwrap()
3132 .await
3133 })
3134 .detach();
3135 }
3136
3137 fn truncate(
3138 &self,
3139 session_id: &acp::SessionId,
3140 _cx: &App,
3141 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3142 Some(Rc::new(FakeAgentSessionEditor {
3143 _session_id: session_id.clone(),
3144 }))
3145 }
3146
3147 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3148 self
3149 }
3150 }
3151
3152 struct FakeAgentSessionEditor {
3153 _session_id: acp::SessionId,
3154 }
3155
3156 impl AgentSessionTruncate for FakeAgentSessionEditor {
3157 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3158 Task::ready(Ok(()))
3159 }
3160 }
3161}