1mod connection;
2pub use connection::*;
3
4use agent_client_protocol::{self as acp};
5use agentic_coding_protocol as acp_old;
6use anyhow::{Context as _, Result};
7use assistant_tool::ActionLog;
8use buffer_diff::BufferDiff;
9use editor::{Bias, MultiBuffer, PathKey};
10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
11use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
12use itertools::{Itertools, PeekingNext};
13use language::{
14 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
15 text_diff,
16};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::collections::HashMap;
20use std::error::Error;
21use std::fmt::Formatter;
22use std::{
23 fmt::Display,
24 mem,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use ui::{App, IconName};
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub content: ContentBlock,
34}
35
36impl UserMessage {
37 pub fn from_acp(
38 message: impl IntoIterator<Item = acp::ContentBlock>,
39 language_registry: Arc<LanguageRegistry>,
40 cx: &mut App,
41 ) -> Self {
42 let mut content = ContentBlock::Empty;
43 for chunk in message {
44 content.append(chunk, &language_registry, cx)
45 }
46 Self { content: content }
47 }
48
49 fn to_markdown(&self, cx: &App) -> String {
50 format!("## User\n{}\n", self.content.to_markdown(cx))
51 }
52}
53
54#[derive(Debug)]
55pub struct MentionPath<'a>(&'a Path);
56
57impl<'a> MentionPath<'a> {
58 const PREFIX: &'static str = "@file:";
59
60 pub fn new(path: &'a Path) -> Self {
61 MentionPath(path)
62 }
63
64 pub fn try_parse(url: &'a str) -> Option<Self> {
65 let path = url.strip_prefix(Self::PREFIX)?;
66 Some(MentionPath(Path::new(path)))
67 }
68
69 pub fn path(&self) -> &Path {
70 self.0
71 }
72}
73
74impl Display for MentionPath<'_> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(
77 f,
78 "[@{}]({}{})",
79 self.0.file_name().unwrap_or_default().display(),
80 Self::PREFIX,
81 self.0.display()
82 )
83 }
84}
85
86#[derive(Debug, PartialEq)]
87pub struct AssistantMessage {
88 pub chunks: Vec<AssistantMessageChunk>,
89}
90
91impl AssistantMessage {
92 pub fn to_markdown(&self, cx: &App) -> String {
93 format!(
94 "## Assistant\n\n{}\n\n",
95 self.chunks
96 .iter()
97 .map(|chunk| chunk.to_markdown(cx))
98 .join("\n\n")
99 )
100 }
101}
102
103#[derive(Debug, PartialEq)]
104pub enum AssistantMessageChunk {
105 Message { block: ContentBlock },
106 Thought { block: ContentBlock },
107}
108
109impl AssistantMessageChunk {
110 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 Self::Message {
112 block: ContentBlock::new(
113 acp::ContentBlock::Text(acp::TextContent {
114 text: chunk.to_owned().into(),
115 annotations: None,
116 }),
117 language_registry,
118 cx,
119 ),
120 }
121 }
122
123 fn to_markdown(&self, cx: &App) -> String {
124 match self {
125 Self::Message { block } => block.to_markdown(cx).to_string(),
126 Self::Thought { block } => {
127 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
128 }
129 }
130 }
131}
132
133#[derive(Debug)]
134pub enum AgentThreadEntry {
135 UserMessage(UserMessage),
136 AssistantMessage(AssistantMessage),
137 ToolCall(ToolCall),
138}
139
140impl AgentThreadEntry {
141 fn to_markdown(&self, cx: &App) -> String {
142 match self {
143 Self::UserMessage(message) => message.to_markdown(cx),
144 Self::AssistantMessage(message) => message.to_markdown(cx),
145 Self::ToolCall(too_call) => too_call.to_markdown(cx),
146 }
147 }
148
149 pub fn diff(&self) -> Option<&Diff> {
150 if let AgentThreadEntry::ToolCall(ToolCall {
151 content: Some(ToolCallContent::Diff { diff }),
152 ..
153 }) = self
154 {
155 Some(&diff)
156 } else {
157 None
158 }
159 }
160
161 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
162 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
163 Some(locations)
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 // todo! Should this be a vec?
176 pub content: Option<ToolCallContent>,
177 pub status: ToolCallStatus,
178 pub locations: Vec<acp::ToolCallLocation>,
179}
180
181impl ToolCall {
182 fn from_acp(
183 tool_call: acp::ToolCall,
184 status: ToolCallStatus,
185 language_registry: Arc<LanguageRegistry>,
186 cx: &mut App,
187 ) -> Self {
188 Self {
189 id: tool_call.id,
190 label: cx.new(|cx| {
191 Markdown::new(
192 tool_call.label.into(),
193 Some(language_registry.clone()),
194 None,
195 cx,
196 )
197 }),
198 kind: tool_call.kind,
199 // todo! Do we assume there is either a coalesced content OR diff?
200 content: ToolCallContent::from_acp_contents(tool_call.content, language_registry, cx)
201 .into_iter()
202 .next(),
203 locations: tool_call.locations,
204 status,
205 }
206 }
207 fn to_markdown(&self, cx: &App) -> String {
208 let mut markdown = format!(
209 "**Tool Call: {}**\nStatus: {}\n\n",
210 self.label.read(cx).source(),
211 self.status
212 );
213 if let Some(content) = &self.content {
214 markdown.push_str(content.to_markdown(cx).as_str());
215 markdown.push_str("\n\n");
216 }
217 markdown
218 }
219}
220
221#[derive(Debug)]
222pub enum ToolCallStatus {
223 WaitingForConfirmation {
224 possible_grants: Vec<acp::Grant>,
225 respond_tx: oneshot::Sender<acp::GrantId>,
226 },
227 Allowed {
228 status: acp::ToolCallStatus,
229 },
230 Rejected,
231 Canceled,
232}
233
234impl Display for ToolCallStatus {
235 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236 write!(
237 f,
238 "{}",
239 match self {
240 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
241 ToolCallStatus::Allowed { status } => match status {
242 acp::ToolCallStatus::InProgress => "In Progress",
243 acp::ToolCallStatus::Completed => "Completed",
244 acp::ToolCallStatus::Failed => "Failed",
245 },
246 ToolCallStatus::Rejected => "Rejected",
247 ToolCallStatus::Canceled => "Canceled",
248 }
249 )
250 }
251}
252
253#[derive(Debug)]
254pub enum ToolCallConfirmation {
255 Edit {
256 description: Option<Entity<Markdown>>,
257 },
258 Execute {
259 command: String,
260 root_command: String,
261 description: Option<Entity<Markdown>>,
262 },
263 Mcp {
264 server_name: String,
265 tool_name: String,
266 tool_display_name: String,
267 description: Option<Entity<Markdown>>,
268 },
269 Fetch {
270 urls: Vec<SharedString>,
271 description: Option<Entity<Markdown>>,
272 },
273 Other {
274 description: Entity<Markdown>,
275 },
276}
277
278impl ToolCallConfirmation {
279 pub fn from_acp(
280 confirmation: acp_old::ToolCallConfirmation,
281 language_registry: Arc<LanguageRegistry>,
282 cx: &mut App,
283 ) -> Self {
284 let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
285 cx.new(|cx| {
286 Markdown::new(
287 description.into(),
288 Some(language_registry.clone()),
289 None,
290 cx,
291 )
292 })
293 };
294
295 match confirmation {
296 acp_old::ToolCallConfirmation::Edit { description } => Self::Edit {
297 description: description.map(|description| to_md(description, cx)),
298 },
299 acp_old::ToolCallConfirmation::Execute {
300 command,
301 root_command,
302 description,
303 } => Self::Execute {
304 command,
305 root_command,
306 description: description.map(|description| to_md(description, cx)),
307 },
308 acp_old::ToolCallConfirmation::Mcp {
309 server_name,
310 tool_name,
311 tool_display_name,
312 description,
313 } => Self::Mcp {
314 server_name,
315 tool_name,
316 tool_display_name,
317 description: description.map(|description| to_md(description, cx)),
318 },
319 acp_old::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
320 urls: urls.iter().map(|url| url.into()).collect(),
321 description: description.map(|description| to_md(description, cx)),
322 },
323 acp_old::ToolCallConfirmation::Other { description } => Self::Other {
324 description: to_md(description, cx),
325 },
326 }
327 }
328}
329
330#[derive(Debug, PartialEq, Clone)]
331enum ContentBlock {
332 Empty,
333 Markdown { markdown: Entity<Markdown> },
334}
335
336impl ContentBlock {
337 pub fn new(
338 block: acp::ContentBlock,
339 language_registry: &Arc<LanguageRegistry>,
340 cx: &mut App,
341 ) -> Self {
342 let mut this = Self::Empty;
343 this.append(block, language_registry, cx);
344 this
345 }
346
347 pub fn new_combined(
348 blocks: impl IntoIterator<Item = acp::ContentBlock>,
349 language_registry: &Arc<LanguageRegistry>,
350 cx: &mut App,
351 ) -> Self {
352 let mut this = Self::Empty;
353 for block in blocks {
354 this.append(block, language_registry, cx);
355 }
356 this
357 }
358
359 pub fn append(
360 &mut self,
361 block: acp::ContentBlock,
362 language_registry: &Arc<LanguageRegistry>,
363 cx: &mut App,
364 ) {
365 let new_content = match block {
366 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
367 acp::ContentBlock::ResourceLink(resource_link) => {
368 if let Some(path) = resource_link.uri.strip_prefix("file://") {
369 format!("{}", MentionPath(path.as_ref()))
370 } else {
371 resource_link.uri.clone()
372 }
373 }
374 acp::ContentBlock::Image(_)
375 | acp::ContentBlock::Audio(_)
376 | acp::ContentBlock::Resource(_) => String::new(),
377 };
378
379 match self {
380 ContentBlock::Empty => {
381 *self = ContentBlock::Markdown {
382 markdown: cx.new(|cx| {
383 Markdown::new(
384 new_content.into(),
385 Some(language_registry.clone()),
386 None,
387 cx,
388 )
389 }),
390 };
391 }
392 ContentBlock::Markdown { markdown } => {
393 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
394 }
395 }
396 }
397
398 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
399 match self {
400 ContentBlock::Empty => "",
401 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
402 }
403 }
404}
405
406#[derive(Debug)]
407pub enum ToolCallContent {
408 ContentBlock { content: ContentBlock },
409 Diff { diff: Diff },
410}
411
412impl ToolCallContent {
413 pub fn from_acp(
414 content: acp::ToolCallContent,
415 language_registry: Arc<LanguageRegistry>,
416 cx: &mut App,
417 ) -> Self {
418 match content {
419 acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock {
420 content: ContentBlock::new(content, &language_registry, cx),
421 },
422 acp::ToolCallContent::Diff { diff } => Self::Diff {
423 diff: Diff::from_acp(diff, language_registry, cx),
424 },
425 }
426 }
427
428 pub fn from_acp_contents(
429 content: Vec<acp::ToolCallContent>,
430 language_registry: Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) -> Vec<Self> {
433 content
434 .into_iter()
435 .peekable()
436 .batching(|it| match it.next()? {
437 acp::ToolCallContent::ContentBlock { content } => {
438 let mut block = ContentBlock::new(content, &language_registry, cx);
439 while let Some(acp::ToolCallContent::ContentBlock { content }) =
440 it.peeking_next(|c| matches!(c, acp::ToolCallContent::ContentBlock { .. }))
441 {
442 block.append(content, &language_registry, cx);
443 }
444 Some(ToolCallContent::ContentBlock { content: block })
445 }
446 content @ acp::ToolCallContent::Diff { .. } => Some(ToolCallContent::from_acp(
447 content,
448 language_registry.clone(),
449 cx,
450 )),
451 })
452 .collect()
453 }
454
455 fn to_markdown(&self, cx: &App) -> String {
456 match self {
457 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
458 Self::Diff { diff } => diff.to_markdown(cx),
459 }
460 }
461}
462
463#[derive(Debug)]
464pub struct Diff {
465 pub multibuffer: Entity<MultiBuffer>,
466 pub path: PathBuf,
467 pub new_buffer: Entity<Buffer>,
468 pub old_buffer: Entity<Buffer>,
469 _task: Task<Result<()>>,
470}
471
472impl Diff {
473 pub fn from_acp(
474 diff: acp::Diff,
475 language_registry: Arc<LanguageRegistry>,
476 cx: &mut App,
477 ) -> Self {
478 let acp::Diff {
479 path,
480 old_text,
481 new_text,
482 } = diff;
483
484 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
485
486 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
487 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
488 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
489 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
490 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
491 let diff_task = buffer_diff.update(cx, |diff, cx| {
492 diff.set_base_text(
493 old_buffer_snapshot,
494 Some(language_registry.clone()),
495 new_buffer_snapshot,
496 cx,
497 )
498 });
499
500 let task = cx.spawn({
501 let multibuffer = multibuffer.clone();
502 let path = path.clone();
503 let new_buffer = new_buffer.clone();
504 async move |cx| {
505 diff_task.await?;
506
507 multibuffer
508 .update(cx, |multibuffer, cx| {
509 let hunk_ranges = {
510 let buffer = new_buffer.read(cx);
511 let diff = buffer_diff.read(cx);
512 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
513 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
514 .collect::<Vec<_>>()
515 };
516
517 multibuffer.set_excerpts_for_path(
518 PathKey::for_buffer(&new_buffer, cx),
519 new_buffer.clone(),
520 hunk_ranges,
521 editor::DEFAULT_MULTIBUFFER_CONTEXT,
522 cx,
523 );
524 multibuffer.add_diff(buffer_diff.clone(), cx);
525 })
526 .log_err();
527
528 if let Some(language) = language_registry
529 .language_for_file_path(&path)
530 .await
531 .log_err()
532 {
533 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
534 }
535
536 anyhow::Ok(())
537 }
538 });
539
540 Self {
541 multibuffer,
542 path,
543 new_buffer,
544 old_buffer,
545 _task: task,
546 }
547 }
548
549 fn to_markdown(&self, cx: &App) -> String {
550 let buffer_text = self
551 .multibuffer
552 .read(cx)
553 .all_buffers()
554 .iter()
555 .map(|buffer| buffer.read(cx).text())
556 .join("\n");
557 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
558 }
559}
560
561#[derive(Debug, Default)]
562pub struct Plan {
563 pub entries: Vec<PlanEntry>,
564}
565
566#[derive(Debug)]
567pub struct PlanStats<'a> {
568 pub in_progress_entry: Option<&'a PlanEntry>,
569 pub pending: u32,
570 pub completed: u32,
571}
572
573impl Plan {
574 pub fn is_empty(&self) -> bool {
575 self.entries.is_empty()
576 }
577
578 pub fn stats(&self) -> PlanStats<'_> {
579 let mut stats = PlanStats {
580 in_progress_entry: None,
581 pending: 0,
582 completed: 0,
583 };
584
585 for entry in &self.entries {
586 match &entry.status {
587 acp_old::PlanEntryStatus::Pending => {
588 stats.pending += 1;
589 }
590 acp_old::PlanEntryStatus::InProgress => {
591 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
592 }
593 acp_old::PlanEntryStatus::Completed => {
594 stats.completed += 1;
595 }
596 }
597 }
598
599 stats
600 }
601}
602
603#[derive(Debug)]
604pub struct PlanEntry {
605 pub content: Entity<Markdown>,
606 pub priority: acp_old::PlanEntryPriority,
607 pub status: acp_old::PlanEntryStatus,
608}
609
610impl PlanEntry {
611 pub fn from_acp(entry: acp_old::PlanEntry, cx: &mut App) -> Self {
612 Self {
613 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
614 priority: entry.priority,
615 status: entry.status,
616 }
617 }
618}
619
620pub struct AcpThread {
621 title: SharedString,
622 entries: Vec<AgentThreadEntry>,
623 plan: Plan,
624 project: Entity<Project>,
625 action_log: Entity<ActionLog>,
626 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
627 send_task: Option<Task<()>>,
628 connection: Arc<dyn AgentConnection>,
629 child_status: Option<Task<Result<()>>>,
630}
631
632pub enum AcpThreadEvent {
633 NewEntry,
634 EntryUpdated(usize),
635}
636
637impl EventEmitter<AcpThreadEvent> for AcpThread {}
638
639#[derive(PartialEq, Eq)]
640pub enum ThreadStatus {
641 Idle,
642 WaitingForToolConfirmation,
643 Generating,
644}
645
646#[derive(Debug, Clone)]
647pub enum LoadError {
648 Unsupported {
649 error_message: SharedString,
650 upgrade_message: SharedString,
651 upgrade_command: String,
652 },
653 Exited(i32),
654 Other(SharedString),
655}
656
657impl Display for LoadError {
658 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
659 match self {
660 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
661 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
662 LoadError::Other(msg) => write!(f, "{}", msg),
663 }
664 }
665}
666
667impl Error for LoadError {}
668
669impl AcpThread {
670 pub fn new(
671 connection: impl AgentConnection + 'static,
672 title: SharedString,
673 child_status: Option<Task<Result<()>>>,
674 project: Entity<Project>,
675 cx: &mut Context<Self>,
676 ) -> Self {
677 let action_log = cx.new(|_| ActionLog::new(project.clone()));
678
679 Self {
680 action_log,
681 shared_buffers: Default::default(),
682 entries: Default::default(),
683 plan: Default::default(),
684 title,
685 project,
686 send_task: None,
687 connection: Arc::new(connection),
688 child_status,
689 }
690 }
691
692 /// Send a request to the agent and wait for a response.
693 pub fn request<R: acp_old::AgentRequest + 'static>(
694 &self,
695 params: R,
696 ) -> impl use<R> + Future<Output = Result<R::Response>> {
697 let params = params.into_any();
698 let result = self.connection.request_any(params);
699 async move {
700 let result = result.await?;
701 Ok(R::response_from_any(result)?)
702 }
703 }
704
705 pub fn action_log(&self) -> &Entity<ActionLog> {
706 &self.action_log
707 }
708
709 pub fn project(&self) -> &Entity<Project> {
710 &self.project
711 }
712
713 pub fn title(&self) -> SharedString {
714 self.title.clone()
715 }
716
717 pub fn entries(&self) -> &[AgentThreadEntry] {
718 &self.entries
719 }
720
721 pub fn status(&self) -> ThreadStatus {
722 if self.send_task.is_some() {
723 if self.waiting_for_tool_confirmation() {
724 ThreadStatus::WaitingForToolConfirmation
725 } else {
726 ThreadStatus::Generating
727 }
728 } else {
729 ThreadStatus::Idle
730 }
731 }
732
733 pub fn has_pending_edit_tool_calls(&self) -> bool {
734 for entry in self.entries.iter().rev() {
735 match entry {
736 AgentThreadEntry::UserMessage(_) => return false,
737 AgentThreadEntry::ToolCall(ToolCall {
738 status:
739 ToolCallStatus::Allowed {
740 status: acp::ToolCallStatus::InProgress,
741 ..
742 },
743 content: Some(ToolCallContent::Diff { .. }),
744 ..
745 }) => return true,
746 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
747 }
748 }
749
750 false
751 }
752
753 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
754 self.entries.push(entry);
755 cx.emit(AcpThreadEvent::NewEntry);
756 }
757
758 pub fn push_assistant_chunk(
759 &mut self,
760 chunk: acp::ContentBlock,
761 is_thought: bool,
762 cx: &mut Context<Self>,
763 ) {
764 let language_registry = self.project.read(cx).languages().clone();
765 let entries_len = self.entries.len();
766 if let Some(last_entry) = self.entries.last_mut()
767 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
768 {
769 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
770 match (chunks.last_mut(), is_thought) {
771 (Some(AssistantMessageChunk::Message { block }), false)
772 | (Some(AssistantMessageChunk::Thought { block }), true) => {
773 block.append(chunk, &language_registry, cx)
774 }
775 _ => {
776 let block = ContentBlock::new(chunk, &language_registry, cx);
777 if is_thought {
778 chunks.push(AssistantMessageChunk::Thought { block })
779 } else {
780 chunks.push(AssistantMessageChunk::Message { block })
781 }
782 }
783 }
784 } else {
785 let block = ContentBlock::new(chunk, &language_registry, cx);
786 let chunk = if is_thought {
787 AssistantMessageChunk::Thought { block }
788 } else {
789 AssistantMessageChunk::Message { block }
790 };
791
792 self.push_entry(
793 AgentThreadEntry::AssistantMessage(AssistantMessage {
794 chunks: vec![chunk],
795 }),
796 cx,
797 );
798 }
799 }
800
801 pub fn update_tool_call(
802 &mut self,
803 tool_call: acp::ToolCall,
804 cx: &mut Context<Self>,
805 ) -> Result<()> {
806 let language_registry = self.project.read(cx).languages().clone();
807 let status = ToolCallStatus::Allowed {
808 status: tool_call.status,
809 };
810 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
811
812 let location = call.locations.last().cloned();
813
814 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
815 match ¤t_call.status {
816 ToolCallStatus::WaitingForConfirmation { .. } => {
817 anyhow::bail!("Tool call hasn't been authorized yet")
818 }
819 ToolCallStatus::Rejected => {
820 anyhow::bail!("Tool call was rejected and therefore can't be updated")
821 }
822 ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {}
823 }
824
825 *current_call = call;
826
827 cx.emit(AcpThreadEvent::EntryUpdated(ix));
828 } else {
829 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
830 }
831
832 if let Some(location) = location {
833 self.set_project_location(location, cx)
834 }
835
836 Ok(())
837 }
838
839 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
840 self.entries
841 .iter_mut()
842 .enumerate()
843 .rev()
844 .find_map(|(index, tool_call)| {
845 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
846 && &tool_call.id == id
847 {
848 Some((index, tool_call))
849 } else {
850 None
851 }
852 })
853 }
854
855 pub fn request_tool_call_permission(
856 &mut self,
857 tool_call: acp::ToolCall,
858 possible_grants: Vec<acp::Grant>,
859 cx: &mut Context<Self>,
860 ) -> oneshot::Receiver<acp::GrantId> {
861 let (tx, rx) = oneshot::channel();
862
863 let status = ToolCallStatus::WaitingForConfirmation {
864 possible_grants,
865 respond_tx: tx,
866 };
867
868 self.insert_tool_call(tool_call, status, cx);
869 rx
870 }
871
872 fn insert_tool_call(
873 &mut self,
874 tool_call: acp::ToolCall,
875 status: ToolCallStatus,
876 cx: &mut Context<Self>,
877 ) {
878 let language_registry = self.project.read(cx).languages().clone();
879 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
880
881 let location = call.locations.last().cloned();
882 if let Some(location) = location {
883 self.set_project_location(location, cx)
884 }
885
886 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
887 }
888
889 pub fn authorize_tool_call(
890 &mut self,
891 id: acp::ToolCallId,
892 grant: acp::Grant,
893 cx: &mut Context<Self>,
894 ) {
895 let Some((ix, call)) = self.tool_call_mut(&id) else {
896 return;
897 };
898
899 let new_status = if grant.is_allowed {
900 ToolCallStatus::Allowed {
901 status: acp::ToolCallStatus::InProgress,
902 }
903 } else {
904 ToolCallStatus::Rejected
905 };
906
907 let curr_status = mem::replace(&mut call.status, new_status);
908
909 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
910 respond_tx.send(grant.id).log_err();
911 } else if cfg!(debug_assertions) {
912 panic!("tried to authorize an already authorized tool call");
913 }
914
915 cx.emit(AcpThreadEvent::EntryUpdated(ix));
916 }
917
918 pub fn plan(&self) -> &Plan {
919 &self.plan
920 }
921
922 pub fn update_plan(&mut self, request: acp_old::UpdatePlanParams, cx: &mut Context<Self>) {
923 self.plan = Plan {
924 entries: request
925 .entries
926 .into_iter()
927 .map(|entry| PlanEntry::from_acp(entry, cx))
928 .collect(),
929 };
930
931 cx.notify();
932 }
933
934 pub fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
935 self.plan
936 .entries
937 .retain(|entry| !matches!(entry.status, acp_old::PlanEntryStatus::Completed));
938 cx.notify();
939 }
940
941 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
942 self.project.update(cx, |project, cx| {
943 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
944 return;
945 };
946 let buffer = project.open_buffer(path, cx);
947 cx.spawn(async move |project, cx| {
948 let buffer = buffer.await?;
949
950 project.update(cx, |project, cx| {
951 let position = if let Some(line) = location.line {
952 let snapshot = buffer.read(cx).snapshot();
953 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
954 snapshot.anchor_before(point)
955 } else {
956 Anchor::MIN
957 };
958
959 project.set_agent_location(
960 Some(AgentLocation {
961 buffer: buffer.downgrade(),
962 position,
963 }),
964 cx,
965 );
966 })
967 })
968 .detach_and_log_err(cx);
969 });
970 }
971
972 /// Returns true if the last turn is awaiting tool authorization
973 pub fn waiting_for_tool_confirmation(&self) -> bool {
974 for entry in self.entries.iter().rev() {
975 match &entry {
976 AgentThreadEntry::ToolCall(call) => match call.status {
977 ToolCallStatus::WaitingForConfirmation { .. } => return true,
978 ToolCallStatus::Allowed { .. }
979 | ToolCallStatus::Rejected
980 | ToolCallStatus::Canceled => continue,
981 },
982 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
983 // Reached the beginning of the turn
984 return false;
985 }
986 }
987 }
988 false
989 }
990
991 pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp_old::InitializeResponse>> {
992 self.request(acp_old::InitializeParams {
993 protocol_version: acp_old::ProtocolVersion::latest(),
994 })
995 }
996
997 pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
998 self.request(acp_old::AuthenticateParams)
999 }
1000
1001 #[cfg(any(test, feature = "test-support"))]
1002 pub fn send_raw(
1003 &mut self,
1004 message: &str,
1005 cx: &mut Context<Self>,
1006 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1007 self.send(
1008 vec![acp::ContentBlock::Text(acp::TextContent {
1009 text: message.to_string(),
1010 annotations: None,
1011 })],
1012 cx,
1013 )
1014 }
1015
1016 pub fn send(
1017 &mut self,
1018 message: Vec<acp::ContentBlock>,
1019 cx: &mut Context<Self>,
1020 ) -> BoxFuture<'static, Result<(), acp_old::Error>> {
1021 let block =
1022 ContentBlock::new_combined(message.clone(), self.project.read(cx).languages(), cx);
1023 self.push_entry(
1024 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1025 cx,
1026 );
1027
1028 let (tx, rx) = oneshot::channel();
1029 let cancel = self.cancel(cx);
1030
1031 self.send_task = Some(cx.spawn(async move |this, cx| {
1032 async {
1033 cancel.await.log_err();
1034
1035 let result = this.update(cx, |this, _| this.request(message))?.await;
1036 tx.send(result).log_err();
1037 this.update(cx, |this, _cx| this.send_task.take())?;
1038 anyhow::Ok(())
1039 }
1040 .await
1041 .log_err();
1042 }));
1043
1044 async move {
1045 match rx.await {
1046 Ok(Err(e)) => Err(e)?,
1047 _ => Ok(()),
1048 }
1049 }
1050 .boxed()
1051 }
1052
1053 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp_old::Error>> {
1054 if self.send_task.take().is_some() {
1055 let request = self.request(acp_old::CancelSendMessageParams);
1056 cx.spawn(async move |this, cx| {
1057 request.await?;
1058 this.update(cx, |this, _cx| {
1059 for entry in this.entries.iter_mut() {
1060 if let AgentThreadEntry::ToolCall(call) = entry {
1061 let cancel = matches!(
1062 call.status,
1063 ToolCallStatus::WaitingForConfirmation { .. }
1064 | ToolCallStatus::Allowed {
1065 status: acp::ToolCallStatus::InProgress
1066 }
1067 );
1068
1069 if cancel {
1070 let curr_status =
1071 mem::replace(&mut call.status, ToolCallStatus::Canceled);
1072
1073 if let ToolCallStatus::WaitingForConfirmation {
1074 respond_tx,
1075 possible_grants,
1076 } = curr_status
1077 {
1078 if let Some(grant_id) = possible_grants
1079 .iter()
1080 .find_map(|g| (!g.is_allowed).then(|| g.id.clone()))
1081 {
1082 // todo! do we need a way to cancel rather than reject?
1083 respond_tx.send(grant_id).ok();
1084 }
1085 }
1086 }
1087 }
1088 }
1089 })?;
1090 Ok(())
1091 })
1092 } else {
1093 Task::ready(Ok(()))
1094 }
1095 }
1096
1097 pub fn read_text_file(
1098 &self,
1099 request: acp_old::ReadTextFileParams,
1100 reuse_shared_snapshot: bool,
1101 cx: &mut Context<Self>,
1102 ) -> Task<Result<String>> {
1103 let project = self.project.clone();
1104 let action_log = self.action_log.clone();
1105 cx.spawn(async move |this, cx| {
1106 let load = project.update(cx, |project, cx| {
1107 let path = project
1108 .project_path_for_absolute_path(&request.path, cx)
1109 .context("invalid path")?;
1110 anyhow::Ok(project.open_buffer(path, cx))
1111 });
1112 let buffer = load??.await?;
1113
1114 let snapshot = if reuse_shared_snapshot {
1115 this.read_with(cx, |this, _| {
1116 this.shared_buffers.get(&buffer.clone()).cloned()
1117 })
1118 .log_err()
1119 .flatten()
1120 } else {
1121 None
1122 };
1123
1124 let snapshot = if let Some(snapshot) = snapshot {
1125 snapshot
1126 } else {
1127 action_log.update(cx, |action_log, cx| {
1128 action_log.buffer_read(buffer.clone(), cx);
1129 })?;
1130 project.update(cx, |project, cx| {
1131 let position = buffer
1132 .read(cx)
1133 .snapshot()
1134 .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1135 project.set_agent_location(
1136 Some(AgentLocation {
1137 buffer: buffer.downgrade(),
1138 position,
1139 }),
1140 cx,
1141 );
1142 })?;
1143
1144 buffer.update(cx, |buffer, _| buffer.snapshot())?
1145 };
1146
1147 this.update(cx, |this, _| {
1148 let text = snapshot.text();
1149 this.shared_buffers.insert(buffer.clone(), snapshot);
1150 if request.line.is_none() && request.limit.is_none() {
1151 return Ok(text);
1152 }
1153 let limit = request.limit.unwrap_or(u32::MAX) as usize;
1154 let Some(line) = request.line else {
1155 return Ok(text.lines().take(limit).collect::<String>());
1156 };
1157
1158 let count = text.lines().count();
1159 if count < line as usize {
1160 anyhow::bail!("There are only {} lines", count);
1161 }
1162 Ok(text
1163 .lines()
1164 .skip(line as usize + 1)
1165 .take(limit)
1166 .collect::<String>())
1167 })?
1168 })
1169 }
1170
1171 pub fn write_text_file(
1172 &self,
1173 path: PathBuf,
1174 content: String,
1175 cx: &mut Context<Self>,
1176 ) -> Task<Result<()>> {
1177 let project = self.project.clone();
1178 let action_log = self.action_log.clone();
1179 cx.spawn(async move |this, cx| {
1180 let load = project.update(cx, |project, cx| {
1181 let path = project
1182 .project_path_for_absolute_path(&path, cx)
1183 .context("invalid path")?;
1184 anyhow::Ok(project.open_buffer(path, cx))
1185 });
1186 let buffer = load??.await?;
1187 let snapshot = this.update(cx, |this, cx| {
1188 this.shared_buffers
1189 .get(&buffer)
1190 .cloned()
1191 .unwrap_or_else(|| buffer.read(cx).snapshot())
1192 })?;
1193 let edits = cx
1194 .background_executor()
1195 .spawn(async move {
1196 let old_text = snapshot.text();
1197 text_diff(old_text.as_str(), &content)
1198 .into_iter()
1199 .map(|(range, replacement)| {
1200 (
1201 snapshot.anchor_after(range.start)
1202 ..snapshot.anchor_before(range.end),
1203 replacement,
1204 )
1205 })
1206 .collect::<Vec<_>>()
1207 })
1208 .await;
1209 cx.update(|cx| {
1210 project.update(cx, |project, cx| {
1211 project.set_agent_location(
1212 Some(AgentLocation {
1213 buffer: buffer.downgrade(),
1214 position: edits
1215 .last()
1216 .map(|(range, _)| range.end)
1217 .unwrap_or(Anchor::MIN),
1218 }),
1219 cx,
1220 );
1221 });
1222
1223 action_log.update(cx, |action_log, cx| {
1224 action_log.buffer_read(buffer.clone(), cx);
1225 });
1226 buffer.update(cx, |buffer, cx| {
1227 buffer.edit(edits, None, cx);
1228 });
1229 action_log.update(cx, |action_log, cx| {
1230 action_log.buffer_edited(buffer.clone(), cx);
1231 });
1232 })?;
1233 project
1234 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1235 .await
1236 })
1237 }
1238
1239 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1240 self.child_status.take()
1241 }
1242
1243 pub fn to_markdown(&self, cx: &App) -> String {
1244 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1245 }
1246}
1247
1248#[derive(Clone)]
1249pub struct OldAcpClientDelegate {
1250 thread: WeakEntity<AcpThread>,
1251 cx: AsyncApp,
1252 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1253}
1254
1255impl OldAcpClientDelegate {
1256 pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1257 Self { thread, cx }
1258 }
1259
1260 pub async fn clear_completed_plan_entries(&self) -> Result<()> {
1261 let cx = &mut self.cx.clone();
1262 cx.update(|cx| {
1263 self.thread
1264 .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
1265 })?
1266 .context("Failed to update thread")?;
1267
1268 Ok(())
1269 }
1270
1271 pub async fn request_existing_tool_call_confirmation(
1272 &self,
1273 tool_call_id: acp_old::ToolCallId,
1274 confirmation: acp_old::ToolCallConfirmation,
1275 ) -> Result<acp_old::ToolCallConfirmationOutcome> {
1276 let cx = &mut self.cx.clone();
1277 let ToolCallRequest { outcome, .. } = cx
1278 .update(|cx| {
1279 self.thread.update(cx, |thread, cx| {
1280 thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1281 })
1282 })?
1283 .context("Failed to update thread")??;
1284
1285 Ok(outcome.await?)
1286 }
1287
1288 pub async fn read_text_file_reusing_snapshot(
1289 &self,
1290 request: acp_old::ReadTextFileParams,
1291 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1292 let content = self
1293 .cx
1294 .update(|cx| {
1295 self.thread
1296 .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1297 })?
1298 .context("Failed to update thread")?
1299 .await?;
1300 Ok(acp_old::ReadTextFileResponse { content })
1301 }
1302}
1303
1304impl acp_old::Client for OldAcpClientDelegate {
1305 async fn stream_assistant_message_chunk(
1306 &self,
1307 params: acp_old::StreamAssistantMessageChunkParams,
1308 ) -> Result<(), acp_old::Error> {
1309 let cx = &mut self.cx.clone();
1310
1311 cx.update(|cx| {
1312 self.thread
1313 .update(cx, |thread, cx| {
1314 thread.push_assistant_chunk(params.chunk, cx)
1315 })
1316 .ok();
1317 })?;
1318
1319 Ok(())
1320 }
1321
1322 async fn request_tool_call_confirmation(
1323 &self,
1324 request: acp_old::RequestToolCallConfirmationParams,
1325 ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
1326 let cx = &mut self.cx.clone();
1327 let ToolCallRequest { id, outcome } = cx
1328 .update(|cx| {
1329 self.thread
1330 .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1331 })?
1332 .context("Failed to update thread")?;
1333
1334 Ok(acp_old::RequestToolCallConfirmationResponse {
1335 id,
1336 outcome: outcome.await.map_err(acp_old::Error::into_internal_error)?,
1337 })
1338 }
1339
1340 async fn push_tool_call(
1341 &self,
1342 request: acp_old::PushToolCallParams,
1343 ) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
1344 let cx = &mut self.cx.clone();
1345 let id = cx
1346 .update(|cx| {
1347 self.thread
1348 .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1349 })?
1350 .context("Failed to update thread")?;
1351
1352 Ok(acp_old::PushToolCallResponse { id })
1353 }
1354
1355 async fn update_tool_call(
1356 &self,
1357 request: acp_old::UpdateToolCallParams,
1358 ) -> Result<(), acp_old::Error> {
1359 let cx = &mut self.cx.clone();
1360
1361 cx.update(|cx| {
1362 self.thread.update(cx, |thread, cx| {
1363 thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1364 })
1365 })?
1366 .context("Failed to update thread")??;
1367
1368 Ok(())
1369 }
1370
1371 async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
1372 let cx = &mut self.cx.clone();
1373
1374 cx.update(|cx| {
1375 self.thread
1376 .update(cx, |thread, cx| thread.update_plan(request, cx))
1377 })?
1378 .context("Failed to update thread")?;
1379
1380 Ok(())
1381 }
1382
1383 async fn read_text_file(
1384 &self,
1385 request: acp_old::ReadTextFileParams,
1386 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1387 let content = self
1388 .cx
1389 .update(|cx| {
1390 self.thread
1391 .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1392 })?
1393 .context("Failed to update thread")?
1394 .await?;
1395 Ok(acp_old::ReadTextFileResponse { content })
1396 }
1397
1398 async fn write_text_file(
1399 &self,
1400 request: acp_old::WriteTextFileParams,
1401 ) -> Result<(), acp_old::Error> {
1402 self.cx
1403 .update(|cx| {
1404 self.thread.update(cx, |thread, cx| {
1405 thread.write_text_file(request.path, request.content, cx)
1406 })
1407 })?
1408 .context("Failed to update thread")?
1409 .await?;
1410
1411 Ok(())
1412 }
1413}
1414
1415fn acp_icon_to_ui_icon(icon: acp_old::Icon) -> IconName {
1416 match icon {
1417 acp_old::Icon::FileSearch => IconName::ToolSearch,
1418 acp_old::Icon::Folder => IconName::ToolFolder,
1419 acp_old::Icon::Globe => IconName::ToolWeb,
1420 acp_old::Icon::Hammer => IconName::ToolHammer,
1421 acp_old::Icon::LightBulb => IconName::ToolBulb,
1422 acp_old::Icon::Pencil => IconName::ToolPencil,
1423 acp_old::Icon::Regex => IconName::ToolRegex,
1424 acp_old::Icon::Terminal => IconName::ToolTerminal,
1425 }
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430 use super::*;
1431 use anyhow::anyhow;
1432 use async_pipe::{PipeReader, PipeWriter};
1433 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1434 use gpui::{AsyncApp, TestAppContext};
1435 use indoc::indoc;
1436 use project::FakeFs;
1437 use serde_json::json;
1438 use settings::SettingsStore;
1439 use smol::{future::BoxedLocal, stream::StreamExt as _};
1440 use std::{cell::RefCell, rc::Rc, time::Duration};
1441 use util::path;
1442
1443 fn init_test(cx: &mut TestAppContext) {
1444 env_logger::try_init().ok();
1445 cx.update(|cx| {
1446 let settings_store = SettingsStore::test(cx);
1447 cx.set_global(settings_store);
1448 Project::init_settings(cx);
1449 language::init(cx);
1450 });
1451 }
1452
1453 #[gpui::test]
1454 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1455 init_test(cx);
1456
1457 let fs = FakeFs::new(cx.executor());
1458 let project = Project::test(fs, [], cx).await;
1459 let (thread, fake_server) = fake_acp_thread(project, cx);
1460
1461 fake_server.update(cx, |fake_server, _| {
1462 fake_server.on_user_message(move |_, server, mut cx| async move {
1463 server
1464 .update(&mut cx, |server, _| {
1465 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1466 chunk: acp_old::AssistantMessageChunk::Thought {
1467 thought: "Thinking ".into(),
1468 },
1469 })
1470 })?
1471 .await
1472 .unwrap();
1473 server
1474 .update(&mut cx, |server, _| {
1475 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1476 chunk: acp_old::AssistantMessageChunk::Thought {
1477 thought: "hard!".into(),
1478 },
1479 })
1480 })?
1481 .await
1482 .unwrap();
1483
1484 Ok(())
1485 })
1486 });
1487
1488 thread
1489 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1490 .await
1491 .unwrap();
1492
1493 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1494 assert_eq!(
1495 output,
1496 indoc! {r#"
1497 ## User
1498
1499 Hello from Zed!
1500
1501 ## Assistant
1502
1503 <thinking>
1504 Thinking hard!
1505 </thinking>
1506
1507 "#}
1508 );
1509 }
1510
1511 #[gpui::test]
1512 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1513 init_test(cx);
1514
1515 let fs = FakeFs::new(cx.executor());
1516 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1517 .await;
1518 let project = Project::test(fs.clone(), [], cx).await;
1519 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1520 let (worktree, pathbuf) = project
1521 .update(cx, |project, cx| {
1522 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1523 })
1524 .await
1525 .unwrap();
1526 let buffer = project
1527 .update(cx, |project, cx| {
1528 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1529 })
1530 .await
1531 .unwrap();
1532
1533 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1534 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1535
1536 fake_server.update(cx, |fake_server, _| {
1537 fake_server.on_user_message(move |_, server, mut cx| {
1538 let read_file_tx = read_file_tx.clone();
1539 async move {
1540 let content = server
1541 .update(&mut cx, |server, _| {
1542 server.send_to_zed(acp_old::ReadTextFileParams {
1543 path: path!("/tmp/foo").into(),
1544 line: None,
1545 limit: None,
1546 })
1547 })?
1548 .await
1549 .unwrap();
1550 assert_eq!(content.content, "one\ntwo\nthree\n");
1551 read_file_tx.take().unwrap().send(()).unwrap();
1552 server
1553 .update(&mut cx, |server, _| {
1554 server.send_to_zed(acp_old::WriteTextFileParams {
1555 path: path!("/tmp/foo").into(),
1556 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1557 })
1558 })?
1559 .await
1560 .unwrap();
1561 Ok(())
1562 }
1563 })
1564 });
1565
1566 let request = thread.update(cx, |thread, cx| {
1567 thread.send_raw("Extend the count in /tmp/foo", cx)
1568 });
1569 read_file_rx.await.ok();
1570 buffer.update(cx, |buffer, cx| {
1571 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1572 });
1573 cx.run_until_parked();
1574 assert_eq!(
1575 buffer.read_with(cx, |buffer, _| buffer.text()),
1576 "zero\none\ntwo\nthree\nfour\nfive\n"
1577 );
1578 assert_eq!(
1579 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1580 "zero\none\ntwo\nthree\nfour\nfive\n"
1581 );
1582 request.await.unwrap();
1583 }
1584
1585 #[gpui::test]
1586 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1587 init_test(cx);
1588
1589 let fs = FakeFs::new(cx.executor());
1590 let project = Project::test(fs, [], cx).await;
1591 let (thread, fake_server) = fake_acp_thread(project, cx);
1592
1593 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1594
1595 let tool_call_id = Rc::new(RefCell::new(None));
1596 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1597 fake_server.update(cx, |fake_server, _| {
1598 let tool_call_id = tool_call_id.clone();
1599 fake_server.on_user_message(move |_, server, mut cx| {
1600 let end_turn_rx = end_turn_rx.clone();
1601 let tool_call_id = tool_call_id.clone();
1602 async move {
1603 let tool_call_result = server
1604 .update(&mut cx, |server, _| {
1605 server.send_to_zed(acp_old::PushToolCallParams {
1606 label: "Fetch".to_string(),
1607 icon: acp_old::Icon::Globe,
1608 content: None,
1609 locations: vec![],
1610 })
1611 })?
1612 .await
1613 .unwrap();
1614 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1615 end_turn_rx.take().unwrap().await.ok();
1616
1617 Ok(())
1618 }
1619 })
1620 });
1621
1622 let request = thread.update(cx, |thread, cx| {
1623 thread.send_raw("Fetch https://example.com", cx)
1624 });
1625
1626 run_until_first_tool_call(&thread, cx).await;
1627
1628 thread.read_with(cx, |thread, _| {
1629 assert!(matches!(
1630 thread.entries[1],
1631 AgentThreadEntry::ToolCall(ToolCall {
1632 status: ToolCallStatus::Allowed {
1633 status: acp::ToolCallStatus::InProgress,
1634 ..
1635 },
1636 ..
1637 })
1638 ));
1639 });
1640
1641 cx.run_until_parked();
1642
1643 thread
1644 .update(cx, |thread, cx| thread.cancel(cx))
1645 .await
1646 .unwrap();
1647
1648 thread.read_with(cx, |thread, _| {
1649 assert!(matches!(
1650 &thread.entries[1],
1651 AgentThreadEntry::ToolCall(ToolCall {
1652 status: ToolCallStatus::Canceled,
1653 ..
1654 })
1655 ));
1656 });
1657
1658 fake_server
1659 .update(cx, |fake_server, _| {
1660 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1661 tool_call_id: tool_call_id.borrow().unwrap(),
1662 status: acp_old::ToolCallStatus::Finished,
1663 content: None,
1664 })
1665 })
1666 .await
1667 .unwrap();
1668
1669 drop(end_turn_tx);
1670 request.await.unwrap();
1671
1672 thread.read_with(cx, |thread, _| {
1673 assert!(matches!(
1674 thread.entries[1],
1675 AgentThreadEntry::ToolCall(ToolCall {
1676 status: ToolCallStatus::Allowed {
1677 status: acp::ToolCallStatus::Completed,
1678 ..
1679 },
1680 ..
1681 })
1682 ));
1683 });
1684 }
1685
1686 async fn run_until_first_tool_call(
1687 thread: &Entity<AcpThread>,
1688 cx: &mut TestAppContext,
1689 ) -> usize {
1690 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1691
1692 let subscription = cx.update(|cx| {
1693 cx.subscribe(thread, move |thread, _, cx| {
1694 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1695 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1696 return tx.try_send(ix).unwrap();
1697 }
1698 }
1699 })
1700 });
1701
1702 select! {
1703 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1704 panic!("Timeout waiting for tool call")
1705 }
1706 ix = rx.next().fuse() => {
1707 drop(subscription);
1708 ix.unwrap()
1709 }
1710 }
1711 }
1712
1713 pub fn fake_acp_thread(
1714 project: Entity<Project>,
1715 cx: &mut TestAppContext,
1716 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1717 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1718 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1719
1720 let thread = cx.new(|cx| {
1721 let foreground_executor = cx.foreground_executor().clone();
1722 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1723 OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1724 stdin_tx,
1725 stdout_rx,
1726 move |fut| {
1727 foreground_executor.spawn(fut).detach();
1728 },
1729 );
1730
1731 let io_task = cx.background_spawn({
1732 async move {
1733 io_fut.await.log_err();
1734 Ok(())
1735 }
1736 });
1737 AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1738 });
1739 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1740 (thread, agent)
1741 }
1742
1743 pub struct FakeAcpServer {
1744 connection: acp_old::ClientConnection,
1745
1746 _io_task: Task<()>,
1747 on_user_message: Option<
1748 Rc<
1749 dyn Fn(
1750 acp_old::SendUserMessageParams,
1751 Entity<FakeAcpServer>,
1752 AsyncApp,
1753 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1754 >,
1755 >,
1756 }
1757
1758 #[derive(Clone)]
1759 struct FakeAgent {
1760 server: Entity<FakeAcpServer>,
1761 cx: AsyncApp,
1762 }
1763
1764 impl acp_old::Agent for FakeAgent {
1765 async fn initialize(
1766 &self,
1767 params: acp_old::InitializeParams,
1768 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1769 Ok(acp_old::InitializeResponse {
1770 protocol_version: params.protocol_version,
1771 is_authenticated: true,
1772 })
1773 }
1774
1775 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1776 Ok(())
1777 }
1778
1779 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1780 Ok(())
1781 }
1782
1783 async fn send_user_message(
1784 &self,
1785 request: acp_old::SendUserMessageParams,
1786 ) -> Result<(), acp_old::Error> {
1787 let mut cx = self.cx.clone();
1788 let handler = self
1789 .server
1790 .update(&mut cx, |server, _| server.on_user_message.clone())
1791 .ok()
1792 .flatten();
1793 if let Some(handler) = handler {
1794 handler(request, self.server.clone(), self.cx.clone()).await
1795 } else {
1796 Err(anyhow::anyhow!("No handler for on_user_message").into())
1797 }
1798 }
1799 }
1800
1801 impl FakeAcpServer {
1802 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1803 let agent = FakeAgent {
1804 server: cx.entity(),
1805 cx: cx.to_async(),
1806 };
1807 let foreground_executor = cx.foreground_executor().clone();
1808
1809 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1810 agent.clone(),
1811 stdout,
1812 stdin,
1813 move |fut| {
1814 foreground_executor.spawn(fut).detach();
1815 },
1816 );
1817 FakeAcpServer {
1818 connection: connection,
1819 on_user_message: None,
1820 _io_task: cx.background_spawn(async move {
1821 io_fut.await.log_err();
1822 }),
1823 }
1824 }
1825
1826 fn on_user_message<F>(
1827 &mut self,
1828 handler: impl for<'a> Fn(
1829 acp_old::SendUserMessageParams,
1830 Entity<FakeAcpServer>,
1831 AsyncApp,
1832 ) -> F
1833 + 'static,
1834 ) where
1835 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1836 {
1837 self.on_user_message
1838 .replace(Rc::new(move |request, server, cx| {
1839 handler(request, server, cx).boxed_local()
1840 }));
1841 }
1842
1843 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1844 &self,
1845 message: T,
1846 ) -> BoxedLocal<Result<T::Response>> {
1847 self.connection
1848 .request(message)
1849 .map(|f| f.map_err(|err| anyhow!(err)))
1850 .boxed_local()
1851 }
1852 }
1853}