1mod connection;
2mod diff;
3
4pub use connection::*;
5pub use diff::*;
6
7use agent_client_protocol as acp;
8use anyhow::{Context as _, Result};
9use assistant_tool::ActionLog;
10use editor::Bias;
11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
12use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
13use itertools::Itertools;
14use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
15use markdown::Markdown;
16use project::{AgentLocation, Project};
17use std::collections::HashMap;
18use std::error::Error;
19use std::fmt::Formatter;
20use std::process::ExitStatus;
21use std::rc::Rc;
22use std::{
23 fmt::Display,
24 mem,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub content: ContentBlock,
34}
35
36impl UserMessage {
37 pub fn from_acp(
38 message: impl IntoIterator<Item = acp::ContentBlock>,
39 language_registry: Arc<LanguageRegistry>,
40 cx: &mut App,
41 ) -> Self {
42 let mut content = ContentBlock::Empty;
43 for chunk in message {
44 content.append(chunk, &language_registry, cx)
45 }
46 Self { content: content }
47 }
48
49 fn to_markdown(&self, cx: &App) -> String {
50 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
51 }
52}
53
54#[derive(Debug)]
55pub struct MentionPath<'a>(&'a Path);
56
57impl<'a> MentionPath<'a> {
58 const PREFIX: &'static str = "@file:";
59
60 pub fn new(path: &'a Path) -> Self {
61 MentionPath(path)
62 }
63
64 pub fn try_parse(url: &'a str) -> Option<Self> {
65 let path = url.strip_prefix(Self::PREFIX)?;
66 Some(MentionPath(Path::new(path)))
67 }
68
69 pub fn path(&self) -> &Path {
70 self.0
71 }
72}
73
74impl Display for MentionPath<'_> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(
77 f,
78 "[@{}]({}{})",
79 self.0.file_name().unwrap_or_default().display(),
80 Self::PREFIX,
81 self.0.display()
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub struct AssistantMessage {
88 pub chunks: Vec<AssistantMessageChunk>,
89}
90
91impl AssistantMessage {
92 pub fn to_markdown(&self, cx: &App) -> String {
93 format!(
94 "## Assistant\n\n{}\n\n",
95 self.chunks
96 .iter()
97 .map(|chunk| chunk.to_markdown(cx))
98 .join("\n\n")
99 )
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub enum AssistantMessageChunk {
105 Message { block: ContentBlock },
106 Thought { block: ContentBlock },
107}
108
109impl AssistantMessageChunk {
110 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 Self::Message {
112 block: ContentBlock::new(chunk.into(), language_registry, cx),
113 }
114 }
115
116 fn to_markdown(&self, cx: &App) -> String {
117 match self {
118 Self::Message { block } => block.to_markdown(cx).to_string(),
119 Self::Thought { block } => {
120 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
121 }
122 }
123 }
124}
125
126#[derive(Debug)]
127pub enum AgentThreadEntry {
128 UserMessage(UserMessage),
129 AssistantMessage(AssistantMessage),
130 ToolCall(ToolCall),
131}
132
133impl AgentThreadEntry {
134 fn to_markdown(&self, cx: &App) -> String {
135 match self {
136 Self::UserMessage(message) => message.to_markdown(cx),
137 Self::AssistantMessage(message) => message.to_markdown(cx),
138 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
139 }
140 }
141
142 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
143 if let AgentThreadEntry::ToolCall(call) = self {
144 itertools::Either::Left(call.diffs())
145 } else {
146 itertools::Either::Right(std::iter::empty())
147 }
148 }
149
150 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
151 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
152 Some(locations)
153 } else {
154 None
155 }
156 }
157}
158
159#[derive(Debug)]
160pub struct ToolCall {
161 pub id: acp::ToolCallId,
162 pub label: Entity<Markdown>,
163 pub kind: acp::ToolKind,
164 pub content: Vec<ToolCallContent>,
165 pub status: ToolCallStatus,
166 pub locations: Vec<acp::ToolCallLocation>,
167 pub raw_input: Option<serde_json::Value>,
168 pub raw_output: Option<serde_json::Value>,
169}
170
171impl ToolCall {
172 fn from_acp(
173 tool_call: acp::ToolCall,
174 status: ToolCallStatus,
175 language_registry: Arc<LanguageRegistry>,
176 cx: &mut App,
177 ) -> Self {
178 Self {
179 id: tool_call.id,
180 label: cx.new(|cx| {
181 Markdown::new(
182 tool_call.title.into(),
183 Some(language_registry.clone()),
184 None,
185 cx,
186 )
187 }),
188 kind: tool_call.kind,
189 content: tool_call
190 .content
191 .into_iter()
192 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
193 .collect(),
194 locations: tool_call.locations,
195 status,
196 raw_input: tool_call.raw_input,
197 raw_output: tool_call.raw_output,
198 }
199 }
200
201 fn update(
202 &mut self,
203 fields: acp::ToolCallUpdateFields,
204 language_registry: Arc<LanguageRegistry>,
205 cx: &mut App,
206 ) {
207 let acp::ToolCallUpdateFields {
208 kind,
209 status,
210 title,
211 content,
212 locations,
213 raw_input,
214 raw_output,
215 } = fields;
216
217 if let Some(kind) = kind {
218 self.kind = kind;
219 }
220
221 if let Some(status) = status {
222 self.status = ToolCallStatus::Allowed { status };
223 }
224
225 if let Some(title) = title {
226 self.label.update(cx, |label, cx| {
227 label.replace(title, cx);
228 });
229 }
230
231 if let Some(content) = content {
232 self.content = content
233 .into_iter()
234 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
235 .collect();
236 }
237
238 if let Some(locations) = locations {
239 self.locations = locations;
240 }
241
242 if let Some(raw_input) = raw_input {
243 self.raw_input = Some(raw_input);
244 }
245
246 if let Some(raw_output) = raw_output {
247 self.raw_output = Some(raw_output);
248 }
249 }
250
251 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
252 self.content.iter().filter_map(|content| match content {
253 ToolCallContent::ContentBlock { .. } => None,
254 ToolCallContent::Diff { diff } => Some(diff),
255 })
256 }
257
258 fn to_markdown(&self, cx: &App) -> String {
259 let mut markdown = format!(
260 "**Tool Call: {}**\nStatus: {}\n\n",
261 self.label.read(cx).source(),
262 self.status
263 );
264 for content in &self.content {
265 markdown.push_str(content.to_markdown(cx).as_str());
266 markdown.push_str("\n\n");
267 }
268 markdown
269 }
270}
271
272#[derive(Debug)]
273pub enum ToolCallStatus {
274 WaitingForConfirmation {
275 options: Vec<acp::PermissionOption>,
276 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
277 },
278 Allowed {
279 status: acp::ToolCallStatus,
280 },
281 Rejected,
282 Canceled,
283}
284
285impl Display for ToolCallStatus {
286 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287 write!(
288 f,
289 "{}",
290 match self {
291 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
292 ToolCallStatus::Allowed { status } => match status {
293 acp::ToolCallStatus::Pending => "Pending",
294 acp::ToolCallStatus::InProgress => "In Progress",
295 acp::ToolCallStatus::Completed => "Completed",
296 acp::ToolCallStatus::Failed => "Failed",
297 },
298 ToolCallStatus::Rejected => "Rejected",
299 ToolCallStatus::Canceled => "Canceled",
300 }
301 )
302 }
303}
304
305#[derive(Debug, PartialEq, Clone)]
306pub enum ContentBlock {
307 Empty,
308 Markdown { markdown: Entity<Markdown> },
309}
310
311impl ContentBlock {
312 pub fn new(
313 block: acp::ContentBlock,
314 language_registry: &Arc<LanguageRegistry>,
315 cx: &mut App,
316 ) -> Self {
317 let mut this = Self::Empty;
318 this.append(block, language_registry, cx);
319 this
320 }
321
322 pub fn new_combined(
323 blocks: impl IntoIterator<Item = acp::ContentBlock>,
324 language_registry: Arc<LanguageRegistry>,
325 cx: &mut App,
326 ) -> Self {
327 let mut this = Self::Empty;
328 for block in blocks {
329 this.append(block, &language_registry, cx);
330 }
331 this
332 }
333
334 pub fn append(
335 &mut self,
336 block: acp::ContentBlock,
337 language_registry: &Arc<LanguageRegistry>,
338 cx: &mut App,
339 ) {
340 let new_content = match block {
341 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
342 acp::ContentBlock::ResourceLink(resource_link) => {
343 if let Some(path) = resource_link.uri.strip_prefix("file://") {
344 format!("{}", MentionPath(path.as_ref()))
345 } else {
346 resource_link.uri.clone()
347 }
348 }
349 acp::ContentBlock::Image(_)
350 | acp::ContentBlock::Audio(_)
351 | acp::ContentBlock::Resource(_) => String::new(),
352 };
353
354 match self {
355 ContentBlock::Empty => {
356 *self = ContentBlock::Markdown {
357 markdown: cx.new(|cx| {
358 Markdown::new(
359 new_content.into(),
360 Some(language_registry.clone()),
361 None,
362 cx,
363 )
364 }),
365 };
366 }
367 ContentBlock::Markdown { markdown } => {
368 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
369 }
370 }
371 }
372
373 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
374 match self {
375 ContentBlock::Empty => "",
376 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
377 }
378 }
379
380 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
381 match self {
382 ContentBlock::Empty => None,
383 ContentBlock::Markdown { markdown } => Some(markdown),
384 }
385 }
386}
387
388#[derive(Debug)]
389pub enum ToolCallContent {
390 ContentBlock { content: ContentBlock },
391 Diff { diff: Entity<Diff> },
392}
393
394impl ToolCallContent {
395 pub fn from_acp(
396 content: acp::ToolCallContent,
397 language_registry: Arc<LanguageRegistry>,
398 cx: &mut App,
399 ) -> Self {
400 match content {
401 acp::ToolCallContent::Content { content } => Self::ContentBlock {
402 content: ContentBlock::new(content, &language_registry, cx),
403 },
404 acp::ToolCallContent::Diff { diff } => Self::Diff {
405 diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
406 },
407 }
408 }
409
410 pub fn to_markdown(&self, cx: &App) -> String {
411 match self {
412 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
413 Self::Diff { diff } => diff.read(cx).to_markdown(cx),
414 }
415 }
416}
417
418#[derive(Debug, Default)]
419pub struct Plan {
420 pub entries: Vec<PlanEntry>,
421}
422
423#[derive(Debug)]
424pub struct PlanStats<'a> {
425 pub in_progress_entry: Option<&'a PlanEntry>,
426 pub pending: u32,
427 pub completed: u32,
428}
429
430impl Plan {
431 pub fn is_empty(&self) -> bool {
432 self.entries.is_empty()
433 }
434
435 pub fn stats(&self) -> PlanStats<'_> {
436 let mut stats = PlanStats {
437 in_progress_entry: None,
438 pending: 0,
439 completed: 0,
440 };
441
442 for entry in &self.entries {
443 match &entry.status {
444 acp::PlanEntryStatus::Pending => {
445 stats.pending += 1;
446 }
447 acp::PlanEntryStatus::InProgress => {
448 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
449 }
450 acp::PlanEntryStatus::Completed => {
451 stats.completed += 1;
452 }
453 }
454 }
455
456 stats
457 }
458}
459
460#[derive(Debug)]
461pub struct PlanEntry {
462 pub content: Entity<Markdown>,
463 pub priority: acp::PlanEntryPriority,
464 pub status: acp::PlanEntryStatus,
465}
466
467impl PlanEntry {
468 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
469 Self {
470 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
471 priority: entry.priority,
472 status: entry.status,
473 }
474 }
475}
476
477pub struct AcpThread {
478 title: SharedString,
479 entries: Vec<AgentThreadEntry>,
480 plan: Plan,
481 project: Entity<Project>,
482 action_log: Entity<ActionLog>,
483 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
484 send_task: Option<Task<()>>,
485 connection: Rc<dyn AgentConnection>,
486 session_id: acp::SessionId,
487}
488
489pub enum AcpThreadEvent {
490 NewEntry,
491 EntryUpdated(usize),
492 ToolAuthorizationRequired,
493 Stopped,
494 Error,
495 ServerExited(ExitStatus),
496}
497
498impl EventEmitter<AcpThreadEvent> for AcpThread {}
499
500#[derive(PartialEq, Eq)]
501pub enum ThreadStatus {
502 Idle,
503 WaitingForToolConfirmation,
504 Generating,
505}
506
507#[derive(Debug, Clone)]
508pub enum LoadError {
509 Unsupported {
510 error_message: SharedString,
511 upgrade_message: SharedString,
512 upgrade_command: String,
513 },
514 Exited(i32),
515 Other(SharedString),
516}
517
518impl Display for LoadError {
519 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520 match self {
521 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
522 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
523 LoadError::Other(msg) => write!(f, "{}", msg),
524 }
525 }
526}
527
528impl Error for LoadError {}
529
530impl AcpThread {
531 pub fn new(
532 title: impl Into<SharedString>,
533 connection: Rc<dyn AgentConnection>,
534 project: Entity<Project>,
535 session_id: acp::SessionId,
536 cx: &mut Context<Self>,
537 ) -> Self {
538 let action_log = cx.new(|_| ActionLog::new(project.clone()));
539
540 Self {
541 action_log,
542 shared_buffers: Default::default(),
543 entries: Default::default(),
544 plan: Default::default(),
545 title: title.into(),
546 project,
547 send_task: None,
548 connection,
549 session_id,
550 }
551 }
552
553 pub fn action_log(&self) -> &Entity<ActionLog> {
554 &self.action_log
555 }
556
557 pub fn project(&self) -> &Entity<Project> {
558 &self.project
559 }
560
561 pub fn title(&self) -> SharedString {
562 self.title.clone()
563 }
564
565 pub fn entries(&self) -> &[AgentThreadEntry] {
566 &self.entries
567 }
568
569 pub fn session_id(&self) -> &acp::SessionId {
570 &self.session_id
571 }
572
573 pub fn status(&self) -> ThreadStatus {
574 if self.send_task.is_some() {
575 if self.waiting_for_tool_confirmation() {
576 ThreadStatus::WaitingForToolConfirmation
577 } else {
578 ThreadStatus::Generating
579 }
580 } else {
581 ThreadStatus::Idle
582 }
583 }
584
585 pub fn has_pending_edit_tool_calls(&self) -> bool {
586 for entry in self.entries.iter().rev() {
587 match entry {
588 AgentThreadEntry::UserMessage(_) => return false,
589 AgentThreadEntry::ToolCall(
590 call @ ToolCall {
591 status:
592 ToolCallStatus::Allowed {
593 status:
594 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
595 },
596 ..
597 },
598 ) if call.diffs().next().is_some() => {
599 return true;
600 }
601 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
602 }
603 }
604
605 false
606 }
607
608 pub fn used_tools_since_last_user_message(&self) -> bool {
609 for entry in self.entries.iter().rev() {
610 match entry {
611 AgentThreadEntry::UserMessage(..) => return false,
612 AgentThreadEntry::AssistantMessage(..) => continue,
613 AgentThreadEntry::ToolCall(..) => return true,
614 }
615 }
616
617 false
618 }
619
620 pub fn handle_session_update(
621 &mut self,
622 update: acp::SessionUpdate,
623 cx: &mut Context<Self>,
624 ) -> Result<()> {
625 match update {
626 acp::SessionUpdate::UserMessageChunk { content } => {
627 self.push_user_content_block(content, cx);
628 }
629 acp::SessionUpdate::AgentMessageChunk { content } => {
630 self.push_assistant_content_block(content, false, cx);
631 }
632 acp::SessionUpdate::AgentThoughtChunk { content } => {
633 self.push_assistant_content_block(content, true, cx);
634 }
635 acp::SessionUpdate::ToolCall(tool_call) => {
636 self.upsert_tool_call(tool_call, cx);
637 }
638 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
639 self.update_tool_call(tool_call_update, cx)?;
640 }
641 acp::SessionUpdate::Plan(plan) => {
642 self.update_plan(plan, cx);
643 }
644 }
645 Ok(())
646 }
647
648 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
649 let language_registry = self.project.read(cx).languages().clone();
650 let entries_len = self.entries.len();
651
652 if let Some(last_entry) = self.entries.last_mut()
653 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
654 {
655 content.append(chunk, &language_registry, cx);
656 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
657 } else {
658 let content = ContentBlock::new(chunk, &language_registry, cx);
659 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
660 }
661 }
662
663 pub fn push_assistant_content_block(
664 &mut self,
665 chunk: acp::ContentBlock,
666 is_thought: bool,
667 cx: &mut Context<Self>,
668 ) {
669 let language_registry = self.project.read(cx).languages().clone();
670 let entries_len = self.entries.len();
671 if let Some(last_entry) = self.entries.last_mut()
672 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
673 {
674 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
675 match (chunks.last_mut(), is_thought) {
676 (Some(AssistantMessageChunk::Message { block }), false)
677 | (Some(AssistantMessageChunk::Thought { block }), true) => {
678 block.append(chunk, &language_registry, cx)
679 }
680 _ => {
681 let block = ContentBlock::new(chunk, &language_registry, cx);
682 if is_thought {
683 chunks.push(AssistantMessageChunk::Thought { block })
684 } else {
685 chunks.push(AssistantMessageChunk::Message { block })
686 }
687 }
688 }
689 } else {
690 let block = ContentBlock::new(chunk, &language_registry, cx);
691 let chunk = if is_thought {
692 AssistantMessageChunk::Thought { block }
693 } else {
694 AssistantMessageChunk::Message { block }
695 };
696
697 self.push_entry(
698 AgentThreadEntry::AssistantMessage(AssistantMessage {
699 chunks: vec![chunk],
700 }),
701 cx,
702 );
703 }
704 }
705
706 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
707 self.entries.push(entry);
708 cx.emit(AcpThreadEvent::NewEntry);
709 }
710
711 pub fn update_tool_call(
712 &mut self,
713 update: acp::ToolCallUpdate,
714 cx: &mut Context<Self>,
715 ) -> Result<()> {
716 let languages = self.project.read(cx).languages().clone();
717
718 let (ix, current_call) = self
719 .tool_call_mut(&update.id)
720 .context("Tool call not found")?;
721 current_call.update(update.fields, languages, cx);
722
723 cx.emit(AcpThreadEvent::EntryUpdated(ix));
724
725 Ok(())
726 }
727
728 pub fn set_tool_call_diff(
729 &mut self,
730 tool_call_id: &acp::ToolCallId,
731 diff: Entity<Diff>,
732 cx: &mut Context<Self>,
733 ) -> Result<()> {
734 let (ix, current_call) = self
735 .tool_call_mut(tool_call_id)
736 .context("Tool call not found")?;
737 current_call.content.clear();
738 current_call.content.push(ToolCallContent::Diff { diff });
739 cx.emit(AcpThreadEvent::EntryUpdated(ix));
740 Ok(())
741 }
742
743 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
744 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
745 let status = ToolCallStatus::Allowed {
746 status: tool_call.status,
747 };
748 self.upsert_tool_call_inner(tool_call, status, cx)
749 }
750
751 pub fn upsert_tool_call_inner(
752 &mut self,
753 tool_call: acp::ToolCall,
754 status: ToolCallStatus,
755 cx: &mut Context<Self>,
756 ) {
757 let language_registry = self.project.read(cx).languages().clone();
758 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
759
760 let location = call.locations.last().cloned();
761
762 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
763 *current_call = call;
764
765 cx.emit(AcpThreadEvent::EntryUpdated(ix));
766 } else {
767 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
768 }
769
770 if let Some(location) = location {
771 self.set_project_location(location, cx)
772 }
773 }
774
775 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
776 // The tool call we are looking for is typically the last one, or very close to the end.
777 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
778 self.entries
779 .iter_mut()
780 .enumerate()
781 .rev()
782 .find_map(|(index, tool_call)| {
783 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
784 && &tool_call.id == id
785 {
786 Some((index, tool_call))
787 } else {
788 None
789 }
790 })
791 }
792
793 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
794 self.project.update(cx, |project, cx| {
795 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
796 return;
797 };
798 let buffer = project.open_buffer(path, cx);
799 cx.spawn(async move |project, cx| {
800 let buffer = buffer.await?;
801
802 project.update(cx, |project, cx| {
803 let position = if let Some(line) = location.line {
804 let snapshot = buffer.read(cx).snapshot();
805 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
806 snapshot.anchor_before(point)
807 } else {
808 Anchor::MIN
809 };
810
811 project.set_agent_location(
812 Some(AgentLocation {
813 buffer: buffer.downgrade(),
814 position,
815 }),
816 cx,
817 );
818 })
819 })
820 .detach_and_log_err(cx);
821 });
822 }
823
824 pub fn request_tool_call_authorization(
825 &mut self,
826 tool_call: acp::ToolCall,
827 options: Vec<acp::PermissionOption>,
828 cx: &mut Context<Self>,
829 ) -> oneshot::Receiver<acp::PermissionOptionId> {
830 let (tx, rx) = oneshot::channel();
831
832 let status = ToolCallStatus::WaitingForConfirmation {
833 options,
834 respond_tx: tx,
835 };
836
837 self.upsert_tool_call_inner(tool_call, status, cx);
838 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
839 rx
840 }
841
842 pub fn authorize_tool_call(
843 &mut self,
844 id: acp::ToolCallId,
845 option_id: acp::PermissionOptionId,
846 option_kind: acp::PermissionOptionKind,
847 cx: &mut Context<Self>,
848 ) {
849 let Some((ix, call)) = self.tool_call_mut(&id) else {
850 return;
851 };
852
853 let new_status = match option_kind {
854 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
855 ToolCallStatus::Rejected
856 }
857 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
858 ToolCallStatus::Allowed {
859 status: acp::ToolCallStatus::InProgress,
860 }
861 }
862 };
863
864 let curr_status = mem::replace(&mut call.status, new_status);
865
866 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
867 respond_tx.send(option_id).log_err();
868 } else if cfg!(debug_assertions) {
869 panic!("tried to authorize an already authorized tool call");
870 }
871
872 cx.emit(AcpThreadEvent::EntryUpdated(ix));
873 }
874
875 /// Returns true if the last turn is awaiting tool authorization
876 pub fn waiting_for_tool_confirmation(&self) -> bool {
877 for entry in self.entries.iter().rev() {
878 match &entry {
879 AgentThreadEntry::ToolCall(call) => match call.status {
880 ToolCallStatus::WaitingForConfirmation { .. } => return true,
881 ToolCallStatus::Allowed { .. }
882 | ToolCallStatus::Rejected
883 | ToolCallStatus::Canceled => continue,
884 },
885 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
886 // Reached the beginning of the turn
887 return false;
888 }
889 }
890 }
891 false
892 }
893
894 pub fn plan(&self) -> &Plan {
895 &self.plan
896 }
897
898 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
899 let new_entries_len = request.entries.len();
900 let mut new_entries = request.entries.into_iter();
901
902 // Reuse existing markdown to prevent flickering
903 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
904 let PlanEntry {
905 content,
906 priority,
907 status,
908 } = old;
909 content.update(cx, |old, cx| {
910 old.replace(new.content, cx);
911 });
912 *priority = new.priority;
913 *status = new.status;
914 }
915 for new in new_entries {
916 self.plan.entries.push(PlanEntry::from_acp(new, cx))
917 }
918 self.plan.entries.truncate(new_entries_len);
919
920 cx.notify();
921 }
922
923 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
924 self.plan
925 .entries
926 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
927 cx.notify();
928 }
929
930 #[cfg(any(test, feature = "test-support"))]
931 pub fn send_raw(
932 &mut self,
933 message: &str,
934 cx: &mut Context<Self>,
935 ) -> BoxFuture<'static, Result<()>> {
936 self.send(
937 vec![acp::ContentBlock::Text(acp::TextContent {
938 text: message.to_string(),
939 annotations: None,
940 })],
941 cx,
942 )
943 }
944
945 pub fn send(
946 &mut self,
947 message: Vec<acp::ContentBlock>,
948 cx: &mut Context<Self>,
949 ) -> BoxFuture<'static, Result<()>> {
950 let block = ContentBlock::new_combined(
951 message.clone(),
952 self.project.read(cx).languages().clone(),
953 cx,
954 );
955 self.push_entry(
956 AgentThreadEntry::UserMessage(UserMessage { content: block }),
957 cx,
958 );
959 self.clear_completed_plan_entries(cx);
960
961 let (tx, rx) = oneshot::channel();
962 let cancel_task = self.cancel(cx);
963
964 self.send_task = Some(cx.spawn(async move |this, cx| {
965 async {
966 cancel_task.await;
967
968 let result = this
969 .update(cx, |this, cx| {
970 this.connection.prompt(
971 acp::PromptRequest {
972 prompt: message,
973 session_id: this.session_id.clone(),
974 },
975 cx,
976 )
977 })?
978 .await;
979
980 tx.send(result).log_err();
981
982 anyhow::Ok(())
983 }
984 .await
985 .log_err();
986 }));
987
988 cx.spawn(async move |this, cx| match rx.await {
989 Ok(Err(e)) => {
990 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
991 .log_err();
992 Err(e)?
993 }
994 result => {
995 let cancelled = matches!(
996 result,
997 Ok(Ok(acp::PromptResponse {
998 stop_reason: acp::StopReason::Cancelled
999 }))
1000 );
1001
1002 // We only take the task if the current prompt wasn't cancelled.
1003 //
1004 // This prompt may have been cancelled because another one was sent
1005 // while it was still generating. In these cases, dropping `send_task`
1006 // would cause the next generation to be cancelled.
1007 if !cancelled {
1008 this.update(cx, |this, _cx| this.send_task.take()).ok();
1009 }
1010
1011 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1012 .log_err();
1013 Ok(())
1014 }
1015 })
1016 .boxed()
1017 }
1018
1019 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1020 let Some(send_task) = self.send_task.take() else {
1021 return Task::ready(());
1022 };
1023
1024 for entry in self.entries.iter_mut() {
1025 if let AgentThreadEntry::ToolCall(call) = entry {
1026 let cancel = matches!(
1027 call.status,
1028 ToolCallStatus::WaitingForConfirmation { .. }
1029 | ToolCallStatus::Allowed {
1030 status: acp::ToolCallStatus::InProgress
1031 }
1032 );
1033
1034 if cancel {
1035 call.status = ToolCallStatus::Canceled;
1036 }
1037 }
1038 }
1039
1040 self.connection.cancel(&self.session_id, cx);
1041
1042 // Wait for the send task to complete
1043 cx.foreground_executor().spawn(send_task)
1044 }
1045
1046 pub fn read_text_file(
1047 &self,
1048 path: PathBuf,
1049 line: Option<u32>,
1050 limit: Option<u32>,
1051 reuse_shared_snapshot: bool,
1052 cx: &mut Context<Self>,
1053 ) -> Task<Result<String>> {
1054 let project = self.project.clone();
1055 let action_log = self.action_log.clone();
1056 cx.spawn(async move |this, cx| {
1057 let load = project.update(cx, |project, cx| {
1058 let path = project
1059 .project_path_for_absolute_path(&path, cx)
1060 .context("invalid path")?;
1061 anyhow::Ok(project.open_buffer(path, cx))
1062 });
1063 let buffer = load??.await?;
1064
1065 let snapshot = if reuse_shared_snapshot {
1066 this.read_with(cx, |this, _| {
1067 this.shared_buffers.get(&buffer.clone()).cloned()
1068 })
1069 .log_err()
1070 .flatten()
1071 } else {
1072 None
1073 };
1074
1075 let snapshot = if let Some(snapshot) = snapshot {
1076 snapshot
1077 } else {
1078 action_log.update(cx, |action_log, cx| {
1079 action_log.buffer_read(buffer.clone(), cx);
1080 })?;
1081 project.update(cx, |project, cx| {
1082 let position = buffer
1083 .read(cx)
1084 .snapshot()
1085 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1086 project.set_agent_location(
1087 Some(AgentLocation {
1088 buffer: buffer.downgrade(),
1089 position,
1090 }),
1091 cx,
1092 );
1093 })?;
1094
1095 buffer.update(cx, |buffer, _| buffer.snapshot())?
1096 };
1097
1098 this.update(cx, |this, _| {
1099 let text = snapshot.text();
1100 this.shared_buffers.insert(buffer.clone(), snapshot);
1101 if line.is_none() && limit.is_none() {
1102 return Ok(text);
1103 }
1104 let limit = limit.unwrap_or(u32::MAX) as usize;
1105 let Some(line) = line else {
1106 return Ok(text.lines().take(limit).collect::<String>());
1107 };
1108
1109 let count = text.lines().count();
1110 if count < line as usize {
1111 anyhow::bail!("There are only {} lines", count);
1112 }
1113 Ok(text
1114 .lines()
1115 .skip(line as usize + 1)
1116 .take(limit)
1117 .collect::<String>())
1118 })?
1119 })
1120 }
1121
1122 pub fn write_text_file(
1123 &self,
1124 path: PathBuf,
1125 content: String,
1126 cx: &mut Context<Self>,
1127 ) -> Task<Result<()>> {
1128 let project = self.project.clone();
1129 let action_log = self.action_log.clone();
1130 cx.spawn(async move |this, cx| {
1131 let load = project.update(cx, |project, cx| {
1132 let path = project
1133 .project_path_for_absolute_path(&path, cx)
1134 .context("invalid path")?;
1135 anyhow::Ok(project.open_buffer(path, cx))
1136 });
1137 let buffer = load??.await?;
1138 let snapshot = this.update(cx, |this, cx| {
1139 this.shared_buffers
1140 .get(&buffer)
1141 .cloned()
1142 .unwrap_or_else(|| buffer.read(cx).snapshot())
1143 })?;
1144 let edits = cx
1145 .background_executor()
1146 .spawn(async move {
1147 let old_text = snapshot.text();
1148 text_diff(old_text.as_str(), &content)
1149 .into_iter()
1150 .map(|(range, replacement)| {
1151 (
1152 snapshot.anchor_after(range.start)
1153 ..snapshot.anchor_before(range.end),
1154 replacement,
1155 )
1156 })
1157 .collect::<Vec<_>>()
1158 })
1159 .await;
1160 cx.update(|cx| {
1161 project.update(cx, |project, cx| {
1162 project.set_agent_location(
1163 Some(AgentLocation {
1164 buffer: buffer.downgrade(),
1165 position: edits
1166 .last()
1167 .map(|(range, _)| range.end)
1168 .unwrap_or(Anchor::MIN),
1169 }),
1170 cx,
1171 );
1172 });
1173
1174 action_log.update(cx, |action_log, cx| {
1175 action_log.buffer_read(buffer.clone(), cx);
1176 });
1177 buffer.update(cx, |buffer, cx| {
1178 buffer.edit(edits, None, cx);
1179 });
1180 action_log.update(cx, |action_log, cx| {
1181 action_log.buffer_edited(buffer.clone(), cx);
1182 });
1183 })?;
1184 project
1185 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1186 .await
1187 })
1188 }
1189
1190 pub fn to_markdown(&self, cx: &App) -> String {
1191 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1192 }
1193
1194 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1195 cx.emit(AcpThreadEvent::ServerExited(status));
1196 }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use super::*;
1202 use anyhow::anyhow;
1203 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1204 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1205 use indoc::indoc;
1206 use project::FakeFs;
1207 use rand::Rng as _;
1208 use serde_json::json;
1209 use settings::SettingsStore;
1210 use smol::stream::StreamExt as _;
1211 use std::{cell::RefCell, rc::Rc, time::Duration};
1212
1213 use util::path;
1214
1215 fn init_test(cx: &mut TestAppContext) {
1216 env_logger::try_init().ok();
1217 cx.update(|cx| {
1218 let settings_store = SettingsStore::test(cx);
1219 cx.set_global(settings_store);
1220 Project::init_settings(cx);
1221 language::init(cx);
1222 });
1223 }
1224
1225 #[gpui::test]
1226 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1227 init_test(cx);
1228
1229 let fs = FakeFs::new(cx.executor());
1230 let project = Project::test(fs, [], cx).await;
1231 let connection = Rc::new(FakeAgentConnection::new());
1232 let thread = cx
1233 .spawn(async move |mut cx| {
1234 connection
1235 .new_thread(project, Path::new(path!("/test")), &mut cx)
1236 .await
1237 })
1238 .await
1239 .unwrap();
1240
1241 // Test creating a new user message
1242 thread.update(cx, |thread, cx| {
1243 thread.push_user_content_block(
1244 acp::ContentBlock::Text(acp::TextContent {
1245 annotations: None,
1246 text: "Hello, ".to_string(),
1247 }),
1248 cx,
1249 );
1250 });
1251
1252 thread.update(cx, |thread, cx| {
1253 assert_eq!(thread.entries.len(), 1);
1254 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1255 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1256 } else {
1257 panic!("Expected UserMessage");
1258 }
1259 });
1260
1261 // Test appending to existing user message
1262 thread.update(cx, |thread, cx| {
1263 thread.push_user_content_block(
1264 acp::ContentBlock::Text(acp::TextContent {
1265 annotations: None,
1266 text: "world!".to_string(),
1267 }),
1268 cx,
1269 );
1270 });
1271
1272 thread.update(cx, |thread, cx| {
1273 assert_eq!(thread.entries.len(), 1);
1274 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1275 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1276 } else {
1277 panic!("Expected UserMessage");
1278 }
1279 });
1280
1281 // Test creating new user message after assistant message
1282 thread.update(cx, |thread, cx| {
1283 thread.push_assistant_content_block(
1284 acp::ContentBlock::Text(acp::TextContent {
1285 annotations: None,
1286 text: "Assistant response".to_string(),
1287 }),
1288 false,
1289 cx,
1290 );
1291 });
1292
1293 thread.update(cx, |thread, cx| {
1294 thread.push_user_content_block(
1295 acp::ContentBlock::Text(acp::TextContent {
1296 annotations: None,
1297 text: "New user message".to_string(),
1298 }),
1299 cx,
1300 );
1301 });
1302
1303 thread.update(cx, |thread, cx| {
1304 assert_eq!(thread.entries.len(), 3);
1305 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1306 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1307 } else {
1308 panic!("Expected UserMessage at index 2");
1309 }
1310 });
1311 }
1312
1313 #[gpui::test]
1314 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1315 init_test(cx);
1316
1317 let fs = FakeFs::new(cx.executor());
1318 let project = Project::test(fs, [], cx).await;
1319 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1320 |_, thread, mut cx| {
1321 async move {
1322 thread.update(&mut cx, |thread, cx| {
1323 thread
1324 .handle_session_update(
1325 acp::SessionUpdate::AgentThoughtChunk {
1326 content: "Thinking ".into(),
1327 },
1328 cx,
1329 )
1330 .unwrap();
1331 thread
1332 .handle_session_update(
1333 acp::SessionUpdate::AgentThoughtChunk {
1334 content: "hard!".into(),
1335 },
1336 cx,
1337 )
1338 .unwrap();
1339 })?;
1340 Ok(acp::PromptResponse {
1341 stop_reason: acp::StopReason::EndTurn,
1342 })
1343 }
1344 .boxed_local()
1345 },
1346 ));
1347
1348 let thread = cx
1349 .spawn(async move |mut cx| {
1350 connection
1351 .new_thread(project, Path::new(path!("/test")), &mut cx)
1352 .await
1353 })
1354 .await
1355 .unwrap();
1356
1357 thread
1358 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1359 .await
1360 .unwrap();
1361
1362 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1363 assert_eq!(
1364 output,
1365 indoc! {r#"
1366 ## User
1367
1368 Hello from Zed!
1369
1370 ## Assistant
1371
1372 <thinking>
1373 Thinking hard!
1374 </thinking>
1375
1376 "#}
1377 );
1378 }
1379
1380 #[gpui::test]
1381 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1382 init_test(cx);
1383
1384 let fs = FakeFs::new(cx.executor());
1385 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1386 .await;
1387 let project = Project::test(fs.clone(), [], cx).await;
1388 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1389 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1390 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1391 move |_, thread, mut cx| {
1392 let read_file_tx = read_file_tx.clone();
1393 async move {
1394 let content = thread
1395 .update(&mut cx, |thread, cx| {
1396 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1397 })
1398 .unwrap()
1399 .await
1400 .unwrap();
1401 assert_eq!(content, "one\ntwo\nthree\n");
1402 read_file_tx.take().unwrap().send(()).unwrap();
1403 thread
1404 .update(&mut cx, |thread, cx| {
1405 thread.write_text_file(
1406 path!("/tmp/foo").into(),
1407 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1408 cx,
1409 )
1410 })
1411 .unwrap()
1412 .await
1413 .unwrap();
1414 Ok(acp::PromptResponse {
1415 stop_reason: acp::StopReason::EndTurn,
1416 })
1417 }
1418 .boxed_local()
1419 },
1420 ));
1421
1422 let (worktree, pathbuf) = project
1423 .update(cx, |project, cx| {
1424 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1425 })
1426 .await
1427 .unwrap();
1428 let buffer = project
1429 .update(cx, |project, cx| {
1430 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1431 })
1432 .await
1433 .unwrap();
1434
1435 let thread = cx
1436 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1437 .await
1438 .unwrap();
1439
1440 let request = thread.update(cx, |thread, cx| {
1441 thread.send_raw("Extend the count in /tmp/foo", cx)
1442 });
1443 read_file_rx.await.ok();
1444 buffer.update(cx, |buffer, cx| {
1445 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1446 });
1447 cx.run_until_parked();
1448 assert_eq!(
1449 buffer.read_with(cx, |buffer, _| buffer.text()),
1450 "zero\none\ntwo\nthree\nfour\nfive\n"
1451 );
1452 assert_eq!(
1453 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1454 "zero\none\ntwo\nthree\nfour\nfive\n"
1455 );
1456 request.await.unwrap();
1457 }
1458
1459 #[gpui::test]
1460 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1461 init_test(cx);
1462
1463 let fs = FakeFs::new(cx.executor());
1464 let project = Project::test(fs, [], cx).await;
1465 let id = acp::ToolCallId("test".into());
1466
1467 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1468 let id = id.clone();
1469 move |_, thread, mut cx| {
1470 let id = id.clone();
1471 async move {
1472 thread
1473 .update(&mut cx, |thread, cx| {
1474 thread.handle_session_update(
1475 acp::SessionUpdate::ToolCall(acp::ToolCall {
1476 id: id.clone(),
1477 title: "Label".into(),
1478 kind: acp::ToolKind::Fetch,
1479 status: acp::ToolCallStatus::InProgress,
1480 content: vec![],
1481 locations: vec![],
1482 raw_input: None,
1483 raw_output: None,
1484 }),
1485 cx,
1486 )
1487 })
1488 .unwrap()
1489 .unwrap();
1490 Ok(acp::PromptResponse {
1491 stop_reason: acp::StopReason::EndTurn,
1492 })
1493 }
1494 .boxed_local()
1495 }
1496 }));
1497
1498 let thread = cx
1499 .spawn(async move |mut cx| {
1500 connection
1501 .new_thread(project, Path::new(path!("/test")), &mut cx)
1502 .await
1503 })
1504 .await
1505 .unwrap();
1506
1507 let request = thread.update(cx, |thread, cx| {
1508 thread.send_raw("Fetch https://example.com", cx)
1509 });
1510
1511 run_until_first_tool_call(&thread, cx).await;
1512
1513 thread.read_with(cx, |thread, _| {
1514 assert!(matches!(
1515 thread.entries[1],
1516 AgentThreadEntry::ToolCall(ToolCall {
1517 status: ToolCallStatus::Allowed {
1518 status: acp::ToolCallStatus::InProgress,
1519 ..
1520 },
1521 ..
1522 })
1523 ));
1524 });
1525
1526 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1527
1528 thread.read_with(cx, |thread, _| {
1529 assert!(matches!(
1530 &thread.entries[1],
1531 AgentThreadEntry::ToolCall(ToolCall {
1532 status: ToolCallStatus::Canceled,
1533 ..
1534 })
1535 ));
1536 });
1537
1538 thread
1539 .update(cx, |thread, cx| {
1540 thread.handle_session_update(
1541 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1542 id,
1543 fields: acp::ToolCallUpdateFields {
1544 status: Some(acp::ToolCallStatus::Completed),
1545 ..Default::default()
1546 },
1547 }),
1548 cx,
1549 )
1550 })
1551 .unwrap();
1552
1553 request.await.unwrap();
1554
1555 thread.read_with(cx, |thread, _| {
1556 assert!(matches!(
1557 thread.entries[1],
1558 AgentThreadEntry::ToolCall(ToolCall {
1559 status: ToolCallStatus::Allowed {
1560 status: acp::ToolCallStatus::Completed,
1561 ..
1562 },
1563 ..
1564 })
1565 ));
1566 });
1567 }
1568
1569 #[gpui::test]
1570 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1571 init_test(cx);
1572 let fs = FakeFs::new(cx.background_executor.clone());
1573 fs.insert_tree(path!("/test"), json!({})).await;
1574 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1575
1576 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1577 move |_, thread, mut cx| {
1578 async move {
1579 thread
1580 .update(&mut cx, |thread, cx| {
1581 thread.handle_session_update(
1582 acp::SessionUpdate::ToolCall(acp::ToolCall {
1583 id: acp::ToolCallId("test".into()),
1584 title: "Label".into(),
1585 kind: acp::ToolKind::Edit,
1586 status: acp::ToolCallStatus::Completed,
1587 content: vec![acp::ToolCallContent::Diff {
1588 diff: acp::Diff {
1589 path: "/test/test.txt".into(),
1590 old_text: None,
1591 new_text: "foo".into(),
1592 },
1593 }],
1594 locations: vec![],
1595 raw_input: None,
1596 raw_output: None,
1597 }),
1598 cx,
1599 )
1600 })
1601 .unwrap()
1602 .unwrap();
1603 Ok(acp::PromptResponse {
1604 stop_reason: acp::StopReason::EndTurn,
1605 })
1606 }
1607 .boxed_local()
1608 }
1609 }));
1610
1611 let thread = connection
1612 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1613 .await
1614 .unwrap();
1615 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1616 .await
1617 .unwrap();
1618
1619 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1620 }
1621
1622 async fn run_until_first_tool_call(
1623 thread: &Entity<AcpThread>,
1624 cx: &mut TestAppContext,
1625 ) -> usize {
1626 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1627
1628 let subscription = cx.update(|cx| {
1629 cx.subscribe(thread, move |thread, _, cx| {
1630 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1631 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1632 return tx.try_send(ix).unwrap();
1633 }
1634 }
1635 })
1636 });
1637
1638 select! {
1639 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1640 panic!("Timeout waiting for tool call")
1641 }
1642 ix = rx.next().fuse() => {
1643 drop(subscription);
1644 ix.unwrap()
1645 }
1646 }
1647 }
1648
1649 #[derive(Clone, Default)]
1650 struct FakeAgentConnection {
1651 auth_methods: Vec<acp::AuthMethod>,
1652 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1653 on_user_message: Option<
1654 Rc<
1655 dyn Fn(
1656 acp::PromptRequest,
1657 WeakEntity<AcpThread>,
1658 AsyncApp,
1659 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1660 + 'static,
1661 >,
1662 >,
1663 }
1664
1665 impl FakeAgentConnection {
1666 fn new() -> Self {
1667 Self {
1668 auth_methods: Vec::new(),
1669 on_user_message: None,
1670 sessions: Arc::default(),
1671 }
1672 }
1673
1674 #[expect(unused)]
1675 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1676 self.auth_methods = auth_methods;
1677 self
1678 }
1679
1680 fn on_user_message(
1681 mut self,
1682 handler: impl Fn(
1683 acp::PromptRequest,
1684 WeakEntity<AcpThread>,
1685 AsyncApp,
1686 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1687 + 'static,
1688 ) -> Self {
1689 self.on_user_message.replace(Rc::new(handler));
1690 self
1691 }
1692 }
1693
1694 impl AgentConnection for FakeAgentConnection {
1695 fn auth_methods(&self) -> &[acp::AuthMethod] {
1696 &self.auth_methods
1697 }
1698
1699 fn new_thread(
1700 self: Rc<Self>,
1701 project: Entity<Project>,
1702 _cwd: &Path,
1703 cx: &mut gpui::AsyncApp,
1704 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1705 let session_id = acp::SessionId(
1706 rand::thread_rng()
1707 .sample_iter(&rand::distributions::Alphanumeric)
1708 .take(7)
1709 .map(char::from)
1710 .collect::<String>()
1711 .into(),
1712 );
1713 let thread = cx
1714 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1715 .unwrap();
1716 self.sessions.lock().insert(session_id, thread.downgrade());
1717 Task::ready(Ok(thread))
1718 }
1719
1720 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1721 if self.auth_methods().iter().any(|m| m.id == method) {
1722 Task::ready(Ok(()))
1723 } else {
1724 Task::ready(Err(anyhow!("Invalid Auth Method")))
1725 }
1726 }
1727
1728 fn prompt(
1729 &self,
1730 params: acp::PromptRequest,
1731 cx: &mut App,
1732 ) -> Task<gpui::Result<acp::PromptResponse>> {
1733 let sessions = self.sessions.lock();
1734 let thread = sessions.get(¶ms.session_id).unwrap();
1735 if let Some(handler) = &self.on_user_message {
1736 let handler = handler.clone();
1737 let thread = thread.clone();
1738 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1739 } else {
1740 Task::ready(Ok(acp::PromptResponse {
1741 stop_reason: acp::StopReason::EndTurn,
1742 }))
1743 }
1744 }
1745
1746 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1747 let sessions = self.sessions.lock();
1748 let thread = sessions.get(&session_id).unwrap().clone();
1749
1750 cx.spawn(async move |cx| {
1751 thread
1752 .update(cx, |thread, cx| thread.cancel(cx))
1753 .unwrap()
1754 .await
1755 })
1756 .detach();
1757 }
1758 }
1759}