1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use futures::future::Shared;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path,
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 available_commands: Vec<acp::AvailableCommand>,
789 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
790 determine_shell: Shared<Task<String>>,
791 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
792}
793
794#[derive(Debug)]
795pub enum AcpThreadEvent {
796 NewEntry,
797 TitleUpdated,
798 TokenUsageUpdated,
799 EntryUpdated(usize),
800 EntriesRemoved(Range<usize>),
801 ToolAuthorizationRequired,
802 Retry(RetryStatus),
803 Stopped,
804 Error,
805 LoadError(LoadError),
806 PromptCapabilitiesUpdated,
807 Refusal,
808}
809
810impl EventEmitter<AcpThreadEvent> for AcpThread {}
811
812#[derive(PartialEq, Eq, Debug)]
813pub enum ThreadStatus {
814 Idle,
815 WaitingForToolConfirmation,
816 Generating,
817}
818
819#[derive(Debug, Clone)]
820pub enum LoadError {
821 Unsupported {
822 command: SharedString,
823 current_version: SharedString,
824 minimum_version: SharedString,
825 },
826 FailedToInstall(SharedString),
827 Exited {
828 status: ExitStatus,
829 },
830 Other(SharedString),
831}
832
833impl Display for LoadError {
834 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
835 match self {
836 LoadError::Unsupported {
837 command: path,
838 current_version,
839 minimum_version,
840 } => {
841 write!(
842 f,
843 "version {current_version} from {path} is not supported (need at least {minimum_version})"
844 )
845 }
846 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
847 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
848 LoadError::Other(msg) => write!(f, "{msg}"),
849 }
850 }
851}
852
853impl Error for LoadError {}
854
855impl AcpThread {
856 pub fn new(
857 title: impl Into<SharedString>,
858 connection: Rc<dyn AgentConnection>,
859 project: Entity<Project>,
860 action_log: Entity<ActionLog>,
861 session_id: acp::SessionId,
862 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
863 available_commands: Vec<acp::AvailableCommand>,
864 cx: &mut Context<Self>,
865 ) -> Self {
866 let prompt_capabilities = *prompt_capabilities_rx.borrow();
867 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
868 loop {
869 let caps = prompt_capabilities_rx.recv().await?;
870 this.update(cx, |this, cx| {
871 this.prompt_capabilities = caps;
872 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
873 })?;
874 }
875 });
876
877 let determine_shell = cx
878 .background_spawn(async move {
879 if cfg!(windows) {
880 return get_system_shell();
881 }
882
883 if which::which("bash").is_ok() {
884 "bash".into()
885 } else {
886 get_system_shell()
887 }
888 })
889 .shared();
890
891 Self {
892 action_log,
893 shared_buffers: Default::default(),
894 entries: Default::default(),
895 plan: Default::default(),
896 title: title.into(),
897 project,
898 send_task: None,
899 connection,
900 session_id,
901 token_usage: None,
902 prompt_capabilities,
903 available_commands,
904 _observe_prompt_capabilities: task,
905 terminals: HashMap::default(),
906 determine_shell,
907 }
908 }
909
910 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
911 self.prompt_capabilities
912 }
913
914 pub fn available_commands(&self) -> Vec<acp::AvailableCommand> {
915 self.available_commands.clone()
916 }
917
918 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
919 &self.connection
920 }
921
922 pub fn action_log(&self) -> &Entity<ActionLog> {
923 &self.action_log
924 }
925
926 pub fn project(&self) -> &Entity<Project> {
927 &self.project
928 }
929
930 pub fn title(&self) -> SharedString {
931 self.title.clone()
932 }
933
934 pub fn entries(&self) -> &[AgentThreadEntry] {
935 &self.entries
936 }
937
938 pub fn session_id(&self) -> &acp::SessionId {
939 &self.session_id
940 }
941
942 pub fn status(&self) -> ThreadStatus {
943 if self.send_task.is_some() {
944 if self.waiting_for_tool_confirmation() {
945 ThreadStatus::WaitingForToolConfirmation
946 } else {
947 ThreadStatus::Generating
948 }
949 } else {
950 ThreadStatus::Idle
951 }
952 }
953
954 pub fn token_usage(&self) -> Option<&TokenUsage> {
955 self.token_usage.as_ref()
956 }
957
958 pub fn has_pending_edit_tool_calls(&self) -> bool {
959 for entry in self.entries.iter().rev() {
960 match entry {
961 AgentThreadEntry::UserMessage(_) => return false,
962 AgentThreadEntry::ToolCall(
963 call @ ToolCall {
964 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
965 ..
966 },
967 ) if call.diffs().next().is_some() => {
968 return true;
969 }
970 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
971 }
972 }
973
974 false
975 }
976
977 pub fn used_tools_since_last_user_message(&self) -> bool {
978 for entry in self.entries.iter().rev() {
979 match entry {
980 AgentThreadEntry::UserMessage(..) => return false,
981 AgentThreadEntry::AssistantMessage(..) => continue,
982 AgentThreadEntry::ToolCall(..) => return true,
983 }
984 }
985
986 false
987 }
988
989 pub fn handle_session_update(
990 &mut self,
991 update: acp::SessionUpdate,
992 cx: &mut Context<Self>,
993 ) -> Result<(), acp::Error> {
994 match update {
995 acp::SessionUpdate::UserMessageChunk { content } => {
996 self.push_user_content_block(None, content, cx);
997 }
998 acp::SessionUpdate::AgentMessageChunk { content } => {
999 self.push_assistant_content_block(content, false, cx);
1000 }
1001 acp::SessionUpdate::AgentThoughtChunk { content } => {
1002 self.push_assistant_content_block(content, true, cx);
1003 }
1004 acp::SessionUpdate::ToolCall(tool_call) => {
1005 self.upsert_tool_call(tool_call, cx)?;
1006 }
1007 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1008 self.update_tool_call(tool_call_update, cx)?;
1009 }
1010 acp::SessionUpdate::Plan(plan) => {
1011 self.update_plan(plan, cx);
1012 }
1013 }
1014 Ok(())
1015 }
1016
1017 pub fn push_user_content_block(
1018 &mut self,
1019 message_id: Option<UserMessageId>,
1020 chunk: acp::ContentBlock,
1021 cx: &mut Context<Self>,
1022 ) {
1023 let language_registry = self.project.read(cx).languages().clone();
1024 let entries_len = self.entries.len();
1025
1026 if let Some(last_entry) = self.entries.last_mut()
1027 && let AgentThreadEntry::UserMessage(UserMessage {
1028 id,
1029 content,
1030 chunks,
1031 ..
1032 }) = last_entry
1033 {
1034 *id = message_id.or(id.take());
1035 content.append(chunk.clone(), &language_registry, cx);
1036 chunks.push(chunk);
1037 let idx = entries_len - 1;
1038 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1039 } else {
1040 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1041 self.push_entry(
1042 AgentThreadEntry::UserMessage(UserMessage {
1043 id: message_id,
1044 content,
1045 chunks: vec![chunk],
1046 checkpoint: None,
1047 }),
1048 cx,
1049 );
1050 }
1051 }
1052
1053 pub fn push_assistant_content_block(
1054 &mut self,
1055 chunk: acp::ContentBlock,
1056 is_thought: bool,
1057 cx: &mut Context<Self>,
1058 ) {
1059 let language_registry = self.project.read(cx).languages().clone();
1060 let entries_len = self.entries.len();
1061 if let Some(last_entry) = self.entries.last_mut()
1062 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1063 {
1064 let idx = entries_len - 1;
1065 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1066 match (chunks.last_mut(), is_thought) {
1067 (Some(AssistantMessageChunk::Message { block }), false)
1068 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1069 block.append(chunk, &language_registry, cx)
1070 }
1071 _ => {
1072 let block = ContentBlock::new(chunk, &language_registry, cx);
1073 if is_thought {
1074 chunks.push(AssistantMessageChunk::Thought { block })
1075 } else {
1076 chunks.push(AssistantMessageChunk::Message { block })
1077 }
1078 }
1079 }
1080 } else {
1081 let block = ContentBlock::new(chunk, &language_registry, cx);
1082 let chunk = if is_thought {
1083 AssistantMessageChunk::Thought { block }
1084 } else {
1085 AssistantMessageChunk::Message { block }
1086 };
1087
1088 self.push_entry(
1089 AgentThreadEntry::AssistantMessage(AssistantMessage {
1090 chunks: vec![chunk],
1091 }),
1092 cx,
1093 );
1094 }
1095 }
1096
1097 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1098 self.entries.push(entry);
1099 cx.emit(AcpThreadEvent::NewEntry);
1100 }
1101
1102 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1103 self.connection.set_title(&self.session_id, cx).is_some()
1104 }
1105
1106 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1107 if title != self.title {
1108 self.title = title.clone();
1109 cx.emit(AcpThreadEvent::TitleUpdated);
1110 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1111 return set_title.run(title, cx);
1112 }
1113 }
1114 Task::ready(Ok(()))
1115 }
1116
1117 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1118 self.token_usage = usage;
1119 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1120 }
1121
1122 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1123 cx.emit(AcpThreadEvent::Retry(status));
1124 }
1125
1126 pub fn update_tool_call(
1127 &mut self,
1128 update: impl Into<ToolCallUpdate>,
1129 cx: &mut Context<Self>,
1130 ) -> Result<()> {
1131 let update = update.into();
1132 let languages = self.project.read(cx).languages().clone();
1133
1134 let ix = self
1135 .index_for_tool_call(update.id())
1136 .context("Tool call not found")?;
1137 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1138 unreachable!()
1139 };
1140
1141 match update {
1142 ToolCallUpdate::UpdateFields(update) => {
1143 let location_updated = update.fields.locations.is_some();
1144 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1145 if location_updated {
1146 self.resolve_locations(update.id, cx);
1147 }
1148 }
1149 ToolCallUpdate::UpdateDiff(update) => {
1150 call.content.clear();
1151 call.content.push(ToolCallContent::Diff(update.diff));
1152 }
1153 ToolCallUpdate::UpdateTerminal(update) => {
1154 call.content.clear();
1155 call.content
1156 .push(ToolCallContent::Terminal(update.terminal));
1157 }
1158 }
1159
1160 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1161
1162 Ok(())
1163 }
1164
1165 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1166 pub fn upsert_tool_call(
1167 &mut self,
1168 tool_call: acp::ToolCall,
1169 cx: &mut Context<Self>,
1170 ) -> Result<(), acp::Error> {
1171 let status = tool_call.status.into();
1172 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1173 }
1174
1175 /// Fails if id does not match an existing entry.
1176 pub fn upsert_tool_call_inner(
1177 &mut self,
1178 update: acp::ToolCallUpdate,
1179 status: ToolCallStatus,
1180 cx: &mut Context<Self>,
1181 ) -> Result<(), acp::Error> {
1182 let language_registry = self.project.read(cx).languages().clone();
1183 let id = update.id.clone();
1184
1185 if let Some(ix) = self.index_for_tool_call(&id) {
1186 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1187 unreachable!()
1188 };
1189
1190 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1191 call.status = status;
1192
1193 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1194 } else {
1195 let call = ToolCall::from_acp(
1196 update.try_into()?,
1197 status,
1198 language_registry,
1199 &self.terminals,
1200 cx,
1201 )?;
1202 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1203 };
1204
1205 self.resolve_locations(id, cx);
1206 Ok(())
1207 }
1208
1209 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1210 self.entries
1211 .iter()
1212 .enumerate()
1213 .rev()
1214 .find_map(|(index, entry)| {
1215 if let AgentThreadEntry::ToolCall(tool_call) = entry
1216 && &tool_call.id == id
1217 {
1218 Some(index)
1219 } else {
1220 None
1221 }
1222 })
1223 }
1224
1225 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1226 // The tool call we are looking for is typically the last one, or very close to the end.
1227 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1228 self.entries
1229 .iter_mut()
1230 .enumerate()
1231 .rev()
1232 .find_map(|(index, tool_call)| {
1233 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1234 && &tool_call.id == id
1235 {
1236 Some((index, tool_call))
1237 } else {
1238 None
1239 }
1240 })
1241 }
1242
1243 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1244 self.entries
1245 .iter()
1246 .enumerate()
1247 .rev()
1248 .find_map(|(index, tool_call)| {
1249 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1250 && &tool_call.id == id
1251 {
1252 Some((index, tool_call))
1253 } else {
1254 None
1255 }
1256 })
1257 }
1258
1259 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1260 let project = self.project.clone();
1261 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1262 return;
1263 };
1264 let task = tool_call.resolve_locations(project, cx);
1265 cx.spawn(async move |this, cx| {
1266 let resolved_locations = task.await;
1267 this.update(cx, |this, cx| {
1268 let project = this.project.clone();
1269 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1270 return;
1271 };
1272 if let Some(Some(location)) = resolved_locations.last() {
1273 project.update(cx, |project, cx| {
1274 if let Some(agent_location) = project.agent_location() {
1275 let should_ignore = agent_location.buffer == location.buffer
1276 && location
1277 .buffer
1278 .update(cx, |buffer, _| {
1279 let snapshot = buffer.snapshot();
1280 let old_position =
1281 agent_location.position.to_point(&snapshot);
1282 let new_position = location.position.to_point(&snapshot);
1283 // ignore this so that when we get updates from the edit tool
1284 // the position doesn't reset to the startof line
1285 old_position.row == new_position.row
1286 && old_position.column > new_position.column
1287 })
1288 .ok()
1289 .unwrap_or_default();
1290 if !should_ignore {
1291 project.set_agent_location(Some(location.clone()), cx);
1292 }
1293 }
1294 });
1295 }
1296 if tool_call.resolved_locations != resolved_locations {
1297 tool_call.resolved_locations = resolved_locations;
1298 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1299 }
1300 })
1301 })
1302 .detach();
1303 }
1304
1305 pub fn request_tool_call_authorization(
1306 &mut self,
1307 tool_call: acp::ToolCallUpdate,
1308 options: Vec<acp::PermissionOption>,
1309 cx: &mut Context<Self>,
1310 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1311 let (tx, rx) = oneshot::channel();
1312
1313 if AgentSettings::get_global(cx).always_allow_tool_actions {
1314 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1315 // some tools would (incorrectly) continue to auto-accept.
1316 if let Some(allow_once_option) = options.iter().find_map(|option| {
1317 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1318 Some(option.id.clone())
1319 } else {
1320 None
1321 }
1322 }) {
1323 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1324 return Ok(async {
1325 acp::RequestPermissionOutcome::Selected {
1326 option_id: allow_once_option,
1327 }
1328 }
1329 .boxed());
1330 }
1331 }
1332
1333 let status = ToolCallStatus::WaitingForConfirmation {
1334 options,
1335 respond_tx: tx,
1336 };
1337
1338 self.upsert_tool_call_inner(tool_call, status, cx)?;
1339 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1340
1341 let fut = async {
1342 match rx.await {
1343 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1344 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1345 }
1346 }
1347 .boxed();
1348
1349 Ok(fut)
1350 }
1351
1352 pub fn authorize_tool_call(
1353 &mut self,
1354 id: acp::ToolCallId,
1355 option_id: acp::PermissionOptionId,
1356 option_kind: acp::PermissionOptionKind,
1357 cx: &mut Context<Self>,
1358 ) {
1359 let Some((ix, call)) = self.tool_call_mut(&id) else {
1360 return;
1361 };
1362
1363 let new_status = match option_kind {
1364 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1365 ToolCallStatus::Rejected
1366 }
1367 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1368 ToolCallStatus::InProgress
1369 }
1370 };
1371
1372 let curr_status = mem::replace(&mut call.status, new_status);
1373
1374 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1375 respond_tx.send(option_id).log_err();
1376 } else if cfg!(debug_assertions) {
1377 panic!("tried to authorize an already authorized tool call");
1378 }
1379
1380 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1381 }
1382
1383 /// Returns true if the last turn is awaiting tool authorization
1384 pub fn waiting_for_tool_confirmation(&self) -> bool {
1385 for entry in self.entries.iter().rev() {
1386 match &entry {
1387 AgentThreadEntry::ToolCall(call) => match call.status {
1388 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1389 ToolCallStatus::Pending
1390 | ToolCallStatus::InProgress
1391 | ToolCallStatus::Completed
1392 | ToolCallStatus::Failed
1393 | ToolCallStatus::Rejected
1394 | ToolCallStatus::Canceled => continue,
1395 },
1396 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1397 // Reached the beginning of the turn
1398 return false;
1399 }
1400 }
1401 }
1402 false
1403 }
1404
1405 pub fn plan(&self) -> &Plan {
1406 &self.plan
1407 }
1408
1409 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1410 let new_entries_len = request.entries.len();
1411 let mut new_entries = request.entries.into_iter();
1412
1413 // Reuse existing markdown to prevent flickering
1414 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1415 let PlanEntry {
1416 content,
1417 priority,
1418 status,
1419 } = old;
1420 content.update(cx, |old, cx| {
1421 old.replace(new.content, cx);
1422 });
1423 *priority = new.priority;
1424 *status = new.status;
1425 }
1426 for new in new_entries {
1427 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1428 }
1429 self.plan.entries.truncate(new_entries_len);
1430
1431 cx.notify();
1432 }
1433
1434 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1435 self.plan
1436 .entries
1437 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1438 cx.notify();
1439 }
1440
1441 #[cfg(any(test, feature = "test-support"))]
1442 pub fn send_raw(
1443 &mut self,
1444 message: &str,
1445 cx: &mut Context<Self>,
1446 ) -> BoxFuture<'static, Result<()>> {
1447 self.send(
1448 vec![acp::ContentBlock::Text(acp::TextContent {
1449 text: message.to_string(),
1450 annotations: None,
1451 })],
1452 cx,
1453 )
1454 }
1455
1456 pub fn send(
1457 &mut self,
1458 message: Vec<acp::ContentBlock>,
1459 cx: &mut Context<Self>,
1460 ) -> BoxFuture<'static, Result<()>> {
1461 let block = ContentBlock::new_combined(
1462 message.clone(),
1463 self.project.read(cx).languages().clone(),
1464 cx,
1465 );
1466 let request = acp::PromptRequest {
1467 prompt: message.clone(),
1468 session_id: self.session_id.clone(),
1469 };
1470 let git_store = self.project.read(cx).git_store().clone();
1471
1472 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1473 Some(UserMessageId::new())
1474 } else {
1475 None
1476 };
1477
1478 self.run_turn(cx, async move |this, cx| {
1479 this.update(cx, |this, cx| {
1480 this.push_entry(
1481 AgentThreadEntry::UserMessage(UserMessage {
1482 id: message_id.clone(),
1483 content: block,
1484 chunks: message,
1485 checkpoint: None,
1486 }),
1487 cx,
1488 );
1489 })
1490 .ok();
1491
1492 let old_checkpoint = git_store
1493 .update(cx, |git, cx| git.checkpoint(cx))?
1494 .await
1495 .context("failed to get old checkpoint")
1496 .log_err();
1497 this.update(cx, |this, cx| {
1498 if let Some((_ix, message)) = this.last_user_message() {
1499 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1500 git_checkpoint,
1501 show: false,
1502 });
1503 }
1504 this.connection.prompt(message_id, request, cx)
1505 })?
1506 .await
1507 })
1508 }
1509
1510 pub fn can_resume(&self, cx: &App) -> bool {
1511 self.connection.resume(&self.session_id, cx).is_some()
1512 }
1513
1514 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1515 self.run_turn(cx, async move |this, cx| {
1516 this.update(cx, |this, cx| {
1517 this.connection
1518 .resume(&this.session_id, cx)
1519 .map(|resume| resume.run(cx))
1520 })?
1521 .context("resuming a session is not supported")?
1522 .await
1523 })
1524 }
1525
1526 fn run_turn(
1527 &mut self,
1528 cx: &mut Context<Self>,
1529 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1530 ) -> BoxFuture<'static, Result<()>> {
1531 self.clear_completed_plan_entries(cx);
1532
1533 let (tx, rx) = oneshot::channel();
1534 let cancel_task = self.cancel(cx);
1535
1536 self.send_task = Some(cx.spawn(async move |this, cx| {
1537 cancel_task.await;
1538 tx.send(f(this, cx).await).ok();
1539 }));
1540
1541 cx.spawn(async move |this, cx| {
1542 let response = rx.await;
1543
1544 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1545 .await?;
1546
1547 this.update(cx, |this, cx| {
1548 this.project
1549 .update(cx, |project, cx| project.set_agent_location(None, cx));
1550 match response {
1551 Ok(Err(e)) => {
1552 this.send_task.take();
1553 cx.emit(AcpThreadEvent::Error);
1554 Err(e)
1555 }
1556 result => {
1557 let canceled = matches!(
1558 result,
1559 Ok(Ok(acp::PromptResponse {
1560 stop_reason: acp::StopReason::Cancelled
1561 }))
1562 );
1563
1564 // We only take the task if the current prompt wasn't canceled.
1565 //
1566 // This prompt may have been canceled because another one was sent
1567 // while it was still generating. In these cases, dropping `send_task`
1568 // would cause the next generation to be canceled.
1569 if !canceled {
1570 this.send_task.take();
1571 }
1572
1573 // Handle refusal - distinguish between user prompt and tool call refusals
1574 if let Ok(Ok(acp::PromptResponse {
1575 stop_reason: acp::StopReason::Refusal,
1576 })) = result
1577 {
1578 if let Some((user_msg_ix, _)) = this.last_user_message() {
1579 // Check if there's a completed tool call with results after the last user message
1580 // This indicates the refusal is in response to tool output, not the user's prompt
1581 let has_completed_tool_call_after_user_msg =
1582 this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1583 if let AgentThreadEntry::ToolCall(tool_call) = entry {
1584 // Check if the tool call has completed and has output
1585 matches!(tool_call.status, ToolCallStatus::Completed)
1586 && tool_call.raw_output.is_some()
1587 } else {
1588 false
1589 }
1590 });
1591
1592 if has_completed_tool_call_after_user_msg {
1593 // Refusal is due to tool output - don't truncate, just notify
1594 // The model refused based on what the tool returned
1595 cx.emit(AcpThreadEvent::Refusal);
1596 } else {
1597 // User prompt was refused - truncate back to before the user message
1598 let range = user_msg_ix..this.entries.len();
1599 if range.start < range.end {
1600 this.entries.truncate(user_msg_ix);
1601 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1602 }
1603 cx.emit(AcpThreadEvent::Refusal);
1604 }
1605 } else {
1606 // No user message found, treat as general refusal
1607 cx.emit(AcpThreadEvent::Refusal);
1608 }
1609 }
1610
1611 cx.emit(AcpThreadEvent::Stopped);
1612 Ok(())
1613 }
1614 }
1615 })?
1616 })
1617 .boxed()
1618 }
1619
1620 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1621 let Some(send_task) = self.send_task.take() else {
1622 return Task::ready(());
1623 };
1624
1625 for entry in self.entries.iter_mut() {
1626 if let AgentThreadEntry::ToolCall(call) = entry {
1627 let cancel = matches!(
1628 call.status,
1629 ToolCallStatus::Pending
1630 | ToolCallStatus::WaitingForConfirmation { .. }
1631 | ToolCallStatus::InProgress
1632 );
1633
1634 if cancel {
1635 call.status = ToolCallStatus::Canceled;
1636 }
1637 }
1638 }
1639
1640 self.connection.cancel(&self.session_id, cx);
1641
1642 // Wait for the send task to complete
1643 cx.foreground_executor().spawn(send_task)
1644 }
1645
1646 /// Rewinds this thread to before the entry at `index`, removing it and all
1647 /// subsequent entries while reverting any changes made from that point.
1648 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1649 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1650 return Task::ready(Err(anyhow!("not supported")));
1651 };
1652 let Some(message) = self.user_message(&id) else {
1653 return Task::ready(Err(anyhow!("message not found")));
1654 };
1655
1656 let checkpoint = message
1657 .checkpoint
1658 .as_ref()
1659 .map(|c| c.git_checkpoint.clone());
1660
1661 let git_store = self.project.read(cx).git_store().clone();
1662 cx.spawn(async move |this, cx| {
1663 if let Some(checkpoint) = checkpoint {
1664 git_store
1665 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1666 .await?;
1667 }
1668
1669 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1670 this.update(cx, |this, cx| {
1671 if let Some((ix, _)) = this.user_message_mut(&id) {
1672 let range = ix..this.entries.len();
1673 this.entries.truncate(ix);
1674 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1675 }
1676 })
1677 })
1678 }
1679
1680 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1681 let git_store = self.project.read(cx).git_store().clone();
1682
1683 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1684 if let Some(checkpoint) = message.checkpoint.as_ref() {
1685 checkpoint.git_checkpoint.clone()
1686 } else {
1687 return Task::ready(Ok(()));
1688 }
1689 } else {
1690 return Task::ready(Ok(()));
1691 };
1692
1693 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1694 cx.spawn(async move |this, cx| {
1695 let new_checkpoint = new_checkpoint
1696 .await
1697 .context("failed to get new checkpoint")
1698 .log_err();
1699 if let Some(new_checkpoint) = new_checkpoint {
1700 let equal = git_store
1701 .update(cx, |git, cx| {
1702 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1703 })?
1704 .await
1705 .unwrap_or(true);
1706 this.update(cx, |this, cx| {
1707 let (ix, message) = this.last_user_message().context("no user message")?;
1708 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1709 checkpoint.show = !equal;
1710 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1711 anyhow::Ok(())
1712 })??;
1713 }
1714
1715 Ok(())
1716 })
1717 }
1718
1719 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1720 self.entries
1721 .iter_mut()
1722 .enumerate()
1723 .rev()
1724 .find_map(|(ix, entry)| {
1725 if let AgentThreadEntry::UserMessage(message) = entry {
1726 Some((ix, message))
1727 } else {
1728 None
1729 }
1730 })
1731 }
1732
1733 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1734 self.entries.iter().find_map(|entry| {
1735 if let AgentThreadEntry::UserMessage(message) = entry {
1736 if message.id.as_ref() == Some(id) {
1737 Some(message)
1738 } else {
1739 None
1740 }
1741 } else {
1742 None
1743 }
1744 })
1745 }
1746
1747 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1748 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1749 if let AgentThreadEntry::UserMessage(message) = entry {
1750 if message.id.as_ref() == Some(id) {
1751 Some((ix, message))
1752 } else {
1753 None
1754 }
1755 } else {
1756 None
1757 }
1758 })
1759 }
1760
1761 pub fn read_text_file(
1762 &self,
1763 path: PathBuf,
1764 line: Option<u32>,
1765 limit: Option<u32>,
1766 reuse_shared_snapshot: bool,
1767 cx: &mut Context<Self>,
1768 ) -> Task<Result<String>> {
1769 let project = self.project.clone();
1770 let action_log = self.action_log.clone();
1771 cx.spawn(async move |this, cx| {
1772 let load = project.update(cx, |project, cx| {
1773 let path = project
1774 .project_path_for_absolute_path(&path, cx)
1775 .context("invalid path")?;
1776 anyhow::Ok(project.open_buffer(path, cx))
1777 });
1778 let buffer = load??.await?;
1779
1780 let snapshot = if reuse_shared_snapshot {
1781 this.read_with(cx, |this, _| {
1782 this.shared_buffers.get(&buffer.clone()).cloned()
1783 })
1784 .log_err()
1785 .flatten()
1786 } else {
1787 None
1788 };
1789
1790 let snapshot = if let Some(snapshot) = snapshot {
1791 snapshot
1792 } else {
1793 action_log.update(cx, |action_log, cx| {
1794 action_log.buffer_read(buffer.clone(), cx);
1795 })?;
1796 project.update(cx, |project, cx| {
1797 let position = buffer
1798 .read(cx)
1799 .snapshot()
1800 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1801 project.set_agent_location(
1802 Some(AgentLocation {
1803 buffer: buffer.downgrade(),
1804 position,
1805 }),
1806 cx,
1807 );
1808 })?;
1809
1810 buffer.update(cx, |buffer, _| buffer.snapshot())?
1811 };
1812
1813 this.update(cx, |this, _| {
1814 let text = snapshot.text();
1815 this.shared_buffers.insert(buffer.clone(), snapshot);
1816 if line.is_none() && limit.is_none() {
1817 return Ok(text);
1818 }
1819 let limit = limit.unwrap_or(u32::MAX) as usize;
1820 let Some(line) = line else {
1821 return Ok(text.lines().take(limit).collect::<String>());
1822 };
1823
1824 let count = text.lines().count();
1825 if count < line as usize {
1826 anyhow::bail!("There are only {} lines", count);
1827 }
1828 Ok(text
1829 .lines()
1830 .skip(line as usize + 1)
1831 .take(limit)
1832 .collect::<String>())
1833 })?
1834 })
1835 }
1836
1837 pub fn write_text_file(
1838 &self,
1839 path: PathBuf,
1840 content: String,
1841 cx: &mut Context<Self>,
1842 ) -> Task<Result<()>> {
1843 let project = self.project.clone();
1844 let action_log = self.action_log.clone();
1845 cx.spawn(async move |this, cx| {
1846 let load = project.update(cx, |project, cx| {
1847 let path = project
1848 .project_path_for_absolute_path(&path, cx)
1849 .context("invalid path")?;
1850 anyhow::Ok(project.open_buffer(path, cx))
1851 });
1852 let buffer = load??.await?;
1853 let snapshot = this.update(cx, |this, cx| {
1854 this.shared_buffers
1855 .get(&buffer)
1856 .cloned()
1857 .unwrap_or_else(|| buffer.read(cx).snapshot())
1858 })?;
1859 let edits = cx
1860 .background_executor()
1861 .spawn(async move {
1862 let old_text = snapshot.text();
1863 text_diff(old_text.as_str(), &content)
1864 .into_iter()
1865 .map(|(range, replacement)| {
1866 (
1867 snapshot.anchor_after(range.start)
1868 ..snapshot.anchor_before(range.end),
1869 replacement,
1870 )
1871 })
1872 .collect::<Vec<_>>()
1873 })
1874 .await;
1875
1876 project.update(cx, |project, cx| {
1877 project.set_agent_location(
1878 Some(AgentLocation {
1879 buffer: buffer.downgrade(),
1880 position: edits
1881 .last()
1882 .map(|(range, _)| range.end)
1883 .unwrap_or(Anchor::MIN),
1884 }),
1885 cx,
1886 );
1887 })?;
1888
1889 let format_on_save = cx.update(|cx| {
1890 action_log.update(cx, |action_log, cx| {
1891 action_log.buffer_read(buffer.clone(), cx);
1892 });
1893
1894 let format_on_save = buffer.update(cx, |buffer, cx| {
1895 buffer.edit(edits, None, cx);
1896
1897 let settings = language::language_settings::language_settings(
1898 buffer.language().map(|l| l.name()),
1899 buffer.file(),
1900 cx,
1901 );
1902
1903 settings.format_on_save != FormatOnSave::Off
1904 });
1905 action_log.update(cx, |action_log, cx| {
1906 action_log.buffer_edited(buffer.clone(), cx);
1907 });
1908 format_on_save
1909 })?;
1910
1911 if format_on_save {
1912 let format_task = project.update(cx, |project, cx| {
1913 project.format(
1914 HashSet::from_iter([buffer.clone()]),
1915 LspFormatTarget::Buffers,
1916 false,
1917 FormatTrigger::Save,
1918 cx,
1919 )
1920 })?;
1921 format_task.await.log_err();
1922
1923 action_log.update(cx, |action_log, cx| {
1924 action_log.buffer_edited(buffer.clone(), cx);
1925 })?;
1926 }
1927
1928 project
1929 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1930 .await
1931 })
1932 }
1933
1934 pub fn create_terminal(
1935 &self,
1936 mut command: String,
1937 args: Vec<String>,
1938 extra_env: Vec<acp::EnvVariable>,
1939 cwd: Option<PathBuf>,
1940 output_byte_limit: Option<u64>,
1941 cx: &mut Context<Self>,
1942 ) -> Task<Result<Entity<Terminal>>> {
1943 for arg in args {
1944 command.push(' ');
1945 command.push_str(&arg);
1946 }
1947
1948 let shell_command = if cfg!(windows) {
1949 format!("$null | & {{{}}}", command.replace("\"", "'"))
1950 } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1951 // Make sure once we're *inside* the shell, we cd into `cwd`
1952 format!("(cd {cwd}; {}) </dev/null", command)
1953 } else {
1954 format!("({}) </dev/null", command)
1955 };
1956 let args = vec!["-c".into(), shell_command];
1957
1958 let env = match &cwd {
1959 Some(dir) => self.project.update(cx, |project, cx| {
1960 project.directory_environment(dir.as_path().into(), cx)
1961 }),
1962 None => Task::ready(None).shared(),
1963 };
1964
1965 let env = cx.spawn(async move |_, _| {
1966 let mut env = env.await.unwrap_or_default();
1967 if cfg!(unix) {
1968 env.insert("PAGER".into(), "cat".into());
1969 }
1970 for var in extra_env {
1971 env.insert(var.name, var.value);
1972 }
1973 env
1974 });
1975
1976 let project = self.project.clone();
1977 let language_registry = project.read(cx).languages().clone();
1978 let determine_shell = self.determine_shell.clone();
1979
1980 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1981 let terminal_task = cx.spawn({
1982 let terminal_id = terminal_id.clone();
1983 async move |_this, cx| {
1984 let program = determine_shell.await;
1985 let env = env.await;
1986 let terminal = project
1987 .update(cx, |project, cx| {
1988 project.create_terminal_task(
1989 task::SpawnInTerminal {
1990 command: Some(program),
1991 args,
1992 cwd: cwd.clone(),
1993 env,
1994 ..Default::default()
1995 },
1996 cx,
1997 )
1998 })?
1999 .await?;
2000
2001 cx.new(|cx| {
2002 Terminal::new(
2003 terminal_id,
2004 command,
2005 cwd,
2006 output_byte_limit.map(|l| l as usize),
2007 terminal,
2008 language_registry,
2009 cx,
2010 )
2011 })
2012 }
2013 });
2014
2015 cx.spawn(async move |this, cx| {
2016 let terminal = terminal_task.await?;
2017 this.update(cx, |this, _cx| {
2018 this.terminals.insert(terminal_id, terminal.clone());
2019 terminal
2020 })
2021 })
2022 }
2023
2024 pub fn kill_terminal(
2025 &mut self,
2026 terminal_id: acp::TerminalId,
2027 cx: &mut Context<Self>,
2028 ) -> Result<()> {
2029 self.terminals
2030 .get(&terminal_id)
2031 .context("Terminal not found")?
2032 .update(cx, |terminal, cx| {
2033 terminal.kill(cx);
2034 });
2035
2036 Ok(())
2037 }
2038
2039 pub fn release_terminal(
2040 &mut self,
2041 terminal_id: acp::TerminalId,
2042 cx: &mut Context<Self>,
2043 ) -> Result<()> {
2044 self.terminals
2045 .remove(&terminal_id)
2046 .context("Terminal not found")?
2047 .update(cx, |terminal, cx| {
2048 terminal.kill(cx);
2049 });
2050
2051 Ok(())
2052 }
2053
2054 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2055 self.terminals
2056 .get(&terminal_id)
2057 .context("Terminal not found")
2058 .cloned()
2059 }
2060
2061 pub fn to_markdown(&self, cx: &App) -> String {
2062 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2063 }
2064
2065 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2066 cx.emit(AcpThreadEvent::LoadError(error));
2067 }
2068}
2069
2070fn markdown_for_raw_output(
2071 raw_output: &serde_json::Value,
2072 language_registry: &Arc<LanguageRegistry>,
2073 cx: &mut App,
2074) -> Option<Entity<Markdown>> {
2075 match raw_output {
2076 serde_json::Value::Null => None,
2077 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2078 Markdown::new(
2079 value.to_string().into(),
2080 Some(language_registry.clone()),
2081 None,
2082 cx,
2083 )
2084 })),
2085 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2086 Markdown::new(
2087 value.to_string().into(),
2088 Some(language_registry.clone()),
2089 None,
2090 cx,
2091 )
2092 })),
2093 serde_json::Value::String(value) => Some(cx.new(|cx| {
2094 Markdown::new(
2095 value.clone().into(),
2096 Some(language_registry.clone()),
2097 None,
2098 cx,
2099 )
2100 })),
2101 value => Some(cx.new(|cx| {
2102 Markdown::new(
2103 format!("```json\n{}\n```", value).into(),
2104 Some(language_registry.clone()),
2105 None,
2106 cx,
2107 )
2108 })),
2109 }
2110}
2111
2112#[cfg(test)]
2113mod tests {
2114 use super::*;
2115 use anyhow::anyhow;
2116 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2117 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2118 use indoc::indoc;
2119 use project::{FakeFs, Fs};
2120 use rand::Rng as _;
2121 use serde_json::json;
2122 use settings::SettingsStore;
2123 use smol::stream::StreamExt as _;
2124 use std::{
2125 any::Any,
2126 cell::RefCell,
2127 path::Path,
2128 rc::Rc,
2129 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2130 time::Duration,
2131 };
2132 use util::path;
2133
2134 fn init_test(cx: &mut TestAppContext) {
2135 env_logger::try_init().ok();
2136 cx.update(|cx| {
2137 let settings_store = SettingsStore::test(cx);
2138 cx.set_global(settings_store);
2139 Project::init_settings(cx);
2140 language::init(cx);
2141 });
2142 }
2143
2144 #[gpui::test]
2145 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2146 init_test(cx);
2147
2148 let fs = FakeFs::new(cx.executor());
2149 let project = Project::test(fs, [], cx).await;
2150 let connection = Rc::new(FakeAgentConnection::new());
2151 let thread = cx
2152 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2153 .await
2154 .unwrap();
2155
2156 // Test creating a new user message
2157 thread.update(cx, |thread, cx| {
2158 thread.push_user_content_block(
2159 None,
2160 acp::ContentBlock::Text(acp::TextContent {
2161 annotations: None,
2162 text: "Hello, ".to_string(),
2163 }),
2164 cx,
2165 );
2166 });
2167
2168 thread.update(cx, |thread, cx| {
2169 assert_eq!(thread.entries.len(), 1);
2170 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2171 assert_eq!(user_msg.id, None);
2172 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2173 } else {
2174 panic!("Expected UserMessage");
2175 }
2176 });
2177
2178 // Test appending to existing user message
2179 let message_1_id = UserMessageId::new();
2180 thread.update(cx, |thread, cx| {
2181 thread.push_user_content_block(
2182 Some(message_1_id.clone()),
2183 acp::ContentBlock::Text(acp::TextContent {
2184 annotations: None,
2185 text: "world!".to_string(),
2186 }),
2187 cx,
2188 );
2189 });
2190
2191 thread.update(cx, |thread, cx| {
2192 assert_eq!(thread.entries.len(), 1);
2193 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2194 assert_eq!(user_msg.id, Some(message_1_id));
2195 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2196 } else {
2197 panic!("Expected UserMessage");
2198 }
2199 });
2200
2201 // Test creating new user message after assistant message
2202 thread.update(cx, |thread, cx| {
2203 thread.push_assistant_content_block(
2204 acp::ContentBlock::Text(acp::TextContent {
2205 annotations: None,
2206 text: "Assistant response".to_string(),
2207 }),
2208 false,
2209 cx,
2210 );
2211 });
2212
2213 let message_2_id = UserMessageId::new();
2214 thread.update(cx, |thread, cx| {
2215 thread.push_user_content_block(
2216 Some(message_2_id.clone()),
2217 acp::ContentBlock::Text(acp::TextContent {
2218 annotations: None,
2219 text: "New user message".to_string(),
2220 }),
2221 cx,
2222 );
2223 });
2224
2225 thread.update(cx, |thread, cx| {
2226 assert_eq!(thread.entries.len(), 3);
2227 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2228 assert_eq!(user_msg.id, Some(message_2_id));
2229 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2230 } else {
2231 panic!("Expected UserMessage at index 2");
2232 }
2233 });
2234 }
2235
2236 #[gpui::test]
2237 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2238 init_test(cx);
2239
2240 let fs = FakeFs::new(cx.executor());
2241 let project = Project::test(fs, [], cx).await;
2242 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2243 |_, thread, mut cx| {
2244 async move {
2245 thread.update(&mut cx, |thread, cx| {
2246 thread
2247 .handle_session_update(
2248 acp::SessionUpdate::AgentThoughtChunk {
2249 content: "Thinking ".into(),
2250 },
2251 cx,
2252 )
2253 .unwrap();
2254 thread
2255 .handle_session_update(
2256 acp::SessionUpdate::AgentThoughtChunk {
2257 content: "hard!".into(),
2258 },
2259 cx,
2260 )
2261 .unwrap();
2262 })?;
2263 Ok(acp::PromptResponse {
2264 stop_reason: acp::StopReason::EndTurn,
2265 })
2266 }
2267 .boxed_local()
2268 },
2269 ));
2270
2271 let thread = cx
2272 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2273 .await
2274 .unwrap();
2275
2276 thread
2277 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2278 .await
2279 .unwrap();
2280
2281 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2282 assert_eq!(
2283 output,
2284 indoc! {r#"
2285 ## User
2286
2287 Hello from Zed!
2288
2289 ## Assistant
2290
2291 <thinking>
2292 Thinking hard!
2293 </thinking>
2294
2295 "#}
2296 );
2297 }
2298
2299 #[gpui::test]
2300 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2301 init_test(cx);
2302
2303 let fs = FakeFs::new(cx.executor());
2304 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2305 .await;
2306 let project = Project::test(fs.clone(), [], cx).await;
2307 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2308 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2309 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2310 move |_, thread, mut cx| {
2311 let read_file_tx = read_file_tx.clone();
2312 async move {
2313 let content = thread
2314 .update(&mut cx, |thread, cx| {
2315 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2316 })
2317 .unwrap()
2318 .await
2319 .unwrap();
2320 assert_eq!(content, "one\ntwo\nthree\n");
2321 read_file_tx.take().unwrap().send(()).unwrap();
2322 thread
2323 .update(&mut cx, |thread, cx| {
2324 thread.write_text_file(
2325 path!("/tmp/foo").into(),
2326 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2327 cx,
2328 )
2329 })
2330 .unwrap()
2331 .await
2332 .unwrap();
2333 Ok(acp::PromptResponse {
2334 stop_reason: acp::StopReason::EndTurn,
2335 })
2336 }
2337 .boxed_local()
2338 },
2339 ));
2340
2341 let (worktree, pathbuf) = project
2342 .update(cx, |project, cx| {
2343 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2344 })
2345 .await
2346 .unwrap();
2347 let buffer = project
2348 .update(cx, |project, cx| {
2349 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2350 })
2351 .await
2352 .unwrap();
2353
2354 let thread = cx
2355 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2356 .await
2357 .unwrap();
2358
2359 let request = thread.update(cx, |thread, cx| {
2360 thread.send_raw("Extend the count in /tmp/foo", cx)
2361 });
2362 read_file_rx.await.ok();
2363 buffer.update(cx, |buffer, cx| {
2364 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2365 });
2366 cx.run_until_parked();
2367 assert_eq!(
2368 buffer.read_with(cx, |buffer, _| buffer.text()),
2369 "zero\none\ntwo\nthree\nfour\nfive\n"
2370 );
2371 assert_eq!(
2372 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2373 "zero\none\ntwo\nthree\nfour\nfive\n"
2374 );
2375 request.await.unwrap();
2376 }
2377
2378 #[gpui::test]
2379 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2380 init_test(cx);
2381
2382 let fs = FakeFs::new(cx.executor());
2383 let project = Project::test(fs, [], cx).await;
2384 let id = acp::ToolCallId("test".into());
2385
2386 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2387 let id = id.clone();
2388 move |_, thread, mut cx| {
2389 let id = id.clone();
2390 async move {
2391 thread
2392 .update(&mut cx, |thread, cx| {
2393 thread.handle_session_update(
2394 acp::SessionUpdate::ToolCall(acp::ToolCall {
2395 id: id.clone(),
2396 title: "Label".into(),
2397 kind: acp::ToolKind::Fetch,
2398 status: acp::ToolCallStatus::InProgress,
2399 content: vec![],
2400 locations: vec![],
2401 raw_input: None,
2402 raw_output: None,
2403 }),
2404 cx,
2405 )
2406 })
2407 .unwrap()
2408 .unwrap();
2409 Ok(acp::PromptResponse {
2410 stop_reason: acp::StopReason::EndTurn,
2411 })
2412 }
2413 .boxed_local()
2414 }
2415 }));
2416
2417 let thread = cx
2418 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2419 .await
2420 .unwrap();
2421
2422 let request = thread.update(cx, |thread, cx| {
2423 thread.send_raw("Fetch https://example.com", cx)
2424 });
2425
2426 run_until_first_tool_call(&thread, cx).await;
2427
2428 thread.read_with(cx, |thread, _| {
2429 assert!(matches!(
2430 thread.entries[1],
2431 AgentThreadEntry::ToolCall(ToolCall {
2432 status: ToolCallStatus::InProgress,
2433 ..
2434 })
2435 ));
2436 });
2437
2438 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2439
2440 thread.read_with(cx, |thread, _| {
2441 assert!(matches!(
2442 &thread.entries[1],
2443 AgentThreadEntry::ToolCall(ToolCall {
2444 status: ToolCallStatus::Canceled,
2445 ..
2446 })
2447 ));
2448 });
2449
2450 thread
2451 .update(cx, |thread, cx| {
2452 thread.handle_session_update(
2453 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2454 id,
2455 fields: acp::ToolCallUpdateFields {
2456 status: Some(acp::ToolCallStatus::Completed),
2457 ..Default::default()
2458 },
2459 }),
2460 cx,
2461 )
2462 })
2463 .unwrap();
2464
2465 request.await.unwrap();
2466
2467 thread.read_with(cx, |thread, _| {
2468 assert!(matches!(
2469 thread.entries[1],
2470 AgentThreadEntry::ToolCall(ToolCall {
2471 status: ToolCallStatus::Completed,
2472 ..
2473 })
2474 ));
2475 });
2476 }
2477
2478 #[gpui::test]
2479 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2480 init_test(cx);
2481 let fs = FakeFs::new(cx.background_executor.clone());
2482 fs.insert_tree(path!("/test"), json!({})).await;
2483 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2484
2485 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2486 move |_, thread, mut cx| {
2487 async move {
2488 thread
2489 .update(&mut cx, |thread, cx| {
2490 thread.handle_session_update(
2491 acp::SessionUpdate::ToolCall(acp::ToolCall {
2492 id: acp::ToolCallId("test".into()),
2493 title: "Label".into(),
2494 kind: acp::ToolKind::Edit,
2495 status: acp::ToolCallStatus::Completed,
2496 content: vec![acp::ToolCallContent::Diff {
2497 diff: acp::Diff {
2498 path: "/test/test.txt".into(),
2499 old_text: None,
2500 new_text: "foo".into(),
2501 },
2502 }],
2503 locations: vec![],
2504 raw_input: None,
2505 raw_output: None,
2506 }),
2507 cx,
2508 )
2509 })
2510 .unwrap()
2511 .unwrap();
2512 Ok(acp::PromptResponse {
2513 stop_reason: acp::StopReason::EndTurn,
2514 })
2515 }
2516 .boxed_local()
2517 }
2518 }));
2519
2520 let thread = cx
2521 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2522 .await
2523 .unwrap();
2524
2525 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2526 .await
2527 .unwrap();
2528
2529 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2530 }
2531
2532 #[gpui::test(iterations = 10)]
2533 async fn test_checkpoints(cx: &mut TestAppContext) {
2534 init_test(cx);
2535 let fs = FakeFs::new(cx.background_executor.clone());
2536 fs.insert_tree(
2537 path!("/test"),
2538 json!({
2539 ".git": {}
2540 }),
2541 )
2542 .await;
2543 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2544
2545 let simulate_changes = Arc::new(AtomicBool::new(true));
2546 let next_filename = Arc::new(AtomicUsize::new(0));
2547 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2548 let simulate_changes = simulate_changes.clone();
2549 let next_filename = next_filename.clone();
2550 let fs = fs.clone();
2551 move |request, thread, mut cx| {
2552 let fs = fs.clone();
2553 let simulate_changes = simulate_changes.clone();
2554 let next_filename = next_filename.clone();
2555 async move {
2556 if simulate_changes.load(SeqCst) {
2557 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2558 fs.write(Path::new(&filename), b"").await?;
2559 }
2560
2561 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2562 panic!("expected text content block");
2563 };
2564 thread.update(&mut cx, |thread, cx| {
2565 thread
2566 .handle_session_update(
2567 acp::SessionUpdate::AgentMessageChunk {
2568 content: content.text.to_uppercase().into(),
2569 },
2570 cx,
2571 )
2572 .unwrap();
2573 })?;
2574 Ok(acp::PromptResponse {
2575 stop_reason: acp::StopReason::EndTurn,
2576 })
2577 }
2578 .boxed_local()
2579 }
2580 }));
2581 let thread = cx
2582 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2583 .await
2584 .unwrap();
2585
2586 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2587 .await
2588 .unwrap();
2589 thread.read_with(cx, |thread, cx| {
2590 assert_eq!(
2591 thread.to_markdown(cx),
2592 indoc! {"
2593 ## User (checkpoint)
2594
2595 Lorem
2596
2597 ## Assistant
2598
2599 LOREM
2600
2601 "}
2602 );
2603 });
2604 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2605
2606 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2607 .await
2608 .unwrap();
2609 thread.read_with(cx, |thread, cx| {
2610 assert_eq!(
2611 thread.to_markdown(cx),
2612 indoc! {"
2613 ## User (checkpoint)
2614
2615 Lorem
2616
2617 ## Assistant
2618
2619 LOREM
2620
2621 ## User (checkpoint)
2622
2623 ipsum
2624
2625 ## Assistant
2626
2627 IPSUM
2628
2629 "}
2630 );
2631 });
2632 assert_eq!(
2633 fs.files(),
2634 vec![
2635 Path::new(path!("/test/file-0")),
2636 Path::new(path!("/test/file-1"))
2637 ]
2638 );
2639
2640 // Checkpoint isn't stored when there are no changes.
2641 simulate_changes.store(false, SeqCst);
2642 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2643 .await
2644 .unwrap();
2645 thread.read_with(cx, |thread, cx| {
2646 assert_eq!(
2647 thread.to_markdown(cx),
2648 indoc! {"
2649 ## User (checkpoint)
2650
2651 Lorem
2652
2653 ## Assistant
2654
2655 LOREM
2656
2657 ## User (checkpoint)
2658
2659 ipsum
2660
2661 ## Assistant
2662
2663 IPSUM
2664
2665 ## User
2666
2667 dolor
2668
2669 ## Assistant
2670
2671 DOLOR
2672
2673 "}
2674 );
2675 });
2676 assert_eq!(
2677 fs.files(),
2678 vec![
2679 Path::new(path!("/test/file-0")),
2680 Path::new(path!("/test/file-1"))
2681 ]
2682 );
2683
2684 // Rewinding the conversation truncates the history and restores the checkpoint.
2685 thread
2686 .update(cx, |thread, cx| {
2687 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2688 panic!("unexpected entries {:?}", thread.entries)
2689 };
2690 thread.rewind(message.id.clone().unwrap(), cx)
2691 })
2692 .await
2693 .unwrap();
2694 thread.read_with(cx, |thread, cx| {
2695 assert_eq!(
2696 thread.to_markdown(cx),
2697 indoc! {"
2698 ## User (checkpoint)
2699
2700 Lorem
2701
2702 ## Assistant
2703
2704 LOREM
2705
2706 "}
2707 );
2708 });
2709 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2710 }
2711
2712 #[gpui::test]
2713 async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2714 use std::sync::atomic::AtomicUsize;
2715 init_test(cx);
2716
2717 let fs = FakeFs::new(cx.executor());
2718 let project = Project::test(fs, None, cx).await;
2719
2720 // Create a connection that simulates refusal after tool result
2721 let prompt_count = Arc::new(AtomicUsize::new(0));
2722 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2723 let prompt_count = prompt_count.clone();
2724 move |_request, thread, mut cx| {
2725 let count = prompt_count.fetch_add(1, SeqCst);
2726 async move {
2727 if count == 0 {
2728 // First prompt: Generate a tool call with result
2729 thread.update(&mut cx, |thread, cx| {
2730 thread
2731 .handle_session_update(
2732 acp::SessionUpdate::ToolCall(acp::ToolCall {
2733 id: acp::ToolCallId("tool1".into()),
2734 title: "Test Tool".into(),
2735 kind: acp::ToolKind::Fetch,
2736 status: acp::ToolCallStatus::Completed,
2737 content: vec![],
2738 locations: vec![],
2739 raw_input: Some(serde_json::json!({"query": "test"})),
2740 raw_output: Some(
2741 serde_json::json!({"result": "inappropriate content"}),
2742 ),
2743 }),
2744 cx,
2745 )
2746 .unwrap();
2747 })?;
2748
2749 // Now return refusal because of the tool result
2750 Ok(acp::PromptResponse {
2751 stop_reason: acp::StopReason::Refusal,
2752 })
2753 } else {
2754 Ok(acp::PromptResponse {
2755 stop_reason: acp::StopReason::EndTurn,
2756 })
2757 }
2758 }
2759 .boxed_local()
2760 }
2761 }));
2762
2763 let thread = cx
2764 .update(|cx| connection.new_thread(project, Path::new("/test"), cx))
2765 .await
2766 .unwrap();
2767
2768 // Track if we see a Refusal event
2769 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2770 let saw_refusal_event_captured = saw_refusal_event.clone();
2771 thread.update(cx, |_thread, cx| {
2772 cx.subscribe(
2773 &thread,
2774 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2775 if matches!(event, AcpThreadEvent::Refusal) {
2776 *saw_refusal_event_captured.lock().unwrap() = true;
2777 }
2778 },
2779 )
2780 .detach();
2781 });
2782
2783 // Send a user message - this will trigger tool call and then refusal
2784 let send_task = thread.update(cx, |thread, cx| {
2785 thread.send(
2786 vec![acp::ContentBlock::Text(acp::TextContent {
2787 text: "Hello".into(),
2788 annotations: None,
2789 })],
2790 cx,
2791 )
2792 });
2793 cx.background_executor.spawn(send_task).detach();
2794 cx.run_until_parked();
2795
2796 // Verify that:
2797 // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2798 // 2. The user message was NOT truncated
2799 assert!(
2800 *saw_refusal_event.lock().unwrap(),
2801 "Refusal event should be emitted for tool result refusals"
2802 );
2803
2804 thread.read_with(cx, |thread, _| {
2805 let entries = thread.entries();
2806 assert!(entries.len() >= 2, "Should have user message and tool call");
2807
2808 // Verify user message is still there
2809 assert!(
2810 matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2811 "User message should not be truncated"
2812 );
2813
2814 // Verify tool call is there with result
2815 if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2816 assert!(
2817 tool_call.raw_output.is_some(),
2818 "Tool call should have output"
2819 );
2820 } else {
2821 panic!("Expected tool call at index 1");
2822 }
2823 });
2824 }
2825
2826 #[gpui::test]
2827 async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2828 init_test(cx);
2829
2830 let fs = FakeFs::new(cx.executor());
2831 let project = Project::test(fs, None, cx).await;
2832
2833 let refuse_next = Arc::new(AtomicBool::new(false));
2834 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2835 let refuse_next = refuse_next.clone();
2836 move |_request, _thread, _cx| {
2837 if refuse_next.load(SeqCst) {
2838 async move {
2839 Ok(acp::PromptResponse {
2840 stop_reason: acp::StopReason::Refusal,
2841 })
2842 }
2843 .boxed_local()
2844 } else {
2845 async move {
2846 Ok(acp::PromptResponse {
2847 stop_reason: acp::StopReason::EndTurn,
2848 })
2849 }
2850 .boxed_local()
2851 }
2852 }
2853 }));
2854
2855 let thread = cx
2856 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2857 .await
2858 .unwrap();
2859
2860 // Track if we see a Refusal event
2861 let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2862 let saw_refusal_event_captured = saw_refusal_event.clone();
2863 thread.update(cx, |_thread, cx| {
2864 cx.subscribe(
2865 &thread,
2866 move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2867 if matches!(event, AcpThreadEvent::Refusal) {
2868 *saw_refusal_event_captured.lock().unwrap() = true;
2869 }
2870 },
2871 )
2872 .detach();
2873 });
2874
2875 // Send a message that will be refused
2876 refuse_next.store(true, SeqCst);
2877 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2878 .await
2879 .unwrap();
2880
2881 // Verify that a Refusal event WAS emitted for user prompt refusal
2882 assert!(
2883 *saw_refusal_event.lock().unwrap(),
2884 "Refusal event should be emitted for user prompt refusals"
2885 );
2886
2887 // Verify the message was truncated (user prompt refusal)
2888 thread.read_with(cx, |thread, cx| {
2889 assert_eq!(thread.to_markdown(cx), "");
2890 });
2891 }
2892
2893 #[gpui::test]
2894 async fn test_refusal(cx: &mut TestAppContext) {
2895 init_test(cx);
2896 let fs = FakeFs::new(cx.background_executor.clone());
2897 fs.insert_tree(path!("/"), json!({})).await;
2898 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2899
2900 let refuse_next = Arc::new(AtomicBool::new(false));
2901 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2902 let refuse_next = refuse_next.clone();
2903 move |request, thread, mut cx| {
2904 let refuse_next = refuse_next.clone();
2905 async move {
2906 if refuse_next.load(SeqCst) {
2907 return Ok(acp::PromptResponse {
2908 stop_reason: acp::StopReason::Refusal,
2909 });
2910 }
2911
2912 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2913 panic!("expected text content block");
2914 };
2915 thread.update(&mut cx, |thread, cx| {
2916 thread
2917 .handle_session_update(
2918 acp::SessionUpdate::AgentMessageChunk {
2919 content: content.text.to_uppercase().into(),
2920 },
2921 cx,
2922 )
2923 .unwrap();
2924 })?;
2925 Ok(acp::PromptResponse {
2926 stop_reason: acp::StopReason::EndTurn,
2927 })
2928 }
2929 .boxed_local()
2930 }
2931 }));
2932 let thread = cx
2933 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2934 .await
2935 .unwrap();
2936
2937 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2938 .await
2939 .unwrap();
2940 thread.read_with(cx, |thread, cx| {
2941 assert_eq!(
2942 thread.to_markdown(cx),
2943 indoc! {"
2944 ## User
2945
2946 hello
2947
2948 ## Assistant
2949
2950 HELLO
2951
2952 "}
2953 );
2954 });
2955
2956 // Simulate refusing the second message. The message should be truncated
2957 // when a user prompt is refused.
2958 refuse_next.store(true, SeqCst);
2959 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2960 .await
2961 .unwrap();
2962 thread.read_with(cx, |thread, cx| {
2963 assert_eq!(
2964 thread.to_markdown(cx),
2965 indoc! {"
2966 ## User
2967
2968 hello
2969
2970 ## Assistant
2971
2972 HELLO
2973
2974 "}
2975 );
2976 });
2977 }
2978
2979 async fn run_until_first_tool_call(
2980 thread: &Entity<AcpThread>,
2981 cx: &mut TestAppContext,
2982 ) -> usize {
2983 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2984
2985 let subscription = cx.update(|cx| {
2986 cx.subscribe(thread, move |thread, _, cx| {
2987 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2988 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2989 return tx.try_send(ix).unwrap();
2990 }
2991 }
2992 })
2993 });
2994
2995 select! {
2996 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2997 panic!("Timeout waiting for tool call")
2998 }
2999 ix = rx.next().fuse() => {
3000 drop(subscription);
3001 ix.unwrap()
3002 }
3003 }
3004 }
3005
3006 #[derive(Clone, Default)]
3007 struct FakeAgentConnection {
3008 auth_methods: Vec<acp::AuthMethod>,
3009 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3010 on_user_message: Option<
3011 Rc<
3012 dyn Fn(
3013 acp::PromptRequest,
3014 WeakEntity<AcpThread>,
3015 AsyncApp,
3016 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3017 + 'static,
3018 >,
3019 >,
3020 }
3021
3022 impl FakeAgentConnection {
3023 fn new() -> Self {
3024 Self {
3025 auth_methods: Vec::new(),
3026 on_user_message: None,
3027 sessions: Arc::default(),
3028 }
3029 }
3030
3031 #[expect(unused)]
3032 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3033 self.auth_methods = auth_methods;
3034 self
3035 }
3036
3037 fn on_user_message(
3038 mut self,
3039 handler: impl Fn(
3040 acp::PromptRequest,
3041 WeakEntity<AcpThread>,
3042 AsyncApp,
3043 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3044 + 'static,
3045 ) -> Self {
3046 self.on_user_message.replace(Rc::new(handler));
3047 self
3048 }
3049 }
3050
3051 impl AgentConnection for FakeAgentConnection {
3052 fn auth_methods(&self) -> &[acp::AuthMethod] {
3053 &self.auth_methods
3054 }
3055
3056 fn new_thread(
3057 self: Rc<Self>,
3058 project: Entity<Project>,
3059 _cwd: &Path,
3060 cx: &mut App,
3061 ) -> Task<gpui::Result<Entity<AcpThread>>> {
3062 let session_id = acp::SessionId(
3063 rand::thread_rng()
3064 .sample_iter(&rand::distributions::Alphanumeric)
3065 .take(7)
3066 .map(char::from)
3067 .collect::<String>()
3068 .into(),
3069 );
3070 let action_log = cx.new(|_| ActionLog::new(project.clone()));
3071 let thread = cx.new(|cx| {
3072 AcpThread::new(
3073 "Test",
3074 self.clone(),
3075 project,
3076 action_log,
3077 session_id.clone(),
3078 watch::Receiver::constant(acp::PromptCapabilities {
3079 image: true,
3080 audio: true,
3081 embedded_context: true,
3082 }),
3083 vec![],
3084 cx,
3085 )
3086 });
3087 self.sessions.lock().insert(session_id, thread.downgrade());
3088 Task::ready(Ok(thread))
3089 }
3090
3091 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3092 if self.auth_methods().iter().any(|m| m.id == method) {
3093 Task::ready(Ok(()))
3094 } else {
3095 Task::ready(Err(anyhow!("Invalid Auth Method")))
3096 }
3097 }
3098
3099 fn prompt(
3100 &self,
3101 _id: Option<UserMessageId>,
3102 params: acp::PromptRequest,
3103 cx: &mut App,
3104 ) -> Task<gpui::Result<acp::PromptResponse>> {
3105 let sessions = self.sessions.lock();
3106 let thread = sessions.get(¶ms.session_id).unwrap();
3107 if let Some(handler) = &self.on_user_message {
3108 let handler = handler.clone();
3109 let thread = thread.clone();
3110 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3111 } else {
3112 Task::ready(Ok(acp::PromptResponse {
3113 stop_reason: acp::StopReason::EndTurn,
3114 }))
3115 }
3116 }
3117
3118 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3119 let sessions = self.sessions.lock();
3120 let thread = sessions.get(session_id).unwrap().clone();
3121
3122 cx.spawn(async move |cx| {
3123 thread
3124 .update(cx, |thread, cx| thread.cancel(cx))
3125 .unwrap()
3126 .await
3127 })
3128 .detach();
3129 }
3130
3131 fn truncate(
3132 &self,
3133 session_id: &acp::SessionId,
3134 _cx: &App,
3135 ) -> Option<Rc<dyn AgentSessionTruncate>> {
3136 Some(Rc::new(FakeAgentSessionEditor {
3137 _session_id: session_id.clone(),
3138 }))
3139 }
3140
3141 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3142 self
3143 }
3144 }
3145
3146 struct FakeAgentSessionEditor {
3147 _session_id: acp::SessionId,
3148 }
3149
3150 impl AgentSessionTruncate for FakeAgentSessionEditor {
3151 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3152 Task::ready(Ok(()))
3153 }
3154 }
3155}