1mod connection;
2pub use connection::*;
3
4use agent_client_protocol as acp;
5use agentic_coding_protocol as acp_old;
6use anyhow::{Context as _, Result};
7use assistant_tool::ActionLog;
8use buffer_diff::BufferDiff;
9use editor::{Bias, MultiBuffer, PathKey};
10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
11use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
12use itertools::{Itertools, PeekingNext};
13use language::{
14 Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
15 text_diff,
16};
17use markdown::Markdown;
18use project::{AgentLocation, Project};
19use std::cell::RefCell;
20use std::collections::HashMap;
21use std::error::Error;
22use std::fmt::Formatter;
23use std::rc::Rc;
24use std::{
25 fmt::Display,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub content: ContentBlock,
36}
37
38impl UserMessage {
39 pub fn from_acp(
40 message: impl IntoIterator<Item = acp::ContentBlock>,
41 language_registry: Arc<LanguageRegistry>,
42 cx: &mut App,
43 ) -> Self {
44 let mut content = ContentBlock::Empty;
45 for chunk in message {
46 content.append(chunk, &language_registry, cx)
47 }
48 Self { content: content }
49 }
50
51 fn to_markdown(&self, cx: &App) -> String {
52 format!("## User\n{}\n", self.content.to_markdown(cx))
53 }
54}
55
56#[derive(Debug)]
57pub struct MentionPath<'a>(&'a Path);
58
59impl<'a> MentionPath<'a> {
60 const PREFIX: &'static str = "@file:";
61
62 pub fn new(path: &'a Path) -> Self {
63 MentionPath(path)
64 }
65
66 pub fn try_parse(url: &'a str) -> Option<Self> {
67 let path = url.strip_prefix(Self::PREFIX)?;
68 Some(MentionPath(Path::new(path)))
69 }
70
71 pub fn path(&self) -> &Path {
72 self.0
73 }
74}
75
76impl Display for MentionPath<'_> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "[@{}]({}{})",
81 self.0.file_name().unwrap_or_default().display(),
82 Self::PREFIX,
83 self.0.display()
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub struct AssistantMessage {
90 pub chunks: Vec<AssistantMessageChunk>,
91}
92
93impl AssistantMessage {
94 pub fn to_markdown(&self, cx: &App) -> String {
95 format!(
96 "## Assistant\n\n{}\n\n",
97 self.chunks
98 .iter()
99 .map(|chunk| chunk.to_markdown(cx))
100 .join("\n\n")
101 )
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AssistantMessageChunk {
107 Message { block: ContentBlock },
108 Thought { block: ContentBlock },
109}
110
111impl AssistantMessageChunk {
112 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
113 Self::Message {
114 block: ContentBlock::new(chunk.into(), language_registry, cx),
115 }
116 }
117
118 fn to_markdown(&self, cx: &App) -> String {
119 match self {
120 Self::Message { block } => block.to_markdown(cx).to_string(),
121 Self::Thought { block } => {
122 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
129pub enum AgentThreadEntry {
130 UserMessage(UserMessage),
131 AssistantMessage(AssistantMessage),
132 ToolCall(ToolCall),
133}
134
135impl AgentThreadEntry {
136 fn to_markdown(&self, cx: &App) -> String {
137 match self {
138 Self::UserMessage(message) => message.to_markdown(cx),
139 Self::AssistantMessage(message) => message.to_markdown(cx),
140 Self::ToolCall(too_call) => too_call.to_markdown(cx),
141 }
142 }
143
144 // todo! return all diffs?
145 pub fn first_diff(&self) -> Option<&Diff> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 call.first_diff()
148 } else {
149 None
150 }
151 }
152
153 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
154 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
155 Some(locations)
156 } else {
157 None
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct ToolCall {
164 pub id: acp::ToolCallId,
165 pub label: Entity<Markdown>,
166 pub kind: acp::ToolKind,
167 pub content: Vec<ToolCallContent>,
168 pub status: ToolCallStatus,
169 pub locations: Vec<acp::ToolCallLocation>,
170}
171
172impl ToolCall {
173 fn from_acp(
174 tool_call: acp::ToolCall,
175 status: ToolCallStatus,
176 language_registry: Arc<LanguageRegistry>,
177 cx: &mut App,
178 ) -> Self {
179 Self {
180 id: tool_call.id,
181 label: cx.new(|cx| {
182 Markdown::new(
183 tool_call.label.into(),
184 Some(language_registry.clone()),
185 None,
186 cx,
187 )
188 }),
189 kind: tool_call.kind,
190 content: tool_call
191 .content
192 .into_iter()
193 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
194 .collect(),
195 locations: tool_call.locations,
196 status,
197 }
198 }
199
200 pub fn first_diff(&self) -> Option<&Diff> {
201 self.content.iter().find_map(|content| match content {
202 ToolCallContent::ContentBlock { .. } => None,
203 ToolCallContent::Diff { diff } => Some(diff),
204 })
205 }
206
207 fn to_markdown(&self, cx: &App) -> String {
208 let mut markdown = format!(
209 "**Tool Call: {}**\nStatus: {}\n\n",
210 self.label.read(cx).source(),
211 self.status
212 );
213 for content in &self.content {
214 markdown.push_str(content.to_markdown(cx).as_str());
215 markdown.push_str("\n\n");
216 }
217 markdown
218 }
219}
220
221#[derive(Debug)]
222pub enum ToolCallStatus {
223 WaitingForConfirmation {
224 options: Vec<acp::PermissionOption>,
225 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
226 },
227 Allowed {
228 status: acp::ToolCallStatus,
229 },
230 Rejected,
231 Canceled,
232}
233
234impl Display for ToolCallStatus {
235 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236 write!(
237 f,
238 "{}",
239 match self {
240 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
241 ToolCallStatus::Allowed { status } => match status {
242 acp::ToolCallStatus::InProgress => "In Progress",
243 acp::ToolCallStatus::Completed => "Completed",
244 acp::ToolCallStatus::Failed => "Failed",
245 },
246 ToolCallStatus::Rejected => "Rejected",
247 ToolCallStatus::Canceled => "Canceled",
248 }
249 )
250 }
251}
252
253#[derive(Debug, PartialEq, Clone)]
254pub enum ContentBlock {
255 Empty,
256 Markdown { markdown: Entity<Markdown> },
257}
258
259impl ContentBlock {
260 pub fn new(
261 block: acp::ContentBlock,
262 language_registry: &Arc<LanguageRegistry>,
263 cx: &mut App,
264 ) -> Self {
265 let mut this = Self::Empty;
266 this.append(block, language_registry, cx);
267 this
268 }
269
270 pub fn new_combined(
271 blocks: impl IntoIterator<Item = acp::ContentBlock>,
272 language_registry: Arc<LanguageRegistry>,
273 cx: &mut App,
274 ) -> Self {
275 let mut this = Self::Empty;
276 for block in blocks {
277 this.append(block, &language_registry, cx);
278 }
279 this
280 }
281
282 pub fn append(
283 &mut self,
284 block: acp::ContentBlock,
285 language_registry: &Arc<LanguageRegistry>,
286 cx: &mut App,
287 ) {
288 let new_content = match block {
289 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
290 acp::ContentBlock::ResourceLink(resource_link) => {
291 if let Some(path) = resource_link.uri.strip_prefix("file://") {
292 format!("{}", MentionPath(path.as_ref()))
293 } else {
294 resource_link.uri.clone()
295 }
296 }
297 acp::ContentBlock::Image(_)
298 | acp::ContentBlock::Audio(_)
299 | acp::ContentBlock::Resource(_) => String::new(),
300 };
301
302 match self {
303 ContentBlock::Empty => {
304 *self = ContentBlock::Markdown {
305 markdown: cx.new(|cx| {
306 Markdown::new(
307 new_content.into(),
308 Some(language_registry.clone()),
309 None,
310 cx,
311 )
312 }),
313 };
314 }
315 ContentBlock::Markdown { markdown } => {
316 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
317 }
318 }
319 }
320
321 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
322 match self {
323 ContentBlock::Empty => "",
324 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
325 }
326 }
327
328 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
329 match self {
330 ContentBlock::Empty => None,
331 ContentBlock::Markdown { markdown } => Some(markdown),
332 }
333 }
334}
335
336#[derive(Debug)]
337pub enum ToolCallContent {
338 ContentBlock { content: ContentBlock },
339 Diff { diff: Diff },
340}
341
342impl ToolCallContent {
343 pub fn from_acp(
344 content: acp::ToolCallContent,
345 language_registry: Arc<LanguageRegistry>,
346 cx: &mut App,
347 ) -> Self {
348 match content {
349 acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock {
350 content: ContentBlock::new(content, &language_registry, cx),
351 },
352 acp::ToolCallContent::Diff { diff } => Self::Diff {
353 diff: Diff::from_acp(diff, language_registry, cx),
354 },
355 }
356 }
357
358 pub fn from_acp_contents(
359 content: Vec<acp::ToolCallContent>,
360 language_registry: Arc<LanguageRegistry>,
361 cx: &mut App,
362 ) -> Vec<Self> {
363 content
364 .into_iter()
365 .peekable()
366 .batching(|it| match it.next()? {
367 acp::ToolCallContent::ContentBlock { content } => {
368 let mut block = ContentBlock::new(content, &language_registry, cx);
369 while let Some(acp::ToolCallContent::ContentBlock { content }) =
370 it.peeking_next(|c| matches!(c, acp::ToolCallContent::ContentBlock { .. }))
371 {
372 block.append(content, &language_registry, cx);
373 }
374 Some(ToolCallContent::ContentBlock { content: block })
375 }
376 content @ acp::ToolCallContent::Diff { .. } => Some(ToolCallContent::from_acp(
377 content,
378 language_registry.clone(),
379 cx,
380 )),
381 })
382 .collect()
383 }
384
385 pub fn to_markdown(&self, cx: &App) -> String {
386 match self {
387 Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
388 Self::Diff { diff } => diff.to_markdown(cx),
389 }
390 }
391}
392
393#[derive(Debug)]
394pub struct Diff {
395 pub multibuffer: Entity<MultiBuffer>,
396 pub path: PathBuf,
397 pub new_buffer: Entity<Buffer>,
398 pub old_buffer: Entity<Buffer>,
399 _task: Task<Result<()>>,
400}
401
402impl Diff {
403 pub fn from_acp(
404 diff: acp::Diff,
405 language_registry: Arc<LanguageRegistry>,
406 cx: &mut App,
407 ) -> Self {
408 let acp::Diff {
409 path,
410 old_text,
411 new_text,
412 } = diff;
413
414 let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
415
416 let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
417 let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
418 let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
419 let old_buffer_snapshot = old_buffer.read(cx).snapshot();
420 let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
421 let diff_task = buffer_diff.update(cx, |diff, cx| {
422 diff.set_base_text(
423 old_buffer_snapshot,
424 Some(language_registry.clone()),
425 new_buffer_snapshot,
426 cx,
427 )
428 });
429
430 let task = cx.spawn({
431 let multibuffer = multibuffer.clone();
432 let path = path.clone();
433 let new_buffer = new_buffer.clone();
434 async move |cx| {
435 diff_task.await?;
436
437 multibuffer
438 .update(cx, |multibuffer, cx| {
439 let hunk_ranges = {
440 let buffer = new_buffer.read(cx);
441 let diff = buffer_diff.read(cx);
442 diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
443 .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
444 .collect::<Vec<_>>()
445 };
446
447 multibuffer.set_excerpts_for_path(
448 PathKey::for_buffer(&new_buffer, cx),
449 new_buffer.clone(),
450 hunk_ranges,
451 editor::DEFAULT_MULTIBUFFER_CONTEXT,
452 cx,
453 );
454 multibuffer.add_diff(buffer_diff.clone(), cx);
455 })
456 .log_err();
457
458 if let Some(language) = language_registry
459 .language_for_file_path(&path)
460 .await
461 .log_err()
462 {
463 new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
464 }
465
466 anyhow::Ok(())
467 }
468 });
469
470 Self {
471 multibuffer,
472 path,
473 new_buffer,
474 old_buffer,
475 _task: task,
476 }
477 }
478
479 fn to_markdown(&self, cx: &App) -> String {
480 let buffer_text = self
481 .multibuffer
482 .read(cx)
483 .all_buffers()
484 .iter()
485 .map(|buffer| buffer.read(cx).text())
486 .join("\n");
487 format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
488 }
489}
490
491#[derive(Debug, Default)]
492pub struct Plan {
493 pub entries: Vec<PlanEntry>,
494}
495
496#[derive(Debug)]
497pub struct PlanStats<'a> {
498 pub in_progress_entry: Option<&'a PlanEntry>,
499 pub pending: u32,
500 pub completed: u32,
501}
502
503impl Plan {
504 pub fn is_empty(&self) -> bool {
505 self.entries.is_empty()
506 }
507
508 pub fn stats(&self) -> PlanStats<'_> {
509 let mut stats = PlanStats {
510 in_progress_entry: None,
511 pending: 0,
512 completed: 0,
513 };
514
515 for entry in &self.entries {
516 match &entry.status {
517 acp::PlanEntryStatus::Pending => {
518 stats.pending += 1;
519 }
520 acp::PlanEntryStatus::InProgress => {
521 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
522 }
523 acp::PlanEntryStatus::Completed => {
524 stats.completed += 1;
525 }
526 }
527 }
528
529 stats
530 }
531}
532
533#[derive(Debug)]
534pub struct PlanEntry {
535 pub content: Entity<Markdown>,
536 pub priority: acp::PlanEntryPriority,
537 pub status: acp::PlanEntryStatus,
538}
539
540impl PlanEntry {
541 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
542 Self {
543 content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
544 priority: entry.priority,
545 status: entry.status,
546 }
547 }
548}
549
550pub struct AcpThread {
551 title: SharedString,
552 entries: Vec<AgentThreadEntry>,
553 plan: Plan,
554 project: Entity<Project>,
555 action_log: Entity<ActionLog>,
556 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
557 send_task: Option<Task<()>>,
558 connection: Rc<dyn AgentConnection>,
559 child_status: Option<Task<Result<()>>>,
560 session_id: acp::SessionId,
561}
562
563pub enum AcpThreadEvent {
564 NewEntry,
565 EntryUpdated(usize),
566}
567
568impl EventEmitter<AcpThreadEvent> for AcpThread {}
569
570#[derive(PartialEq, Eq)]
571pub enum ThreadStatus {
572 Idle,
573 WaitingForToolConfirmation,
574 Generating,
575}
576
577#[derive(Debug, Clone)]
578pub enum LoadError {
579 Unsupported {
580 error_message: SharedString,
581 upgrade_message: SharedString,
582 upgrade_command: String,
583 },
584 Exited(i32),
585 Other(SharedString),
586}
587
588impl Display for LoadError {
589 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
590 match self {
591 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
592 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
593 LoadError::Other(msg) => write!(f, "{}", msg),
594 }
595 }
596}
597
598impl Error for LoadError {}
599
600impl AcpThread {
601 pub fn new(
602 connection: Rc<dyn AgentConnection>,
603 // todo! remove me
604 title: SharedString,
605 // todo! remove this?
606 child_status: Option<Task<Result<()>>>,
607 project: Entity<Project>,
608 session_id: acp::SessionId,
609 cx: &mut Context<Self>,
610 ) -> Self {
611 let action_log = cx.new(|_| ActionLog::new(project.clone()));
612
613 Self {
614 action_log,
615 shared_buffers: Default::default(),
616 entries: Default::default(),
617 plan: Default::default(),
618 title,
619 project,
620 send_task: None,
621 connection,
622 child_status,
623 session_id,
624 }
625 }
626
627 pub fn action_log(&self) -> &Entity<ActionLog> {
628 &self.action_log
629 }
630
631 pub fn project(&self) -> &Entity<Project> {
632 &self.project
633 }
634
635 pub fn title(&self) -> SharedString {
636 self.title.clone()
637 }
638
639 pub fn entries(&self) -> &[AgentThreadEntry] {
640 &self.entries
641 }
642
643 pub fn status(&self) -> ThreadStatus {
644 if self.send_task.is_some() {
645 if self.waiting_for_tool_confirmation() {
646 ThreadStatus::WaitingForToolConfirmation
647 } else {
648 ThreadStatus::Generating
649 }
650 } else {
651 ThreadStatus::Idle
652 }
653 }
654
655 pub fn has_pending_edit_tool_calls(&self) -> bool {
656 for entry in self.entries.iter().rev() {
657 match entry {
658 AgentThreadEntry::UserMessage(_) => return false,
659 AgentThreadEntry::ToolCall(call) if call.first_diff().is_some() => return true,
660 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
661 }
662 }
663
664 false
665 }
666
667 pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
668 self.entries.push(entry);
669 cx.emit(AcpThreadEvent::NewEntry);
670 }
671
672 pub fn push_assistant_chunk(
673 &mut self,
674 chunk: acp::ContentBlock,
675 is_thought: bool,
676 cx: &mut Context<Self>,
677 ) {
678 let language_registry = self.project.read(cx).languages().clone();
679 let entries_len = self.entries.len();
680 if let Some(last_entry) = self.entries.last_mut()
681 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
682 {
683 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
684 match (chunks.last_mut(), is_thought) {
685 (Some(AssistantMessageChunk::Message { block }), false)
686 | (Some(AssistantMessageChunk::Thought { block }), true) => {
687 block.append(chunk, &language_registry, cx)
688 }
689 _ => {
690 let block = ContentBlock::new(chunk, &language_registry, cx);
691 if is_thought {
692 chunks.push(AssistantMessageChunk::Thought { block })
693 } else {
694 chunks.push(AssistantMessageChunk::Message { block })
695 }
696 }
697 }
698 } else {
699 let block = ContentBlock::new(chunk, &language_registry, cx);
700 let chunk = if is_thought {
701 AssistantMessageChunk::Thought { block }
702 } else {
703 AssistantMessageChunk::Message { block }
704 };
705
706 self.push_entry(
707 AgentThreadEntry::AssistantMessage(AssistantMessage {
708 chunks: vec![chunk],
709 }),
710 cx,
711 );
712 }
713 }
714
715 pub fn update_tool_call(
716 &mut self,
717 id: acp::ToolCallId,
718 status: acp::ToolCallStatus,
719 content: Option<Vec<acp::ToolCallContent>>,
720 cx: &mut Context<Self>,
721 ) -> Result<()> {
722 let languages = self.project.read(cx).languages().clone();
723 let (ix, current_call) = self.tool_call_mut(&id).context("Tool call not found")?;
724
725 if let Some(content) = content {
726 current_call.content = content
727 .into_iter()
728 .map(|chunk| ToolCallContent::from_acp(chunk, languages.clone(), cx))
729 .collect();
730 }
731 current_call.status = ToolCallStatus::Allowed { status };
732
733 cx.emit(AcpThreadEvent::EntryUpdated(ix));
734
735 Ok(())
736 }
737
738 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
739 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
740 let status = ToolCallStatus::Allowed {
741 status: tool_call.status,
742 };
743 self.upsert_tool_call_inner(tool_call, status, cx)
744 }
745
746 pub fn upsert_tool_call_inner(
747 &mut self,
748 tool_call: acp::ToolCall,
749 status: ToolCallStatus,
750 cx: &mut Context<Self>,
751 ) {
752 let language_registry = self.project.read(cx).languages().clone();
753 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
754
755 let location = call.locations.last().cloned();
756
757 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
758 *current_call = call;
759
760 cx.emit(AcpThreadEvent::EntryUpdated(ix));
761 } else {
762 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
763 }
764
765 if let Some(location) = location {
766 self.set_project_location(location, cx)
767 }
768 }
769
770 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
771 // todo! use map
772 self.entries
773 .iter_mut()
774 .enumerate()
775 .rev()
776 .find_map(|(index, tool_call)| {
777 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
778 && &tool_call.id == id
779 {
780 Some((index, tool_call))
781 } else {
782 None
783 }
784 })
785 }
786
787 pub fn request_tool_call_permission(
788 &mut self,
789 tool_call: acp::ToolCall,
790 options: Vec<acp::PermissionOption>,
791 cx: &mut Context<Self>,
792 ) -> oneshot::Receiver<acp::PermissionOptionId> {
793 let (tx, rx) = oneshot::channel();
794
795 let status = ToolCallStatus::WaitingForConfirmation {
796 options,
797 respond_tx: tx,
798 };
799
800 self.upsert_tool_call_inner(tool_call, status, cx);
801 rx
802 }
803
804 pub fn authorize_tool_call(
805 &mut self,
806 id: acp::ToolCallId,
807 option_id: acp::PermissionOptionId,
808 option_kind: acp::PermissionOptionKind,
809 cx: &mut Context<Self>,
810 ) {
811 let Some((ix, call)) = self.tool_call_mut(&id) else {
812 return;
813 };
814
815 let new_status = match option_kind {
816 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
817 ToolCallStatus::Rejected
818 }
819 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
820 ToolCallStatus::Allowed {
821 status: acp::ToolCallStatus::InProgress,
822 }
823 }
824 };
825
826 let curr_status = mem::replace(&mut call.status, new_status);
827
828 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
829 respond_tx.send(option_id).log_err();
830 } else if cfg!(debug_assertions) {
831 panic!("tried to authorize an already authorized tool call");
832 }
833
834 cx.emit(AcpThreadEvent::EntryUpdated(ix));
835 }
836
837 pub fn plan(&self) -> &Plan {
838 &self.plan
839 }
840
841 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
842 self.plan = Plan {
843 entries: request
844 .entries
845 .into_iter()
846 .map(|entry| PlanEntry::from_acp(entry, cx))
847 .collect(),
848 };
849
850 cx.notify();
851 }
852
853 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
854 self.plan
855 .entries
856 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
857 cx.notify();
858 }
859
860 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
861 self.project.update(cx, |project, cx| {
862 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
863 return;
864 };
865 let buffer = project.open_buffer(path, cx);
866 cx.spawn(async move |project, cx| {
867 let buffer = buffer.await?;
868
869 project.update(cx, |project, cx| {
870 let position = if let Some(line) = location.line {
871 let snapshot = buffer.read(cx).snapshot();
872 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
873 snapshot.anchor_before(point)
874 } else {
875 Anchor::MIN
876 };
877
878 project.set_agent_location(
879 Some(AgentLocation {
880 buffer: buffer.downgrade(),
881 position,
882 }),
883 cx,
884 );
885 })
886 })
887 .detach_and_log_err(cx);
888 });
889 }
890
891 /// Returns true if the last turn is awaiting tool authorization
892 pub fn waiting_for_tool_confirmation(&self) -> bool {
893 for entry in self.entries.iter().rev() {
894 match &entry {
895 AgentThreadEntry::ToolCall(call) => match call.status {
896 ToolCallStatus::WaitingForConfirmation { .. } => return true,
897 ToolCallStatus::Allowed { .. }
898 | ToolCallStatus::Rejected
899 | ToolCallStatus::Canceled => continue,
900 },
901 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
902 // Reached the beginning of the turn
903 return false;
904 }
905 }
906 }
907 false
908 }
909
910 pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
911 self.connection.authenticate(cx)
912 }
913
914 #[cfg(any(test, feature = "test-support"))]
915 pub fn send_raw(
916 &mut self,
917 message: &str,
918 cx: &mut Context<Self>,
919 ) -> BoxFuture<'static, Result<()>> {
920 self.send(
921 vec![acp::ContentBlock::Text(acp::TextContent {
922 text: message.to_string(),
923 annotations: None,
924 })],
925 cx,
926 )
927 }
928
929 pub fn send(
930 &mut self,
931 message: Vec<acp::ContentBlock>,
932 cx: &mut Context<Self>,
933 ) -> BoxFuture<'static, Result<()>> {
934 let block = ContentBlock::new_combined(
935 message.clone(),
936 self.project.read(cx).languages().clone(),
937 cx,
938 );
939 self.push_entry(
940 AgentThreadEntry::UserMessage(UserMessage { content: block }),
941 cx,
942 );
943 self.clear_completed_plan_entries(cx);
944
945 let (tx, rx) = oneshot::channel();
946 let cancel_task = self.cancel(cx);
947
948 self.send_task = Some(cx.spawn(async move |this, cx| {
949 async {
950 cancel_task.await;
951
952 let result = this
953 .update(cx, |this, cx| {
954 this.connection.prompt(
955 acp::PromptToolArguments {
956 prompt: message,
957 session_id: this.session_id.clone(),
958 },
959 cx,
960 )
961 })?
962 .await;
963 tx.send(result).log_err();
964 this.update(cx, |this, _cx| this.send_task.take())?;
965 anyhow::Ok(())
966 }
967 .await
968 .log_err();
969 }));
970
971 async move {
972 match rx.await {
973 Ok(Err(e)) => Err(e)?,
974 _ => Ok(()),
975 }
976 }
977 .boxed()
978 }
979
980 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
981 let Some(send_task) = self.send_task.take() else {
982 return Task::ready(());
983 };
984
985 for entry in self.entries.iter_mut() {
986 if let AgentThreadEntry::ToolCall(call) = entry {
987 let cancel = matches!(
988 call.status,
989 ToolCallStatus::WaitingForConfirmation { .. }
990 | ToolCallStatus::Allowed {
991 status: acp::ToolCallStatus::InProgress
992 }
993 );
994
995 if cancel {
996 call.status = ToolCallStatus::Canceled;
997 }
998 }
999 }
1000
1001 self.connection.cancel(cx);
1002
1003 // Wait for the send task to complete
1004 cx.foreground_executor().spawn(send_task)
1005 }
1006
1007 pub fn read_text_file(
1008 &self,
1009 path: PathBuf,
1010 line: Option<u32>,
1011 limit: Option<u32>,
1012 reuse_shared_snapshot: bool,
1013 cx: &mut Context<Self>,
1014 ) -> Task<Result<String>> {
1015 let project = self.project.clone();
1016 let action_log = self.action_log.clone();
1017 cx.spawn(async move |this, cx| {
1018 let load = project.update(cx, |project, cx| {
1019 let path = project
1020 .project_path_for_absolute_path(&path, cx)
1021 .context("invalid path")?;
1022 anyhow::Ok(project.open_buffer(path, cx))
1023 });
1024 let buffer = load??.await?;
1025
1026 let snapshot = if reuse_shared_snapshot {
1027 this.read_with(cx, |this, _| {
1028 this.shared_buffers.get(&buffer.clone()).cloned()
1029 })
1030 .log_err()
1031 .flatten()
1032 } else {
1033 None
1034 };
1035
1036 let snapshot = if let Some(snapshot) = snapshot {
1037 snapshot
1038 } else {
1039 action_log.update(cx, |action_log, cx| {
1040 action_log.buffer_read(buffer.clone(), cx);
1041 })?;
1042 project.update(cx, |project, cx| {
1043 let position = buffer
1044 .read(cx)
1045 .snapshot()
1046 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1047 project.set_agent_location(
1048 Some(AgentLocation {
1049 buffer: buffer.downgrade(),
1050 position,
1051 }),
1052 cx,
1053 );
1054 })?;
1055
1056 buffer.update(cx, |buffer, _| buffer.snapshot())?
1057 };
1058
1059 this.update(cx, |this, _| {
1060 let text = snapshot.text();
1061 this.shared_buffers.insert(buffer.clone(), snapshot);
1062 if line.is_none() && limit.is_none() {
1063 return Ok(text);
1064 }
1065 let limit = limit.unwrap_or(u32::MAX) as usize;
1066 let Some(line) = line else {
1067 return Ok(text.lines().take(limit).collect::<String>());
1068 };
1069
1070 let count = text.lines().count();
1071 if count < line as usize {
1072 anyhow::bail!("There are only {} lines", count);
1073 }
1074 Ok(text
1075 .lines()
1076 .skip(line as usize + 1)
1077 .take(limit)
1078 .collect::<String>())
1079 })?
1080 })
1081 }
1082
1083 pub fn write_text_file(
1084 &self,
1085 path: PathBuf,
1086 content: String,
1087 cx: &mut Context<Self>,
1088 ) -> Task<Result<()>> {
1089 let project = self.project.clone();
1090 let action_log = self.action_log.clone();
1091 cx.spawn(async move |this, cx| {
1092 let load = project.update(cx, |project, cx| {
1093 let path = project
1094 .project_path_for_absolute_path(&path, cx)
1095 .context("invalid path")?;
1096 anyhow::Ok(project.open_buffer(path, cx))
1097 });
1098 let buffer = load??.await?;
1099 let snapshot = this.update(cx, |this, cx| {
1100 this.shared_buffers
1101 .get(&buffer)
1102 .cloned()
1103 .unwrap_or_else(|| buffer.read(cx).snapshot())
1104 })?;
1105 let edits = cx
1106 .background_executor()
1107 .spawn(async move {
1108 let old_text = snapshot.text();
1109 text_diff(old_text.as_str(), &content)
1110 .into_iter()
1111 .map(|(range, replacement)| {
1112 (
1113 snapshot.anchor_after(range.start)
1114 ..snapshot.anchor_before(range.end),
1115 replacement,
1116 )
1117 })
1118 .collect::<Vec<_>>()
1119 })
1120 .await;
1121 cx.update(|cx| {
1122 project.update(cx, |project, cx| {
1123 project.set_agent_location(
1124 Some(AgentLocation {
1125 buffer: buffer.downgrade(),
1126 position: edits
1127 .last()
1128 .map(|(range, _)| range.end)
1129 .unwrap_or(Anchor::MIN),
1130 }),
1131 cx,
1132 );
1133 });
1134
1135 action_log.update(cx, |action_log, cx| {
1136 action_log.buffer_read(buffer.clone(), cx);
1137 });
1138 buffer.update(cx, |buffer, cx| {
1139 buffer.edit(edits, None, cx);
1140 });
1141 action_log.update(cx, |action_log, cx| {
1142 action_log.buffer_edited(buffer.clone(), cx);
1143 });
1144 })?;
1145 project
1146 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1147 .await
1148 })
1149 }
1150
1151 pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1152 self.child_status.take()
1153 }
1154
1155 pub fn to_markdown(&self, cx: &App) -> String {
1156 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1157 }
1158}
1159
1160#[derive(Clone)]
1161pub struct OldAcpClientDelegate {
1162 thread: Rc<RefCell<WeakEntity<AcpThread>>>,
1163 cx: AsyncApp,
1164 next_tool_call_id: Rc<RefCell<u64>>,
1165 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1166}
1167
1168impl OldAcpClientDelegate {
1169 pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
1170 Self {
1171 thread,
1172 cx,
1173 next_tool_call_id: Rc::new(RefCell::new(0)),
1174 }
1175 }
1176}
1177
1178impl acp_old::Client for OldAcpClientDelegate {
1179 async fn stream_assistant_message_chunk(
1180 &self,
1181 params: acp_old::StreamAssistantMessageChunkParams,
1182 ) -> Result<(), acp_old::Error> {
1183 let cx = &mut self.cx.clone();
1184
1185 cx.update(|cx| {
1186 self.thread
1187 .borrow()
1188 .update(cx, |thread, cx| match params.chunk {
1189 acp_old::AssistantMessageChunk::Text { text } => {
1190 thread.push_assistant_chunk(text.into(), false, cx)
1191 }
1192 acp_old::AssistantMessageChunk::Thought { thought } => {
1193 thread.push_assistant_chunk(thought.into(), true, cx)
1194 }
1195 })
1196 .ok();
1197 })?;
1198
1199 Ok(())
1200 }
1201
1202 async fn request_tool_call_confirmation(
1203 &self,
1204 request: acp_old::RequestToolCallConfirmationParams,
1205 ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
1206 let cx = &mut self.cx.clone();
1207
1208 let old_acp_id = *self.next_tool_call_id.borrow() + 1;
1209 self.next_tool_call_id.replace(old_acp_id);
1210
1211 let tool_call = into_new_tool_call(
1212 acp::ToolCallId(old_acp_id.to_string().into()),
1213 request.tool_call,
1214 );
1215
1216 let mut options = match request.confirmation {
1217 acp_old::ToolCallConfirmation::Edit { .. } => vec![(
1218 acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
1219 acp::PermissionOptionKind::AllowAlways,
1220 "Always Allow Edits".to_string(),
1221 )],
1222 acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
1223 acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
1224 acp::PermissionOptionKind::AllowAlways,
1225 format!("Always Allow {}", root_command),
1226 )],
1227 acp_old::ToolCallConfirmation::Mcp {
1228 server_name,
1229 tool_name,
1230 ..
1231 } => vec![
1232 (
1233 acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
1234 acp::PermissionOptionKind::AllowAlways,
1235 format!("Always Allow {}", server_name),
1236 ),
1237 (
1238 acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
1239 acp::PermissionOptionKind::AllowAlways,
1240 format!("Always Allow {}", tool_name),
1241 ),
1242 ],
1243 acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
1244 acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
1245 acp::PermissionOptionKind::AllowAlways,
1246 "Always Allow".to_string(),
1247 )],
1248 acp_old::ToolCallConfirmation::Other { .. } => vec![(
1249 acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
1250 acp::PermissionOptionKind::AllowAlways,
1251 "Always Allow".to_string(),
1252 )],
1253 };
1254
1255 options.extend([
1256 (
1257 acp_old::ToolCallConfirmationOutcome::Allow,
1258 acp::PermissionOptionKind::AllowOnce,
1259 "Allow".to_string(),
1260 ),
1261 (
1262 acp_old::ToolCallConfirmationOutcome::Reject,
1263 acp::PermissionOptionKind::RejectOnce,
1264 "Reject".to_string(),
1265 ),
1266 ]);
1267
1268 let mut outcomes = Vec::with_capacity(options.len());
1269 let mut acp_options = Vec::with_capacity(options.len());
1270
1271 for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
1272 outcomes.push(outcome);
1273 acp_options.push(acp::PermissionOption {
1274 id: acp::PermissionOptionId(index.to_string().into()),
1275 label,
1276 kind,
1277 })
1278 }
1279
1280 let response = cx
1281 .update(|cx| {
1282 self.thread.borrow().update(cx, |thread, cx| {
1283 thread.request_tool_call_permission(tool_call, acp_options, cx)
1284 })
1285 })?
1286 .context("Failed to update thread")?
1287 .await;
1288
1289 let outcome = match response {
1290 Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
1291 Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
1292 };
1293
1294 Ok(acp_old::RequestToolCallConfirmationResponse {
1295 id: acp_old::ToolCallId(old_acp_id),
1296 outcome: outcome,
1297 })
1298 }
1299
1300 async fn push_tool_call(
1301 &self,
1302 request: acp_old::PushToolCallParams,
1303 ) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
1304 let cx = &mut self.cx.clone();
1305
1306 let old_acp_id = *self.next_tool_call_id.borrow() + 1;
1307 self.next_tool_call_id.replace(old_acp_id);
1308
1309 cx.update(|cx| {
1310 self.thread.borrow().update(cx, |thread, cx| {
1311 thread.upsert_tool_call(
1312 into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
1313 cx,
1314 )
1315 })
1316 })?
1317 .context("Failed to update thread")?;
1318
1319 Ok(acp_old::PushToolCallResponse {
1320 id: acp_old::ToolCallId(old_acp_id),
1321 })
1322 }
1323
1324 async fn update_tool_call(
1325 &self,
1326 request: acp_old::UpdateToolCallParams,
1327 ) -> Result<(), acp_old::Error> {
1328 let cx = &mut self.cx.clone();
1329
1330 cx.update(|cx| {
1331 self.thread.borrow().update(cx, |thread, cx| {
1332 let languages = thread.project.read(cx).languages().clone();
1333
1334 if let Some((ix, tool_call)) = thread
1335 .tool_call_mut(&acp::ToolCallId(request.tool_call_id.0.to_string().into()))
1336 {
1337 tool_call.status = ToolCallStatus::Allowed {
1338 status: into_new_tool_call_status(request.status),
1339 };
1340 tool_call.content = request
1341 .content
1342 .into_iter()
1343 .map(|content| {
1344 ToolCallContent::from_acp(
1345 into_new_tool_call_content(content),
1346 languages.clone(),
1347 cx,
1348 )
1349 })
1350 .collect();
1351
1352 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1353 anyhow::Ok(())
1354 } else {
1355 anyhow::bail!("Tool call not found")
1356 }
1357 })
1358 })?
1359 .context("Failed to update thread")??;
1360
1361 Ok(())
1362 }
1363
1364 async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
1365 let cx = &mut self.cx.clone();
1366
1367 cx.update(|cx| {
1368 self.thread.borrow().update(cx, |thread, cx| {
1369 thread.update_plan(
1370 acp::Plan {
1371 entries: request
1372 .entries
1373 .into_iter()
1374 .map(into_new_plan_entry)
1375 .collect(),
1376 },
1377 cx,
1378 )
1379 })
1380 })?
1381 .context("Failed to update thread")?;
1382
1383 Ok(())
1384 }
1385
1386 async fn read_text_file(
1387 &self,
1388 acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
1389 ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
1390 let content = self
1391 .cx
1392 .update(|cx| {
1393 self.thread.borrow().update(cx, |thread, cx| {
1394 thread.read_text_file(path, line, limit, false, cx)
1395 })
1396 })?
1397 .context("Failed to update thread")?
1398 .await?;
1399 Ok(acp_old::ReadTextFileResponse { content })
1400 }
1401
1402 async fn write_text_file(
1403 &self,
1404 acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
1405 ) -> Result<(), acp_old::Error> {
1406 self.cx
1407 .update(|cx| {
1408 self.thread
1409 .borrow()
1410 .update(cx, |thread, cx| thread.write_text_file(path, content, cx))
1411 })?
1412 .context("Failed to update thread")?
1413 .await?;
1414
1415 Ok(())
1416 }
1417}
1418
1419fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
1420 acp::ToolCall {
1421 id: id,
1422 label: request.label,
1423 kind: acp_kind_from_old_icon(request.icon),
1424 status: acp::ToolCallStatus::InProgress,
1425 content: request
1426 .content
1427 .into_iter()
1428 .map(into_new_tool_call_content)
1429 .collect(),
1430 locations: request
1431 .locations
1432 .into_iter()
1433 .map(into_new_tool_call_location)
1434 .collect(),
1435 }
1436}
1437
1438fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
1439 match icon {
1440 acp_old::Icon::FileSearch => acp::ToolKind::Search,
1441 acp_old::Icon::Folder => acp::ToolKind::Search,
1442 acp_old::Icon::Globe => acp::ToolKind::Search,
1443 acp_old::Icon::Hammer => acp::ToolKind::Other,
1444 acp_old::Icon::LightBulb => acp::ToolKind::Think,
1445 acp_old::Icon::Pencil => acp::ToolKind::Edit,
1446 acp_old::Icon::Regex => acp::ToolKind::Search,
1447 acp_old::Icon::Terminal => acp::ToolKind::Execute,
1448 }
1449}
1450
1451fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
1452 match status {
1453 acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
1454 acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
1455 acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
1456 }
1457}
1458
1459fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
1460 match content {
1461 acp_old::ToolCallContent::Markdown { markdown } => acp::ToolCallContent::ContentBlock {
1462 content: acp::ContentBlock::Text(acp::TextContent {
1463 annotations: None,
1464 text: markdown,
1465 }),
1466 },
1467 acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
1468 diff: into_new_diff(diff),
1469 },
1470 }
1471}
1472
1473fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
1474 acp::Diff {
1475 path: diff.path,
1476 old_text: diff.old_text,
1477 new_text: diff.new_text,
1478 }
1479}
1480
1481fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
1482 acp::ToolCallLocation {
1483 path: location.path,
1484 line: location.line,
1485 }
1486}
1487
1488fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
1489 acp::PlanEntry {
1490 content: entry.content,
1491 priority: into_new_plan_priority(entry.priority),
1492 status: into_new_plan_status(entry.status),
1493 }
1494}
1495
1496fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
1497 match priority {
1498 acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
1499 acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
1500 acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
1501 }
1502}
1503
1504fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
1505 match status {
1506 acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
1507 acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
1508 acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
1509 }
1510}
1511
1512#[cfg(test)]
1513mod tests {
1514 use super::*;
1515 use anyhow::anyhow;
1516 use async_pipe::{PipeReader, PipeWriter};
1517 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1518 use gpui::{AsyncApp, TestAppContext};
1519 use indoc::indoc;
1520 use project::FakeFs;
1521 use serde_json::json;
1522 use settings::SettingsStore;
1523 use smol::{future::BoxedLocal, stream::StreamExt as _};
1524 use std::{cell::RefCell, rc::Rc, time::Duration};
1525 use util::path;
1526
1527 fn init_test(cx: &mut TestAppContext) {
1528 env_logger::try_init().ok();
1529 cx.update(|cx| {
1530 let settings_store = SettingsStore::test(cx);
1531 cx.set_global(settings_store);
1532 Project::init_settings(cx);
1533 language::init(cx);
1534 });
1535 }
1536
1537 #[gpui::test]
1538 async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1539 init_test(cx);
1540
1541 let fs = FakeFs::new(cx.executor());
1542 let project = Project::test(fs, [], cx).await;
1543 let (thread, fake_server) = fake_acp_thread(project, cx);
1544
1545 fake_server.update(cx, |fake_server, _| {
1546 fake_server.on_user_message(move |_, server, mut cx| async move {
1547 server
1548 .update(&mut cx, |server, _| {
1549 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1550 chunk: acp_old::AssistantMessageChunk::Thought {
1551 thought: "Thinking ".into(),
1552 },
1553 })
1554 })?
1555 .await
1556 .unwrap();
1557 server
1558 .update(&mut cx, |server, _| {
1559 server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
1560 chunk: acp_old::AssistantMessageChunk::Thought {
1561 thought: "hard!".into(),
1562 },
1563 })
1564 })?
1565 .await
1566 .unwrap();
1567
1568 Ok(())
1569 })
1570 });
1571
1572 thread
1573 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1574 .await
1575 .unwrap();
1576
1577 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1578 assert_eq!(
1579 output,
1580 indoc! {r#"
1581 ## User
1582
1583 Hello from Zed!
1584
1585 ## Assistant
1586
1587 <thinking>
1588 Thinking hard!
1589 </thinking>
1590
1591 "#}
1592 );
1593 }
1594
1595 #[gpui::test]
1596 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1597 init_test(cx);
1598
1599 let fs = FakeFs::new(cx.executor());
1600 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1601 .await;
1602 let project = Project::test(fs.clone(), [], cx).await;
1603 let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1604 let (worktree, pathbuf) = project
1605 .update(cx, |project, cx| {
1606 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1607 })
1608 .await
1609 .unwrap();
1610 let buffer = project
1611 .update(cx, |project, cx| {
1612 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1613 })
1614 .await
1615 .unwrap();
1616
1617 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1618 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1619
1620 fake_server.update(cx, |fake_server, _| {
1621 fake_server.on_user_message(move |_, server, mut cx| {
1622 let read_file_tx = read_file_tx.clone();
1623 async move {
1624 let content = server
1625 .update(&mut cx, |server, _| {
1626 server.send_to_zed(acp_old::ReadTextFileParams {
1627 path: path!("/tmp/foo").into(),
1628 line: None,
1629 limit: None,
1630 })
1631 })?
1632 .await
1633 .unwrap();
1634 assert_eq!(content.content, "one\ntwo\nthree\n");
1635 read_file_tx.take().unwrap().send(()).unwrap();
1636 server
1637 .update(&mut cx, |server, _| {
1638 server.send_to_zed(acp_old::WriteTextFileParams {
1639 path: path!("/tmp/foo").into(),
1640 content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1641 })
1642 })?
1643 .await
1644 .unwrap();
1645 Ok(())
1646 }
1647 })
1648 });
1649
1650 let request = thread.update(cx, |thread, cx| {
1651 thread.send_raw("Extend the count in /tmp/foo", cx)
1652 });
1653 read_file_rx.await.ok();
1654 buffer.update(cx, |buffer, cx| {
1655 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1656 });
1657 cx.run_until_parked();
1658 assert_eq!(
1659 buffer.read_with(cx, |buffer, _| buffer.text()),
1660 "zero\none\ntwo\nthree\nfour\nfive\n"
1661 );
1662 assert_eq!(
1663 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1664 "zero\none\ntwo\nthree\nfour\nfive\n"
1665 );
1666 request.await.unwrap();
1667 }
1668
1669 #[gpui::test]
1670 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1671 init_test(cx);
1672
1673 let fs = FakeFs::new(cx.executor());
1674 let project = Project::test(fs, [], cx).await;
1675 let (thread, fake_server) = fake_acp_thread(project, cx);
1676
1677 let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1678
1679 let tool_call_id = Rc::new(RefCell::new(None));
1680 let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1681 fake_server.update(cx, |fake_server, _| {
1682 let tool_call_id = tool_call_id.clone();
1683 fake_server.on_user_message(move |_, server, mut cx| {
1684 let end_turn_rx = end_turn_rx.clone();
1685 let tool_call_id = tool_call_id.clone();
1686 async move {
1687 let tool_call_result = server
1688 .update(&mut cx, |server, _| {
1689 server.send_to_zed(acp_old::PushToolCallParams {
1690 label: "Fetch".to_string(),
1691 icon: acp_old::Icon::Globe,
1692 content: None,
1693 locations: vec![],
1694 })
1695 })?
1696 .await
1697 .unwrap();
1698 *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1699 end_turn_rx.take().unwrap().await.ok();
1700
1701 Ok(())
1702 }
1703 })
1704 });
1705
1706 let request = thread.update(cx, |thread, cx| {
1707 thread.send_raw("Fetch https://example.com", cx)
1708 });
1709
1710 run_until_first_tool_call(&thread, cx).await;
1711
1712 thread.read_with(cx, |thread, _| {
1713 assert!(matches!(
1714 thread.entries[1],
1715 AgentThreadEntry::ToolCall(ToolCall {
1716 status: ToolCallStatus::Allowed {
1717 status: acp::ToolCallStatus::InProgress,
1718 ..
1719 },
1720 ..
1721 })
1722 ));
1723 });
1724
1725 cx.run_until_parked();
1726
1727 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1728
1729 thread.read_with(cx, |thread, _| {
1730 assert!(matches!(
1731 &thread.entries[1],
1732 AgentThreadEntry::ToolCall(ToolCall {
1733 status: ToolCallStatus::Canceled,
1734 ..
1735 })
1736 ));
1737 });
1738
1739 fake_server
1740 .update(cx, |fake_server, _| {
1741 fake_server.send_to_zed(acp_old::UpdateToolCallParams {
1742 tool_call_id: tool_call_id.borrow().unwrap(),
1743 status: acp_old::ToolCallStatus::Finished,
1744 content: None,
1745 })
1746 })
1747 .await
1748 .unwrap();
1749
1750 drop(end_turn_tx);
1751 request.await.unwrap();
1752
1753 thread.read_with(cx, |thread, _| {
1754 assert!(matches!(
1755 thread.entries[1],
1756 AgentThreadEntry::ToolCall(ToolCall {
1757 status: ToolCallStatus::Allowed {
1758 status: acp::ToolCallStatus::Completed,
1759 ..
1760 },
1761 ..
1762 })
1763 ));
1764 });
1765 }
1766
1767 async fn run_until_first_tool_call(
1768 thread: &Entity<AcpThread>,
1769 cx: &mut TestAppContext,
1770 ) -> usize {
1771 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1772
1773 let subscription = cx.update(|cx| {
1774 cx.subscribe(thread, move |thread, _, cx| {
1775 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1776 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1777 return tx.try_send(ix).unwrap();
1778 }
1779 }
1780 })
1781 });
1782
1783 select! {
1784 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1785 panic!("Timeout waiting for tool call")
1786 }
1787 ix = rx.next().fuse() => {
1788 drop(subscription);
1789 ix.unwrap()
1790 }
1791 }
1792 }
1793
1794 pub fn fake_acp_thread(
1795 project: Entity<Project>,
1796 cx: &mut TestAppContext,
1797 ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1798 let (stdin_tx, stdin_rx) = async_pipe::pipe();
1799 let (stdout_tx, stdout_rx) = async_pipe::pipe();
1800
1801 let thread = cx.new(|cx| {
1802 let foreground_executor = cx.foreground_executor().clone();
1803 let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade()));
1804
1805 let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
1806 OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()),
1807 stdin_tx,
1808 stdout_rx,
1809 move |fut| {
1810 foreground_executor.spawn(fut).detach();
1811 },
1812 );
1813
1814 let io_task = cx.background_spawn({
1815 async move {
1816 io_fut.await.log_err();
1817 Ok(())
1818 }
1819 });
1820 let connection = OldAcpAgentConnection {
1821 connection,
1822 child_status: io_task,
1823 thread: thread_rc,
1824 };
1825
1826 AcpThread::new(
1827 Rc::new(connection),
1828 "Test".into(),
1829 None,
1830 project,
1831 acp::SessionId("test".into()),
1832 cx,
1833 )
1834 });
1835 let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1836 (thread, agent)
1837 }
1838
1839 pub struct FakeAcpServer {
1840 connection: acp_old::ClientConnection,
1841
1842 _io_task: Task<()>,
1843 on_user_message: Option<
1844 Rc<
1845 dyn Fn(
1846 acp_old::SendUserMessageParams,
1847 Entity<FakeAcpServer>,
1848 AsyncApp,
1849 ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
1850 >,
1851 >,
1852 }
1853
1854 #[derive(Clone)]
1855 struct FakeAgent {
1856 server: Entity<FakeAcpServer>,
1857 cx: AsyncApp,
1858 }
1859
1860 impl acp_old::Agent for FakeAgent {
1861 async fn initialize(
1862 &self,
1863 params: acp_old::InitializeParams,
1864 ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
1865 Ok(acp_old::InitializeResponse {
1866 protocol_version: params.protocol_version,
1867 is_authenticated: true,
1868 })
1869 }
1870
1871 async fn authenticate(&self) -> Result<(), acp_old::Error> {
1872 Ok(())
1873 }
1874
1875 async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
1876 Ok(())
1877 }
1878
1879 async fn send_user_message(
1880 &self,
1881 request: acp_old::SendUserMessageParams,
1882 ) -> Result<(), acp_old::Error> {
1883 let mut cx = self.cx.clone();
1884 let handler = self
1885 .server
1886 .update(&mut cx, |server, _| server.on_user_message.clone())
1887 .ok()
1888 .flatten();
1889 if let Some(handler) = handler {
1890 handler(request, self.server.clone(), self.cx.clone()).await
1891 } else {
1892 Err(anyhow::anyhow!("No handler for on_user_message").into())
1893 }
1894 }
1895 }
1896
1897 impl FakeAcpServer {
1898 fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1899 let agent = FakeAgent {
1900 server: cx.entity(),
1901 cx: cx.to_async(),
1902 };
1903 let foreground_executor = cx.foreground_executor().clone();
1904
1905 let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
1906 agent.clone(),
1907 stdout,
1908 stdin,
1909 move |fut| {
1910 foreground_executor.spawn(fut).detach();
1911 },
1912 );
1913 FakeAcpServer {
1914 connection: connection,
1915 on_user_message: None,
1916 _io_task: cx.background_spawn(async move {
1917 io_fut.await.log_err();
1918 }),
1919 }
1920 }
1921
1922 fn on_user_message<F>(
1923 &mut self,
1924 handler: impl for<'a> Fn(
1925 acp_old::SendUserMessageParams,
1926 Entity<FakeAcpServer>,
1927 AsyncApp,
1928 ) -> F
1929 + 'static,
1930 ) where
1931 F: Future<Output = Result<(), acp_old::Error>> + 'static,
1932 {
1933 self.on_user_message
1934 .replace(Rc::new(move |request, server, cx| {
1935 handler(request, server, cx).boxed_local()
1936 }));
1937 }
1938
1939 fn send_to_zed<T: acp_old::ClientRequest + 'static>(
1940 &self,
1941 message: T,
1942 ) -> BoxedLocal<Result<T::Response>> {
1943 self.connection
1944 .request(message)
1945 .map(|f| f.map_err(|err| anyhow!(err)))
1946 .boxed_local()
1947 }
1948 }
1949}