1mod connection;
2pub use connection::*;
3
4use agent_client_protocol::{self as acp};
5use agentic_coding_protocol as acp_old;
6use anyhow::{Context as _, Result};
7use assistant_tool::ActionLog;
8use buffer_diff::BufferDiff;
9use editor::{Bias, MultiBuffer, PathKey};
10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
11use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
12use itertools::{Itertools, PeekingNext};
13use language::{
14 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
15 text_diff,
16};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::{
23 fmt::Display,
24 mem,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use ui::{App, IconName};
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub content: ContentBlock,
34}
35
36impl UserMessage {
37 pub fn from_acp(
38 message: impl IntoIterator<Item = acp::ContentBlock>,
39 language_registry: Arc<LanguageRegistry>,
40 cx: &mut App,
41 ) -> Self {
42 let mut content = ContentBlock::Empty;
43 for chunk in message {
44 content.append(chunk, &language_registry, cx)
45 }
46 Self { content: content }
47 }
48
49 fn to_markdown(&self, cx: &App) -> String {
50 format!("## User\n{}\n", self.content.to_markdown(cx))
51 }
52}
53
54#[derive(Debug)]
55pub struct MentionPath<'a>(&'a Path);
56
57impl<'a> MentionPath<'a> {
58 const PREFIX: &'static str = "@file:";
59
60 pub fn new(path: &'a Path) -> Self {
61 MentionPath(path)
62 }
63
64 pub fn try_parse(url: &'a str) -> Option<Self> {
65 let path = url.strip_prefix(Self::PREFIX)?;
66 Some(MentionPath(Path::new(path)))
67 }
68
69 pub fn path(&self) -> &Path {
70 self.0
71 }
72}
73
74impl Display for MentionPath<'_> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(
77 f,
78 "[@{}]({}{})",
79 self.0.file_name().unwrap_or_default().display(),
80 Self::PREFIX,
81 self.0.display()
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub struct AssistantMessage {
88 pub chunks: Vec<AssistantMessageChunk>,
89}
90
91impl AssistantMessage {
92 pub fn to_markdown(&self, cx: &App) -> String {
93 format!(
94 "## Assistant\n\n{}\n\n",
95 self.chunks
96 .iter()
97 .map(|chunk| chunk.to_markdown(cx))
98 .join("\n\n")
99 )
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub enum AssistantMessageChunk {
105 Message { block: ContentBlock },
106 Thought { block: ContentBlock },
107}
108
109impl AssistantMessageChunk {
110 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 Self::Message {
112 block: ContentBlock::new(
113 acp::ContentBlock::Text(acp::TextContent {
114 text: chunk.to_owned().into(),
115 annotations: None,
116 }),
117 language_registry,
118 cx,
119 ),
120 }
121 }
122
123 fn to_markdown(&self, cx: &App) -> String {
124 match self {
125 Self::Message { block } => block.to_markdown(cx).to_string(),
126 Self::Thought { block } => {
127 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
128 }
129 }
130 }
131}
132
133#[derive(Debug)]
134pub enum AgentThreadEntry {
135 UserMessage(UserMessage),
136 AssistantMessage(AssistantMessage),
137 ToolCall(ToolCall),
138}
139
140impl AgentThreadEntry {
141 fn to_markdown(&self, cx: &App) -> String {
142 match self {
143 Self::UserMessage(message) => message.to_markdown(cx),
144 Self::AssistantMessage(message) => message.to_markdown(cx),
145 Self::ToolCall(too_call) => too_call.to_markdown(cx),
146 }
147 }
148
149 pub fn diff(&self) -> Option<&Diff> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 content: Some(ToolCallContent::Diff { diff }),
152 ..
153 }) = self
154 {
155 Some(&diff)
156 } else {
157 None
158 }
159 }
160
161 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
162 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
163 Some(locations)
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 // todo! Should this be a vec?
176 pub content: Option<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179}
180
181impl ToolCall {
182 fn from_acp(
183 tool_call: acp::ToolCall,
184 status: ToolCallStatus,
185 language_registry: Arc<LanguageRegistry>,
186 cx: &mut App,
187 ) -> Self {
188 Self {
189 id: tool_call.id,
190 label: cx.new(|cx| {
191 Markdown::new(
192 tool_call.label.into(),
193 Some(language_registry.clone()),
194 None,
195 cx,
196 )
197 }),
198 kind: tool_call.kind,
199 // todo! Do we assume there is either a coalesced content OR diff?
200 content: ToolCallContent::from_acp_contents(tool_call.content, language_registry, cx)
201 .into_iter()
202 .next(),
203 locations: tool_call.locations,
204 status,
205 }
206 }
207 fn to_markdown(&self, cx: &App) -> String {
208 let mut markdown = format!(
209 "**Tool Call: {}**\nStatus: {}\n\n",
210 self.label.read(cx).source(),
211 self.status
212 );
213 if let Some(content) = &self.content {
214 markdown.push_str(content.to_markdown(cx).as_str());
215 markdown.push_str("\n\n");
216 }
217 markdown
218 }
219}
220
221#[derive(Debug)]
222pub enum ToolCallStatus {
223 WaitingForConfirmation {
224 possible_grants: Vec<acp::Grant>,
225 respond_tx: oneshot::Sender<acp::GrantId>,
226 },
227 Allowed {
228 status: acp::ToolCallStatus,
229 },
230 Rejected,
231 Canceled,
232}
233
234impl Display for ToolCallStatus {
235 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236 write!(
237 f,
238 "{}",
239 match self {
240 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
241 ToolCallStatus::Allowed { status } => match status {
242 acp::ToolCallStatus::InProgress => "In Progress",
243 acp::ToolCallStatus::Completed => "Completed",
244 acp::ToolCallStatus::Failed => "Failed",
245 },
246 ToolCallStatus::Rejected => "Rejected",
247 ToolCallStatus::Canceled => "Canceled",
248 }
249 )
250 }
251}
252
253#[derive(Debug, PartialEq, Clone)]
254enum ContentBlock {
255 Empty,
256 Markdown { markdown: Entity<Markdown> },
257}
258
259impl ContentBlock {
260 pub fn new(
261 block: acp::ContentBlock,
262 language_registry: &Arc<LanguageRegistry>,
263 cx: &mut App,
264 ) -> Self {
265 let mut this = Self::Empty;
266 this.append(block, language_registry, cx);
267 this
268 }
269
270 pub fn new_combined(
271 blocks: impl IntoIterator<Item = acp::ContentBlock>,
272 language_registry: &Arc<LanguageRegistry>,
273 cx: &mut App,
274 ) -> Self {
275 let mut this = Self::Empty;
276 for block in blocks {
277 this.append(block, language_registry, cx);
278 }
279 this
280 }
281
282 pub fn append(
283 &mut self,
284 block: acp::ContentBlock,
285 language_registry: &Arc<LanguageRegistry>,
286 cx: &mut App,
287 ) {
288 let new_content = match block {
289 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
290 acp::ContentBlock::ResourceLink(resource_link) => {
291 if let Some(path) = resource_link.uri.strip_prefix("file://") {
292 format!("{}", MentionPath(path.as_ref()))
293 } else {
294 resource_link.uri.clone()
295 }
296 }
297 acp::ContentBlock::Image(_)
298 | acp::ContentBlock::Audio(_)
299 | acp::ContentBlock::Resource(_) => String::new(),
300 };
301
302 match self {
303 ContentBlock::Empty => {
304 *self = ContentBlock::Markdown {
305 markdown: cx.new(|cx| {
306 Markdown::new(
307 new_content.into(),
308 Some(language_registry.clone()),
309 None,
310 cx,
311 )
312 }),
313 };
314 }
315 ContentBlock::Markdown { markdown } => {
316 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
317 }
318 }
319 }
320
321 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
322 match self {
323 ContentBlock::Empty => "",
324 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
325 }
326 }
327}
328
329#[derive(Debug)]
330pub enum ToolCallContent {
331 ContentBlock { content: ContentBlock },
332 Diff { diff: Diff },
333}
334
335impl ToolCallContent {
336 pub fn from_acp(
337 content: acp::ToolCallContent,
338 language_registry: Arc<LanguageRegistry>,
339 cx: &mut App,
340 ) -> Self {
341 match content {
342 acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock {
343 content: ContentBlock::new(content, &language_registry, cx),
344 },
345 acp::ToolCallContent::Diff { diff } => Self::Diff {
346 diff: Diff::from_acp(diff, language_registry, cx),
347 },
348 }
349 }
350
351 pub fn from_acp_contents(
352 content: Vec<acp::ToolCallContent>,
353 language_registry: Arc<LanguageRegistry>,
354 cx: &mut App,
355 ) -> Vec<Self> {
356 content
357 .into_iter()
358 .peekable()
359 .batching(|it| match it.next()? {
360 acp::ToolCallContent::ContentBlock { content } => {
361 let mut block = ContentBlock::new(content, &language_registry, cx);
362 while let Some(acp::ToolCallContent::ContentBlock { content }) =
363 it.peeking_next(|c| matches!(c, acp::ToolCallContent::ContentBlock { .. }))
364 {
365 block.append(content, &language_registry, cx);
366 }
367 Some(ToolCallContent::ContentBlock { content: block })
368 }
369 content @ acp::ToolCallContent::Diff { .. } => Some(ToolCallContent::from_acp(
370 content,
371 language_registry.clone(),
372 cx,
373 )),
374 })
375 .collect()
376 }
377
378 pub fn to_markdown(&self, cx: &App) -> String {
379 match self {
380 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
381 Self::Diff { diff } => diff.to_markdown(cx),
382 }
383 }
384}
385
386#[derive(Debug)]
387pub struct Diff {
388 pub multibuffer: Entity<MultiBuffer>,
389 pub path: PathBuf,
390 pub new_buffer: Entity<Buffer>,
391 pub old_buffer: Entity<Buffer>,
392 _task: Task<Result<()>>,
393}
394
395impl Diff {
396 pub fn from_acp(
397 diff: acp::Diff,
398 language_registry: Arc<LanguageRegistry>,
399 cx: &mut App,
400 ) -> Self {
401 let acp::Diff {
402 path,
403 old_text,
404 new_text,
405 } = diff;
406
407 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
408
409 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
410 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
411 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
412 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
413 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
414 let diff_task = buffer_diff.update(cx, |diff, cx| {
415 diff.set_base_text(
416 old_buffer_snapshot,
417 Some(language_registry.clone()),
418 new_buffer_snapshot,
419 cx,
420 )
421 });
422
423 let task = cx.spawn({
424 let multibuffer = multibuffer.clone();
425 let path = path.clone();
426 let new_buffer = new_buffer.clone();
427 async move |cx| {
428 diff_task.await?;
429
430 multibuffer
431 .update(cx, |multibuffer, cx| {
432 let hunk_ranges = {
433 let buffer = new_buffer.read(cx);
434 let diff = buffer_diff.read(cx);
435 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
436 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
437 .collect::<Vec<_>>()
438 };
439
440 multibuffer.set_excerpts_for_path(
441 PathKey::for_buffer(&new_buffer, cx),
442 new_buffer.clone(),
443 hunk_ranges,
444 editor::DEFAULT_MULTIBUFFER_CONTEXT,
445 cx,
446 );
447 multibuffer.add_diff(buffer_diff.clone(), cx);
448 })
449 .log_err();
450
451 if let Some(language) = language_registry
452 .language_for_file_path(&path)
453 .await
454 .log_err()
455 {
456 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
457 }
458
459 anyhow::Ok(())
460 }
461 });
462
463 Self {
464 multibuffer,
465 path,
466 new_buffer,
467 old_buffer,
468 _task: task,
469 }
470 }
471
472 fn to_markdown(&self, cx: &App) -> String {
473 let buffer_text = self
474 .multibuffer
475 .read(cx)
476 .all_buffers()
477 .iter()
478 .map(|buffer| buffer.read(cx).text())
479 .join("\n");
480 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
481 }
482}
483
484#[derive(Debug, Default)]
485pub struct Plan {
486 pub entries: Vec<PlanEntry>,
487}
488
489#[derive(Debug)]
490pub struct PlanStats<'a> {
491 pub in_progress_entry: Option<&'a PlanEntry>,
492 pub pending: u32,
493 pub completed: u32,
494}
495
496impl Plan {
497 pub fn is_empty(&self) -> bool {
498 self.entries.is_empty()
499 }
500
501 pub fn stats(&self) -> PlanStats<'_> {
502 let mut stats = PlanStats {
503 in_progress_entry: None,
504 pending: 0,
505 completed: 0,
506 };
507
508 for entry in &self.entries {
509 match &entry.status {
510 acp::PlanEntryStatus::Pending => {
511 stats.pending += 1;
512 }
513 acp::PlanEntryStatus::InProgress => {
514 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
515 }
516 acp::PlanEntryStatus::Completed => {
517 stats.completed += 1;
518 }
519 }
520 }
521
522 stats
523 }
524}
525
526#[derive(Debug)]
527pub struct PlanEntry {
528 pub content: Entity<Markdown>,
529 pub priority: acp::PlanEntryPriority,
530 pub status: acp::PlanEntryStatus,
531}
532
533impl PlanEntry {
534 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
535 Self {
536 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
537 priority: entry.priority,
538 status: entry.status,
539 }
540 }
541}
542
543pub struct AcpThread {
544 title: SharedString,
545 entries: Vec<AgentThreadEntry>,
546 plan: Plan,
547 project: Entity<Project>,
548 action_log: Entity<ActionLog>,
549 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
550 send_task: Option<Task<()>>,
551 connection: Arc<dyn AgentConnection>,
552 child_status: Option<Task<Result<()>>>,
553}
554
555pub enum AcpThreadEvent {
556 NewEntry,
557 EntryUpdated(usize),
558}
559
560impl EventEmitter<AcpThreadEvent> for AcpThread {}
561
562#[derive(PartialEq, Eq)]
563pub enum ThreadStatus {
564 Idle,
565 WaitingForToolConfirmation,
566 Generating,
567}
568
569#[derive(Debug, Clone)]
570pub enum LoadError {
571 Unsupported {
572 error_message: SharedString,
573 upgrade_message: SharedString,
574 upgrade_command: String,
575 },
576 Exited(i32),
577 Other(SharedString),
578}
579
580impl Display for LoadError {
581 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
582 match self {
583 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
584 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
585 LoadError::Other(msg) => write!(f, "{}", msg),
586 }
587 }
588}
589
590impl Error for LoadError {}
591
592impl AcpThread {
593 pub fn new(
594 connection: impl AgentConnection + 'static,
595 title: SharedString,
596 child_status: Option<Task<Result<()>>>,
597 project: Entity<Project>,
598 cx: &mut Context<Self>,
599 ) -> Self {
600 let action_log = cx.new(|_| ActionLog::new(project.clone()));
601
602 Self {
603 action_log,
604 shared_buffers: Default::default(),
605 entries: Default::default(),
606 plan: Default::default(),
607 title,
608 project,
609 send_task: None,
610 connection: Arc::new(connection),
611 child_status,
612 }
613 }
614
615 /// Send a request to the agent and wait for a response.
616 pub fn request<R: acp_old::AgentRequest + 'static>(
617 &self,
618 params: R,
619 ) -> impl use<R> + Future<Output = Result<R::Response>> {
620 let params = params.into_any();
621 let result = self.connection.request_any(params);
622 async move {
623 let result = result.await?;
624 Ok(R::response_from_any(result)?)
625 }
626 }
627
628 pub fn action_log(&self) -> &Entity<ActionLog> {
629 &self.action_log
630 }
631
632 pub fn project(&self) -> &Entity<Project> {
633 &self.project
634 }
635
636 pub fn title(&self) -> SharedString {
637 self.title.clone()
638 }
639
640 pub fn entries(&self) -> &[AgentThreadEntry] {
641 &self.entries
642 }
643
644 pub fn status(&self) -> ThreadStatus {
645 if self.send_task.is_some() {
646 if self.waiting_for_tool_confirmation() {
647 ThreadStatus::WaitingForToolConfirmation
648 } else {
649 ThreadStatus::Generating
650 }
651 } else {
652 ThreadStatus::Idle
653 }
654 }
655
656 pub fn has_pending_edit_tool_calls(&self) -> bool {
657 for entry in self.entries.iter().rev() {
658 match entry {
659 AgentThreadEntry::UserMessage(_) => return false,
660 AgentThreadEntry::ToolCall(ToolCall {
661 status:
662 ToolCallStatus::Allowed {
663 status: acp::ToolCallStatus::InProgress,
664 ..
665 },
666 content: Some(ToolCallContent::Diff { .. }),
667 ..
668 }) => return true,
669 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
670 }
671 }
672
673 false
674 }
675
676 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
677 self.entries.push(entry);
678 cx.emit(AcpThreadEvent::NewEntry);
679 }
680
681 pub fn push_assistant_chunk(
682 &mut self,
683 chunk: acp::ContentBlock,
684 is_thought: bool,
685 cx: &mut Context<Self>,
686 ) {
687 let language_registry = self.project.read(cx).languages().clone();
688 let entries_len = self.entries.len();
689 if let Some(last_entry) = self.entries.last_mut()
690 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
691 {
692 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
693 match (chunks.last_mut(), is_thought) {
694 (Some(AssistantMessageChunk::Message { block }), false)
695 | (Some(AssistantMessageChunk::Thought { block }), true) => {
696 block.append(chunk, &language_registry, cx)
697 }
698 _ => {
699 let block = ContentBlock::new(chunk, &language_registry, cx);
700 if is_thought {
701 chunks.push(AssistantMessageChunk::Thought { block })
702 } else {
703 chunks.push(AssistantMessageChunk::Message { block })
704 }
705 }
706 }
707 } else {
708 let block = ContentBlock::new(chunk, &language_registry, cx);
709 let chunk = if is_thought {
710 AssistantMessageChunk::Thought { block }
711 } else {
712 AssistantMessageChunk::Message { block }
713 };
714
715 self.push_entry(
716 AgentThreadEntry::AssistantMessage(AssistantMessage {
717 chunks: vec![chunk],
718 }),
719 cx,
720 );
721 }
722 }
723
724 pub fn update_tool_call(
725 &mut self,
726 tool_call: acp::ToolCall,
727 cx: &mut Context<Self>,
728 ) -> Result<()> {
729 let language_registry = self.project.read(cx).languages().clone();
730 let status = ToolCallStatus::Allowed {
731 status: tool_call.status,
732 };
733 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
734
735 let location = call.locations.last().cloned();
736
737 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
738 match ¤t_call.status {
739 ToolCallStatus::WaitingForConfirmation { .. } => {
740 anyhow::bail!("Tool call hasn't been authorized yet")
741 }
742 ToolCallStatus::Rejected => {
743 anyhow::bail!("Tool call was rejected and therefore can't be updated")
744 }
745 ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {}
746 }
747
748 *current_call = call;
749
750 cx.emit(AcpThreadEvent::EntryUpdated(ix));
751 } else {
752 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
753 }
754
755 if let Some(location) = location {
756 self.set_project_location(location, cx)
757 }
758
759 Ok(())
760 }
761
762 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
763 self.entries
764 .iter_mut()
765 .enumerate()
766 .rev()
767 .find_map(|(index, tool_call)| {
768 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
769 && &tool_call.id == id
770 {
771 Some((index, tool_call))
772 } else {
773 None
774 }
775 })
776 }
777
778 pub fn request_tool_call_permission(
779 &mut self,
780 tool_call: acp::ToolCall,
781 possible_grants: Vec<acp::Grant>,
782 cx: &mut Context<Self>,
783 ) -> oneshot::Receiver<acp::GrantId> {
784 let (tx, rx) = oneshot::channel();
785
786 let status = ToolCallStatus::WaitingForConfirmation {
787 possible_grants,
788 respond_tx: tx,
789 };
790
791 self.insert_tool_call(tool_call, status, cx);
792 rx
793 }
794
795 fn insert_tool_call(
796 &mut self,
797 tool_call: acp::ToolCall,
798 status: ToolCallStatus,
799 cx: &mut Context<Self>,
800 ) {
801 let language_registry = self.project.read(cx).languages().clone();
802 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
803
804 let location = call.locations.last().cloned();
805 if let Some(location) = location {
806 self.set_project_location(location, cx)
807 }
808
809 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
810 }
811
812 pub fn authorize_tool_call(
813 &mut self,
814 id: acp::ToolCallId,
815 grant: acp::Grant,
816 cx: &mut Context<Self>,
817 ) {
818 let Some((ix, call)) = self.tool_call_mut(&id) else {
819 return;
820 };
821
822 let new_status = if grant.is_allowed {
823 ToolCallStatus::Allowed {
824 status: acp::ToolCallStatus::InProgress,
825 }
826 } else {
827 ToolCallStatus::Rejected
828 };
829
830 let curr_status = mem::replace(&mut call.status, new_status);
831
832 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
833 respond_tx.send(grant.id).log_err();
834 } else if cfg!(debug_assertions) {
835 panic!("tried to authorize an already authorized tool call");
836 }
837
838 cx.emit(AcpThreadEvent::EntryUpdated(ix));
839 }
840
841 pub fn plan(&self) -> &Plan {
842 &self.plan
843 }
844
845 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
846 self.plan = Plan {
847 entries: request
848 .entries
849 .into_iter()
850 .map(|entry| PlanEntry::from_acp(entry, cx))
851 .collect(),
852 };
853
854 cx.notify();
855 }
856
857 pub fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
858 self.plan
859 .entries
860 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
861 cx.notify();
862 }
863
864 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
865 self.project.update(cx, |project, cx| {
866 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
867 return;
868 };
869 let buffer = project.open_buffer(path, cx);
870 cx.spawn(async move |project, cx| {
871 let buffer = buffer.await?;
872
873 project.update(cx, |project, cx| {
874 let position = if let Some(line) = location.line {
875 let snapshot = buffer.read(cx).snapshot();
876 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
877 snapshot.anchor_before(point)
878 } else {
879 Anchor::MIN
880 };
881
882 project.set_agent_location(
883 Some(AgentLocation {
884 buffer: buffer.downgrade(),
885 position,
886 }),
887 cx,
888 );
889 })
890 })
891 .detach_and_log_err(cx);
892 });
893 }
894
895 /// Returns true if the last turn is awaiting tool authorization
896 pub fn waiting_for_tool_confirmation(&self) -> bool {
897 for entry in self.entries.iter().rev() {
898 match &entry {
899 AgentThreadEntry::ToolCall(call) => match call.status {
900 ToolCallStatus::WaitingForConfirmation { .. } => return true,
901 ToolCallStatus::Allowed { .. }
902 | ToolCallStatus::Rejected
903 | ToolCallStatus::Canceled => continue,
904 },
905 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
906 // Reached the beginning of the turn
907 return false;
908 }
909 }
910 }
911 false
912 }
913
914 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp_old::InitializeResponse>> {
915 self.request(acp_old::InitializeParams {
916 protocol_version: acp_old::ProtocolVersion::latest(),
917 })
918 }
919
920 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
921 self.request(acp_old::AuthenticateParams)
922 }
923
924 #[cfg(any(test, feature = "test-support"))]
925 pub fn send_raw(
926 &mut self,
927 message: &str,
928 cx: &mut Context<Self>,
929 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
930 self.send(
931 vec![acp::ContentBlock::Text(acp::TextContent {
932 text: message.to_string(),
933 annotations: None,
934 })],
935 cx,
936 )
937 }
938
939 pub fn send(
940 &mut self,
941 message: Vec<acp::ContentBlock>,
942 cx: &mut Context<Self>,
943 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
944 let block =
945 ContentBlock::new_combined(message.clone(), self.project.read(cx).languages(), cx);
946 self.push_entry(
947 AgentThreadEntry::UserMessage(UserMessage { content: block }),
948 cx,
949 );
950
951 let (tx, rx) = oneshot::channel();
952 let cancel = self.cancel(cx);
953
954 self.send_task = Some(cx.spawn(async move |this, cx| {
955 async {
956 cancel.await.log_err();
957
958 let result = this.update(cx, |this, _| this.request(message))?.await;
959 tx.send(result).log_err();
960 this.update(cx, |this, _cx| this.send_task.take())?;
961 anyhow::Ok(())
962 }
963 .await
964 .log_err();
965 }));
966
967 async move {
968 match rx.await {
969 Ok(Err(e)) => Err(e)?,
970 _ => Ok(()),
971 }
972 }
973 .boxed()
974 }
975
976 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp_old::Error>> {
977 if self.send_task.take().is_some() {
978 let request = self.request(acp_old::CancelSendMessageParams);
979 cx.spawn(async move |this, cx| {
980 request.await?;
981 this.update(cx, |this, _cx| {
982 for entry in this.entries.iter_mut() {
983 if let AgentThreadEntry::ToolCall(call) = entry {
984 let cancel = matches!(
985 call.status,
986 ToolCallStatus::WaitingForConfirmation { .. }
987 | ToolCallStatus::Allowed {
988 status: acp::ToolCallStatus::InProgress
989 }
990 );
991
992 if cancel {
993 let curr_status =
994 mem::replace(&mut call.status, ToolCallStatus::Canceled);
995
996 if let ToolCallStatus::WaitingForConfirmation {
997 respond_tx,
998 possible_grants,
999 } = curr_status
1000 {
1001 if let Some(grant_id) = possible_grants
1002 .iter()
1003 .find_map(|g| (!g.is_allowed).then(|| g.id.clone()))
1004 {
1005 // todo! do we need a way to cancel rather than reject?
1006 respond_tx.send(grant_id).ok();
1007 }
1008 }
1009 }
1010 }
1011 }
1012 })?;
1013 Ok(())
1014 })
1015 } else {
1016 Task::ready(Ok(()))
1017 }
1018 }
1019
1020 pub fn read_text_file(
1021 &self,
1022 request: acp::ReadTextFileArguments,
1023 reuse_shared_snapshot: bool,
1024 cx: &mut Context<Self>,
1025 ) -> Task<Result<String>> {
1026 let project = self.project.clone();
1027 let action_log = self.action_log.clone();
1028 cx.spawn(async move |this, cx| {
1029 let load = project.update(cx, |project, cx| {
1030 let path = project
1031 .project_path_for_absolute_path(&request.path, cx)
1032 .context("invalid path")?;
1033 anyhow::Ok(project.open_buffer(path, cx))
1034 });
1035 let buffer = load??.await?;
1036
1037 let snapshot = if reuse_shared_snapshot {
1038 this.read_with(cx, |this, _| {
1039 this.shared_buffers.get(&buffer.clone()).cloned()
1040 })
1041 .log_err()
1042 .flatten()
1043 } else {
1044 None
1045 };
1046
1047 let snapshot = if let Some(snapshot) = snapshot {
1048 snapshot
1049 } else {
1050 action_log.update(cx, |action_log, cx| {
1051 action_log.buffer_read(buffer.clone(), cx);
1052 })?;
1053 project.update(cx, |project, cx| {
1054 let position = buffer
1055 .read(cx)
1056 .snapshot()
1057 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1058 project.set_agent_location(
1059 Some(AgentLocation {
1060 buffer: buffer.downgrade(),
1061 position,
1062 }),
1063 cx,
1064 );
1065 })?;
1066
1067 buffer.update(cx, |buffer, _| buffer.snapshot())?
1068 };
1069
1070 this.update(cx, |this, _| {
1071 let text = snapshot.text();
1072 this.shared_buffers.insert(buffer.clone(), snapshot);
1073 if request.line.is_none() && request.limit.is_none() {
1074 return Ok(text);
1075 }
1076 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1077 let Some(line) = request.line else {
1078 return Ok(text.lines().take(limit).collect::<String>());
1079 };
1080
1081 let count = text.lines().count();
1082 if count < line as usize {
1083 anyhow::bail!("There are only {} lines", count);
1084 }
1085 Ok(text
1086 .lines()
1087 .skip(line as usize + 1)
1088 .take(limit)
1089 .collect::<String>())
1090 })?
1091 })
1092 }
1093
1094 pub fn write_text_file(
1095 &self,
1096 request: acp::WriteTextFileToolArguments,
1097 cx: &mut Context<Self>,
1098 ) -> Task<Result<()>> {
1099 let project = self.project.clone();
1100 let action_log = self.action_log.clone();
1101 cx.spawn(async move |this, cx| {
1102 let load = project.update(cx, |project, cx| {
1103 let path = project
1104 .project_path_for_absolute_path(&request.path, cx)
1105 .context("invalid path")?;
1106 anyhow::Ok(project.open_buffer(path, cx))
1107 });
1108 let buffer = load??.await?;
1109 let snapshot = this.update(cx, |this, cx| {
1110 this.shared_buffers
1111 .get(&buffer)
1112 .cloned()
1113 .unwrap_or_else(|| buffer.read(cx).snapshot())
1114 })?;
1115 let edits = cx
1116 .background_executor()
1117 .spawn(async move {
1118 let old_text = snapshot.text();
1119 text_diff(old_text.as_str(), &request.content)
1120 .into_iter()
1121 .map(|(range, replacement)| {
1122 (
1123 snapshot.anchor_after(range.start)
1124 ..snapshot.anchor_before(range.end),
1125 replacement,
1126 )
1127 })
1128 .collect::<Vec<_>>()
1129 })
1130 .await;
1131 cx.update(|cx| {
1132 project.update(cx, |project, cx| {
1133 project.set_agent_location(
1134 Some(AgentLocation {
1135 buffer: buffer.downgrade(),
1136 position: edits
1137 .last()
1138 .map(|(range, _)| range.end)
1139 .unwrap_or(Anchor::MIN),
1140 }),
1141 cx,
1142 );
1143 });
1144
1145 action_log.update(cx, |action_log, cx| {
1146 action_log.buffer_read(buffer.clone(), cx);
1147 });
1148 buffer.update(cx, |buffer, cx| {
1149 buffer.edit(edits, None, cx);
1150 });
1151 action_log.update(cx, |action_log, cx| {
1152 action_log.buffer_edited(buffer.clone(), cx);
1153 });
1154 })?;
1155 project
1156 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1157 .await
1158 })
1159 }
1160
1161 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1162 self.child_status.take()
1163 }
1164
1165 pub fn to_markdown(&self, cx: &App) -> String {
1166 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1167 }
1168}
1169
1170#[derive(Clone)]
1171pub struct OldAcpClientDelegate {
1172 thread: WeakEntity<AcpThread>,
1173 cx: AsyncApp,
1174 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1175}
1176
1177impl OldAcpClientDelegate {
1178 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1179 Self { thread, cx }
1180 }
1181
1182 pub async fn clear_completed_plan_entries(&self) -> Result<()> {
1183 let cx = &mut self.cx.clone();
1184 cx.update(|cx| {
1185 self.thread
1186 .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
1187 })?
1188 .context("Failed to update thread")?;
1189
1190 Ok(())
1191 }
1192
1193 pub async fn request_existing_tool_call_confirmation(
1194 &self,
1195 tool_call_id: acp_old::ToolCallId,
1196 confirmation: acp_old::ToolCallConfirmation,
1197 ) -> Result<acp_old::ToolCallConfirmationOutcome> {
1198 let cx = &mut self.cx.clone();
1199 let ToolCallRequest { outcome, .. } = cx
1200 .update(|cx| {
1201 self.thread.update(cx, |thread, cx| {
1202 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1203 })
1204 })?
1205 .context("Failed to update thread")??;
1206
1207 Ok(outcome.await?)
1208 }
1209
1210 pub async fn read_text_file_reusing_snapshot(
1211 &self,
1212 request: acp_old::ReadTextFileParams,
1213 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1214 let content = self
1215 .cx
1216 .update(|cx| {
1217 self.thread
1218 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1219 })?
1220 .context("Failed to update thread")?
1221 .await?;
1222 Ok(acp_old::ReadTextFileResponse { content })
1223 }
1224}
1225
1226impl acp_old::Client for OldAcpClientDelegate {
1227 async fn stream_assistant_message_chunk(
1228 &self,
1229 params: acp_old::StreamAssistantMessageChunkParams,
1230 ) -> Result<(), acp_old::Error> {
1231 let cx = &mut self.cx.clone();
1232
1233 cx.update(|cx| {
1234 self.thread
1235 .update(cx, |thread, cx| {
1236 thread.push_assistant_chunk(params.chunk, cx)
1237 })
1238 .ok();
1239 })?;
1240
1241 Ok(())
1242 }
1243
1244 async fn request_tool_call_confirmation(
1245 &self,
1246 request: acp_old::RequestToolCallConfirmationParams,
1247 ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
1248 let cx = &mut self.cx.clone();
1249 let ToolCallRequest { id, outcome } = cx
1250 .update(|cx| {
1251 self.thread
1252 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1253 })?
1254 .context("Failed to update thread")?;
1255
1256 Ok(acp_old::RequestToolCallConfirmationResponse {
1257 id,
1258 outcome: outcome.await.map_err(acp_old::Error::into_internal_error)?,
1259 })
1260 }
1261
1262 async fn push_tool_call(
1263 &self,
1264 request: acp_old::PushToolCallParams,
1265 ) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
1266 let cx = &mut self.cx.clone();
1267 let id = cx
1268 .update(|cx| {
1269 self.thread
1270 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1271 })?
1272 .context("Failed to update thread")?;
1273
1274 Ok(acp_old::PushToolCallResponse { id })
1275 }
1276
1277 async fn update_tool_call(
1278 &self,
1279 request: acp_old::UpdateToolCallParams,
1280 ) -> Result<(), acp_old::Error> {
1281 let cx = &mut self.cx.clone();
1282
1283 cx.update(|cx| {
1284 self.thread.update(cx, |thread, cx| {
1285 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1286 })
1287 })?
1288 .context("Failed to update thread")??;
1289
1290 Ok(())
1291 }
1292
1293 async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
1294 let cx = &mut self.cx.clone();
1295
1296 cx.update(|cx| {
1297 self.thread
1298 .update(cx, |thread, cx| thread.update_plan(request, cx))
1299 })?
1300 .context("Failed to update thread")?;
1301
1302 Ok(())
1303 }
1304
1305 async fn read_text_file(
1306 &self,
1307 request: acp_old::ReadTextFileParams,
1308 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1309 let content = self
1310 .cx
1311 .update(|cx| {
1312 self.thread
1313 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1314 })?
1315 .context("Failed to update thread")?
1316 .await?;
1317 Ok(acp_old::ReadTextFileResponse { content })
1318 }
1319
1320 async fn write_text_file(
1321 &self,
1322 request: acp_old::WriteTextFileParams,
1323 ) -> Result<(), acp_old::Error> {
1324 self.cx
1325 .update(|cx| {
1326 self.thread.update(cx, |thread, cx| {
1327 thread.write_text_file(request.path, request.content, cx)
1328 })
1329 })?
1330 .context("Failed to update thread")?
1331 .await?;
1332
1333 Ok(())
1334 }
1335}
1336
1337fn acp_icon_to_ui_icon(icon: acp_old::Icon) -> IconName {
1338 match icon {
1339 acp_old::Icon::FileSearch => IconName::ToolSearch,
1340 acp_old::Icon::Folder => IconName::ToolFolder,
1341 acp_old::Icon::Globe => IconName::ToolWeb,
1342 acp_old::Icon::Hammer => IconName::ToolHammer,
1343 acp_old::Icon::LightBulb => IconName::ToolBulb,
1344 acp_old::Icon::Pencil => IconName::ToolPencil,
1345 acp_old::Icon::Regex => IconName::ToolRegex,
1346 acp_old::Icon::Terminal => IconName::ToolTerminal,
1347 }
1348}
1349
1350#[cfg(test)]
1351mod tests {
1352 use super::*;
1353 use anyhow::anyhow;
1354 use async_pipe::{PipeReader, PipeWriter};
1355 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1356 use gpui::{AsyncApp, TestAppContext};
1357 use indoc::indoc;
1358 use project::FakeFs;
1359 use serde_json::json;
1360 use settings::SettingsStore;
1361 use smol::{future::BoxedLocal, stream::StreamExt as _};
1362 use std::{cell::RefCell, rc::Rc, time::Duration};
1363 use util::path;
1364
1365 fn init_test(cx: &mut TestAppContext) {
1366 env_logger::try_init().ok();
1367 cx.update(|cx| {
1368 let settings_store = SettingsStore::test(cx);
1369 cx.set_global(settings_store);
1370 Project::init_settings(cx);
1371 language::init(cx);
1372 });
1373 }
1374
1375 #[gpui::test]
1376 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1377 init_test(cx);
1378
1379 let fs = FakeFs::new(cx.executor());
1380 let project = Project::test(fs, [], cx).await;
1381 let (thread, fake_server) = fake_acp_thread(project, cx);
1382
1383 fake_server.update(cx, |fake_server, _| {
1384 fake_server.on_user_message(move |_, server, mut cx| async move {
1385 server
1386 .update(&mut cx, |server, _| {
1387 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1388 chunk: acp_old::AssistantMessageChunk::Thought {
1389 thought: "Thinking ".into(),
1390 },
1391 })
1392 })?
1393 .await
1394 .unwrap();
1395 server
1396 .update(&mut cx, |server, _| {
1397 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1398 chunk: acp_old::AssistantMessageChunk::Thought {
1399 thought: "hard!".into(),
1400 },
1401 })
1402 })?
1403 .await
1404 .unwrap();
1405
1406 Ok(())
1407 })
1408 });
1409
1410 thread
1411 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1412 .await
1413 .unwrap();
1414
1415 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1416 assert_eq!(
1417 output,
1418 indoc! {r#"
1419 ## User
1420
1421 Hello from Zed!
1422
1423 ## Assistant
1424
1425 <thinking>
1426 Thinking hard!
1427 </thinking>
1428
1429 "#}
1430 );
1431 }
1432
1433 #[gpui::test]
1434 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1435 init_test(cx);
1436
1437 let fs = FakeFs::new(cx.executor());
1438 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1439 .await;
1440 let project = Project::test(fs.clone(), [], cx).await;
1441 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1442 let (worktree, pathbuf) = project
1443 .update(cx, |project, cx| {
1444 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1445 })
1446 .await
1447 .unwrap();
1448 let buffer = project
1449 .update(cx, |project, cx| {
1450 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1451 })
1452 .await
1453 .unwrap();
1454
1455 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1456 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1457
1458 fake_server.update(cx, |fake_server, _| {
1459 fake_server.on_user_message(move |_, server, mut cx| {
1460 let read_file_tx = read_file_tx.clone();
1461 async move {
1462 let content = server
1463 .update(&mut cx, |server, _| {
1464 server.send_to_zed(acp_old::ReadTextFileParams {
1465 path: path!("/tmp/foo").into(),
1466 line: None,
1467 limit: None,
1468 })
1469 })?
1470 .await
1471 .unwrap();
1472 assert_eq!(content.content, "one\ntwo\nthree\n");
1473 read_file_tx.take().unwrap().send(()).unwrap();
1474 server
1475 .update(&mut cx, |server, _| {
1476 server.send_to_zed(acp_old::WriteTextFileParams {
1477 path: path!("/tmp/foo").into(),
1478 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1479 })
1480 })?
1481 .await
1482 .unwrap();
1483 Ok(())
1484 }
1485 })
1486 });
1487
1488 let request = thread.update(cx, |thread, cx| {
1489 thread.send_raw("Extend the count in /tmp/foo", cx)
1490 });
1491 read_file_rx.await.ok();
1492 buffer.update(cx, |buffer, cx| {
1493 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1494 });
1495 cx.run_until_parked();
1496 assert_eq!(
1497 buffer.read_with(cx, |buffer, _| buffer.text()),
1498 "zero\none\ntwo\nthree\nfour\nfive\n"
1499 );
1500 assert_eq!(
1501 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1502 "zero\none\ntwo\nthree\nfour\nfive\n"
1503 );
1504 request.await.unwrap();
1505 }
1506
1507 #[gpui::test]
1508 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1509 init_test(cx);
1510
1511 let fs = FakeFs::new(cx.executor());
1512 let project = Project::test(fs, [], cx).await;
1513 let (thread, fake_server) = fake_acp_thread(project, cx);
1514
1515 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1516
1517 let tool_call_id = Rc::new(RefCell::new(None));
1518 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1519 fake_server.update(cx, |fake_server, _| {
1520 let tool_call_id = tool_call_id.clone();
1521 fake_server.on_user_message(move |_, server, mut cx| {
1522 let end_turn_rx = end_turn_rx.clone();
1523 let tool_call_id = tool_call_id.clone();
1524 async move {
1525 let tool_call_result = server
1526 .update(&mut cx, |server, _| {
1527 server.send_to_zed(acp_old::PushToolCallParams {
1528 label: "Fetch".to_string(),
1529 icon: acp_old::Icon::Globe,
1530 content: None,
1531 locations: vec![],
1532 })
1533 })?
1534 .await
1535 .unwrap();
1536 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1537 end_turn_rx.take().unwrap().await.ok();
1538
1539 Ok(())
1540 }
1541 })
1542 });
1543
1544 let request = thread.update(cx, |thread, cx| {
1545 thread.send_raw("Fetch https://example.com", cx)
1546 });
1547
1548 run_until_first_tool_call(&thread, cx).await;
1549
1550 thread.read_with(cx, |thread, _| {
1551 assert!(matches!(
1552 thread.entries[1],
1553 AgentThreadEntry::ToolCall(ToolCall {
1554 status: ToolCallStatus::Allowed {
1555 status: acp::ToolCallStatus::InProgress,
1556 ..
1557 },
1558 ..
1559 })
1560 ));
1561 });
1562
1563 cx.run_until_parked();
1564
1565 thread
1566 .update(cx, |thread, cx| thread.cancel(cx))
1567 .await
1568 .unwrap();
1569
1570 thread.read_with(cx, |thread, _| {
1571 assert!(matches!(
1572 &thread.entries[1],
1573 AgentThreadEntry::ToolCall(ToolCall {
1574 status: ToolCallStatus::Canceled,
1575 ..
1576 })
1577 ));
1578 });
1579
1580 fake_server
1581 .update(cx, |fake_server, _| {
1582 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1583 tool_call_id: tool_call_id.borrow().unwrap(),
1584 status: acp_old::ToolCallStatus::Finished,
1585 content: None,
1586 })
1587 })
1588 .await
1589 .unwrap();
1590
1591 drop(end_turn_tx);
1592 request.await.unwrap();
1593
1594 thread.read_with(cx, |thread, _| {
1595 assert!(matches!(
1596 thread.entries[1],
1597 AgentThreadEntry::ToolCall(ToolCall {
1598 status: ToolCallStatus::Allowed {
1599 status: acp::ToolCallStatus::Completed,
1600 ..
1601 },
1602 ..
1603 })
1604 ));
1605 });
1606 }
1607
1608 async fn run_until_first_tool_call(
1609 thread: &Entity<AcpThread>,
1610 cx: &mut TestAppContext,
1611 ) -> usize {
1612 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1613
1614 let subscription = cx.update(|cx| {
1615 cx.subscribe(thread, move |thread, _, cx| {
1616 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1617 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1618 return tx.try_send(ix).unwrap();
1619 }
1620 }
1621 })
1622 });
1623
1624 select! {
1625 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1626 panic!("Timeout waiting for tool call")
1627 }
1628 ix = rx.next().fuse() => {
1629 drop(subscription);
1630 ix.unwrap()
1631 }
1632 }
1633 }
1634
1635 pub fn fake_acp_thread(
1636 project: Entity<Project>,
1637 cx: &mut TestAppContext,
1638 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1639 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1640 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1641
1642 let thread = cx.new(|cx| {
1643 let foreground_executor = cx.foreground_executor().clone();
1644 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1645 OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1646 stdin_tx,
1647 stdout_rx,
1648 move |fut| {
1649 foreground_executor.spawn(fut).detach();
1650 },
1651 );
1652
1653 let io_task = cx.background_spawn({
1654 async move {
1655 io_fut.await.log_err();
1656 Ok(())
1657 }
1658 });
1659 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1660 });
1661 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1662 (thread, agent)
1663 }
1664
1665 pub struct FakeAcpServer {
1666 connection: acp_old::ClientConnection,
1667
1668 _io_task: Task<()>,
1669 on_user_message: Option<
1670 Rc<
1671 dyn Fn(
1672 acp_old::SendUserMessageParams,
1673 Entity<FakeAcpServer>,
1674 AsyncApp,
1675 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1676 >,
1677 >,
1678 }
1679
1680 #[derive(Clone)]
1681 struct FakeAgent {
1682 server: Entity<FakeAcpServer>,
1683 cx: AsyncApp,
1684 }
1685
1686 impl acp_old::Agent for FakeAgent {
1687 async fn initialize(
1688 &self,
1689 params: acp_old::InitializeParams,
1690 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1691 Ok(acp_old::InitializeResponse {
1692 protocol_version: params.protocol_version,
1693 is_authenticated: true,
1694 })
1695 }
1696
1697 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1698 Ok(())
1699 }
1700
1701 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1702 Ok(())
1703 }
1704
1705 async fn send_user_message(
1706 &self,
1707 request: acp_old::SendUserMessageParams,
1708 ) -> Result<(), acp_old::Error> {
1709 let mut cx = self.cx.clone();
1710 let handler = self
1711 .server
1712 .update(&mut cx, |server, _| server.on_user_message.clone())
1713 .ok()
1714 .flatten();
1715 if let Some(handler) = handler {
1716 handler(request, self.server.clone(), self.cx.clone()).await
1717 } else {
1718 Err(anyhow::anyhow!("No handler for on_user_message").into())
1719 }
1720 }
1721 }
1722
1723 impl FakeAcpServer {
1724 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1725 let agent = FakeAgent {
1726 server: cx.entity(),
1727 cx: cx.to_async(),
1728 };
1729 let foreground_executor = cx.foreground_executor().clone();
1730
1731 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1732 agent.clone(),
1733 stdout,
1734 stdin,
1735 move |fut| {
1736 foreground_executor.spawn(fut).detach();
1737 },
1738 );
1739 FakeAcpServer {
1740 connection: connection,
1741 on_user_message: None,
1742 _io_task: cx.background_spawn(async move {
1743 io_fut.await.log_err();
1744 }),
1745 }
1746 }
1747
1748 fn on_user_message<F>(
1749 &mut self,
1750 handler: impl for<'a> Fn(
1751 acp_old::SendUserMessageParams,
1752 Entity<FakeAcpServer>,
1753 AsyncApp,
1754 ) -> F
1755 + 'static,
1756 ) where
1757 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1758 {
1759 self.on_user_message
1760 .replace(Rc::new(move |request, server, cx| {
1761 handler(request, server, cx).boxed_local()
1762 }));
1763 }
1764
1765 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1766 &self,
1767 message: T,
1768 ) -> BoxedLocal<Result<T::Response>> {
1769 self.connection
1770 .request(message)
1771 .map(|f| f.map_err(|err| anyhow!(err)))
1772 .boxed_local()
1773 }
1774 }
1775}