1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol::{self as acp};
13use anyhow::{Context as _, Result};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::Formatter;
24use std::process::ExitStatus;
25use std::rc::Rc;
26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
27use ui::App;
28use util::ResultExt;
29
30#[derive(Debug)]
31pub struct UserMessage {
32 pub content: ContentBlock,
33}
34
35impl UserMessage {
36 pub fn from_acp(
37 message: impl IntoIterator<Item = acp::ContentBlock>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut App,
40 ) -> Self {
41 let mut content = ContentBlock::Empty;
42 for chunk in message {
43 content.append(chunk, &language_registry, cx)
44 }
45 Self { content: content }
46 }
47
48 fn to_markdown(&self, cx: &App) -> String {
49 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
126 if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
127 Some(locations)
128 } else {
129 None
130 }
131 }
132}
133
134#[derive(Debug)]
135pub struct ToolCall {
136 pub id: acp::ToolCallId,
137 pub label: Entity<Markdown>,
138 pub kind: acp::ToolKind,
139 pub content: Vec<ToolCallContent>,
140 pub status: ToolCallStatus,
141 pub locations: Vec<acp::ToolCallLocation>,
142 pub raw_input: Option<serde_json::Value>,
143 pub raw_output: Option<serde_json::Value>,
144}
145
146impl ToolCall {
147 fn from_acp(
148 tool_call: acp::ToolCall,
149 status: ToolCallStatus,
150 language_registry: Arc<LanguageRegistry>,
151 cx: &mut App,
152 ) -> Self {
153 Self {
154 id: tool_call.id,
155 label: cx.new(|cx| {
156 Markdown::new(
157 tool_call.title.into(),
158 Some(language_registry.clone()),
159 None,
160 cx,
161 )
162 }),
163 kind: tool_call.kind,
164 content: tool_call
165 .content
166 .into_iter()
167 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
168 .collect(),
169 locations: tool_call.locations,
170 status,
171 raw_input: tool_call.raw_input,
172 raw_output: tool_call.raw_output,
173 }
174 }
175
176 fn update_fields(
177 &mut self,
178 fields: acp::ToolCallUpdateFields,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) {
182 let acp::ToolCallUpdateFields {
183 kind,
184 status,
185 title,
186 content,
187 locations,
188 raw_input,
189 raw_output,
190 } = fields;
191
192 if let Some(kind) = kind {
193 self.kind = kind;
194 }
195
196 if let Some(status) = status {
197 self.status = ToolCallStatus::Allowed { status };
198 }
199
200 if let Some(title) = title {
201 self.label.update(cx, |label, cx| {
202 label.replace(title, cx);
203 });
204 }
205
206 if let Some(content) = content {
207 self.content = content
208 .into_iter()
209 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
210 .collect();
211 }
212
213 if let Some(locations) = locations {
214 self.locations = locations;
215 }
216
217 if let Some(raw_input) = raw_input {
218 self.raw_input = Some(raw_input);
219 }
220
221 if let Some(raw_output) = raw_output {
222 if self.content.is_empty() {
223 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
224 {
225 self.content
226 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
227 markdown,
228 }));
229 }
230 }
231 self.raw_output = Some(raw_output);
232 }
233 }
234
235 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
236 self.content.iter().filter_map(|content| match content {
237 ToolCallContent::Diff(diff) => Some(diff),
238 ToolCallContent::ContentBlock(_) => None,
239 ToolCallContent::Terminal(_) => None,
240 })
241 }
242
243 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
244 self.content.iter().filter_map(|content| match content {
245 ToolCallContent::Terminal(terminal) => Some(terminal),
246 ToolCallContent::ContentBlock(_) => None,
247 ToolCallContent::Diff(_) => None,
248 })
249 }
250
251 fn to_markdown(&self, cx: &App) -> String {
252 let mut markdown = format!(
253 "**Tool Call: {}**\nStatus: {}\n\n",
254 self.label.read(cx).source(),
255 self.status
256 );
257 for content in &self.content {
258 markdown.push_str(content.to_markdown(cx).as_str());
259 markdown.push_str("\n\n");
260 }
261 markdown
262 }
263}
264
265#[derive(Debug)]
266pub enum ToolCallStatus {
267 WaitingForConfirmation {
268 options: Vec<acp::PermissionOption>,
269 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
270 },
271 Allowed {
272 status: acp::ToolCallStatus,
273 },
274 Rejected,
275 Canceled,
276}
277
278impl Display for ToolCallStatus {
279 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280 write!(
281 f,
282 "{}",
283 match self {
284 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
285 ToolCallStatus::Allowed { status } => match status {
286 acp::ToolCallStatus::Pending => "Pending",
287 acp::ToolCallStatus::InProgress => "In Progress",
288 acp::ToolCallStatus::Completed => "Completed",
289 acp::ToolCallStatus::Failed => "Failed",
290 },
291 ToolCallStatus::Rejected => "Rejected",
292 ToolCallStatus::Canceled => "Canceled",
293 }
294 )
295 }
296}
297
298#[derive(Debug, PartialEq, Clone)]
299pub enum ContentBlock {
300 Empty,
301 Markdown { markdown: Entity<Markdown> },
302}
303
304impl ContentBlock {
305 pub fn new(
306 block: acp::ContentBlock,
307 language_registry: &Arc<LanguageRegistry>,
308 cx: &mut App,
309 ) -> Self {
310 let mut this = Self::Empty;
311 this.append(block, language_registry, cx);
312 this
313 }
314
315 pub fn new_combined(
316 blocks: impl IntoIterator<Item = acp::ContentBlock>,
317 language_registry: Arc<LanguageRegistry>,
318 cx: &mut App,
319 ) -> Self {
320 let mut this = Self::Empty;
321 for block in blocks {
322 this.append(block, &language_registry, cx);
323 }
324 this
325 }
326
327 pub fn append(
328 &mut self,
329 block: acp::ContentBlock,
330 language_registry: &Arc<LanguageRegistry>,
331 cx: &mut App,
332 ) {
333 let new_content = match block {
334 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
335 acp::ContentBlock::Resource(acp::EmbeddedResource {
336 resource:
337 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
338 uri,
339 ..
340 }),
341 ..
342 }) => {
343 if let Some(uri) = MentionUri::parse(&uri).log_err() {
344 uri.to_link()
345 } else {
346 uri.clone()
347 }
348 }
349 acp::ContentBlock::Image(_)
350 | acp::ContentBlock::Audio(_)
351 | acp::ContentBlock::Resource(acp::EmbeddedResource { .. })
352 | acp::ContentBlock::ResourceLink(_) => String::new(),
353 };
354
355 match self {
356 ContentBlock::Empty => {
357 *self = ContentBlock::Markdown {
358 markdown: cx.new(|cx| {
359 Markdown::new(
360 new_content.into(),
361 Some(language_registry.clone()),
362 None,
363 cx,
364 )
365 }),
366 };
367 }
368 ContentBlock::Markdown { markdown } => {
369 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
370 }
371 }
372 }
373
374 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
375 match self {
376 ContentBlock::Empty => "",
377 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
378 }
379 }
380
381 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
382 match self {
383 ContentBlock::Empty => None,
384 ContentBlock::Markdown { markdown } => Some(markdown),
385 }
386 }
387}
388
389#[derive(Debug)]
390pub enum ToolCallContent {
391 ContentBlock(ContentBlock),
392 Diff(Entity<Diff>),
393 Terminal(Entity<Terminal>),
394}
395
396impl ToolCallContent {
397 pub fn from_acp(
398 content: acp::ToolCallContent,
399 language_registry: Arc<LanguageRegistry>,
400 cx: &mut App,
401 ) -> Self {
402 match content {
403 acp::ToolCallContent::Content { content } => {
404 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
405 }
406 acp::ToolCallContent::Diff { diff } => {
407 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
408 }
409 }
410 }
411
412 pub fn to_markdown(&self, cx: &App) -> String {
413 match self {
414 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
415 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
416 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
417 }
418 }
419}
420
421#[derive(Debug, PartialEq)]
422pub enum ToolCallUpdate {
423 UpdateFields(acp::ToolCallUpdate),
424 UpdateDiff(ToolCallUpdateDiff),
425 UpdateTerminal(ToolCallUpdateTerminal),
426}
427
428impl ToolCallUpdate {
429 fn id(&self) -> &acp::ToolCallId {
430 match self {
431 Self::UpdateFields(update) => &update.id,
432 Self::UpdateDiff(diff) => &diff.id,
433 Self::UpdateTerminal(terminal) => &terminal.id,
434 }
435 }
436}
437
438impl From<acp::ToolCallUpdate> for ToolCallUpdate {
439 fn from(update: acp::ToolCallUpdate) -> Self {
440 Self::UpdateFields(update)
441 }
442}
443
444impl From<ToolCallUpdateDiff> for ToolCallUpdate {
445 fn from(diff: ToolCallUpdateDiff) -> Self {
446 Self::UpdateDiff(diff)
447 }
448}
449
450#[derive(Debug, PartialEq)]
451pub struct ToolCallUpdateDiff {
452 pub id: acp::ToolCallId,
453 pub diff: Entity<Diff>,
454}
455
456impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
457 fn from(terminal: ToolCallUpdateTerminal) -> Self {
458 Self::UpdateTerminal(terminal)
459 }
460}
461
462#[derive(Debug, PartialEq)]
463pub struct ToolCallUpdateTerminal {
464 pub id: acp::ToolCallId,
465 pub terminal: Entity<Terminal>,
466}
467
468#[derive(Debug, Default)]
469pub struct Plan {
470 pub entries: Vec<PlanEntry>,
471}
472
473#[derive(Debug)]
474pub struct PlanStats<'a> {
475 pub in_progress_entry: Option<&'a PlanEntry>,
476 pub pending: u32,
477 pub completed: u32,
478}
479
480impl Plan {
481 pub fn is_empty(&self) -> bool {
482 self.entries.is_empty()
483 }
484
485 pub fn stats(&self) -> PlanStats<'_> {
486 let mut stats = PlanStats {
487 in_progress_entry: None,
488 pending: 0,
489 completed: 0,
490 };
491
492 for entry in &self.entries {
493 match &entry.status {
494 acp::PlanEntryStatus::Pending => {
495 stats.pending += 1;
496 }
497 acp::PlanEntryStatus::InProgress => {
498 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
499 }
500 acp::PlanEntryStatus::Completed => {
501 stats.completed += 1;
502 }
503 }
504 }
505
506 stats
507 }
508}
509
510#[derive(Debug)]
511pub struct PlanEntry {
512 pub content: Entity<Markdown>,
513 pub priority: acp::PlanEntryPriority,
514 pub status: acp::PlanEntryStatus,
515}
516
517impl PlanEntry {
518 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
519 Self {
520 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
521 priority: entry.priority,
522 status: entry.status,
523 }
524 }
525}
526
527pub struct AcpThread {
528 title: SharedString,
529 entries: Vec<AgentThreadEntry>,
530 plan: Plan,
531 project: Entity<Project>,
532 action_log: Entity<ActionLog>,
533 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
534 send_task: Option<Task<()>>,
535 connection: Rc<dyn AgentConnection>,
536 session_id: acp::SessionId,
537}
538
539pub enum AcpThreadEvent {
540 NewEntry,
541 EntryUpdated(usize),
542 ToolAuthorizationRequired,
543 Stopped,
544 Error,
545 ServerExited(ExitStatus),
546}
547
548impl EventEmitter<AcpThreadEvent> for AcpThread {}
549
550#[derive(PartialEq, Eq)]
551pub enum ThreadStatus {
552 Idle,
553 WaitingForToolConfirmation,
554 Generating,
555}
556
557#[derive(Debug, Clone)]
558pub enum LoadError {
559 Unsupported {
560 error_message: SharedString,
561 upgrade_message: SharedString,
562 upgrade_command: String,
563 },
564 Exited(i32),
565 Other(SharedString),
566}
567
568impl Display for LoadError {
569 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
570 match self {
571 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
572 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
573 LoadError::Other(msg) => write!(f, "{}", msg),
574 }
575 }
576}
577
578impl Error for LoadError {}
579
580impl AcpThread {
581 pub fn new(
582 title: impl Into<SharedString>,
583 connection: Rc<dyn AgentConnection>,
584 project: Entity<Project>,
585 session_id: acp::SessionId,
586 cx: &mut Context<Self>,
587 ) -> Self {
588 let action_log = cx.new(|_| ActionLog::new(project.clone()));
589
590 Self {
591 action_log,
592 shared_buffers: Default::default(),
593 entries: Default::default(),
594 plan: Default::default(),
595 title: title.into(),
596 project,
597 send_task: None,
598 connection,
599 session_id,
600 }
601 }
602
603 pub fn action_log(&self) -> &Entity<ActionLog> {
604 &self.action_log
605 }
606
607 pub fn project(&self) -> &Entity<Project> {
608 &self.project
609 }
610
611 pub fn title(&self) -> SharedString {
612 self.title.clone()
613 }
614
615 pub fn entries(&self) -> &[AgentThreadEntry] {
616 &self.entries
617 }
618
619 pub fn session_id(&self) -> &acp::SessionId {
620 &self.session_id
621 }
622
623 pub fn status(&self) -> ThreadStatus {
624 if self.send_task.is_some() {
625 if self.waiting_for_tool_confirmation() {
626 ThreadStatus::WaitingForToolConfirmation
627 } else {
628 ThreadStatus::Generating
629 }
630 } else {
631 ThreadStatus::Idle
632 }
633 }
634
635 pub fn has_pending_edit_tool_calls(&self) -> bool {
636 for entry in self.entries.iter().rev() {
637 match entry {
638 AgentThreadEntry::UserMessage(_) => return false,
639 AgentThreadEntry::ToolCall(
640 call @ ToolCall {
641 status:
642 ToolCallStatus::Allowed {
643 status:
644 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
645 },
646 ..
647 },
648 ) if call.diffs().next().is_some() => {
649 return true;
650 }
651 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
652 }
653 }
654
655 false
656 }
657
658 pub fn used_tools_since_last_user_message(&self) -> bool {
659 for entry in self.entries.iter().rev() {
660 match entry {
661 AgentThreadEntry::UserMessage(..) => return false,
662 AgentThreadEntry::AssistantMessage(..) => continue,
663 AgentThreadEntry::ToolCall(..) => return true,
664 }
665 }
666
667 false
668 }
669
670 pub fn handle_session_update(
671 &mut self,
672 update: acp::SessionUpdate,
673 cx: &mut Context<Self>,
674 ) -> Result<()> {
675 match update {
676 acp::SessionUpdate::UserMessageChunk { content } => {
677 self.push_user_content_block(content, cx);
678 }
679 acp::SessionUpdate::AgentMessageChunk { content } => {
680 self.push_assistant_content_block(content, false, cx);
681 }
682 acp::SessionUpdate::AgentThoughtChunk { content } => {
683 self.push_assistant_content_block(content, true, cx);
684 }
685 acp::SessionUpdate::ToolCall(tool_call) => {
686 self.upsert_tool_call(tool_call, cx);
687 }
688 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
689 self.update_tool_call(tool_call_update, cx)?;
690 }
691 acp::SessionUpdate::Plan(plan) => {
692 self.update_plan(plan, cx);
693 }
694 }
695 Ok(())
696 }
697
698 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
699 let language_registry = self.project.read(cx).languages().clone();
700 let entries_len = self.entries.len();
701
702 if let Some(last_entry) = self.entries.last_mut()
703 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
704 {
705 content.append(chunk, &language_registry, cx);
706 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
707 } else {
708 let content = ContentBlock::new(chunk, &language_registry, cx);
709 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
710 }
711 }
712
713 pub fn push_assistant_content_block(
714 &mut self,
715 chunk: acp::ContentBlock,
716 is_thought: bool,
717 cx: &mut Context<Self>,
718 ) {
719 let language_registry = self.project.read(cx).languages().clone();
720 let entries_len = self.entries.len();
721 if let Some(last_entry) = self.entries.last_mut()
722 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
723 {
724 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
725 match (chunks.last_mut(), is_thought) {
726 (Some(AssistantMessageChunk::Message { block }), false)
727 | (Some(AssistantMessageChunk::Thought { block }), true) => {
728 block.append(chunk, &language_registry, cx)
729 }
730 _ => {
731 let block = ContentBlock::new(chunk, &language_registry, cx);
732 if is_thought {
733 chunks.push(AssistantMessageChunk::Thought { block })
734 } else {
735 chunks.push(AssistantMessageChunk::Message { block })
736 }
737 }
738 }
739 } else {
740 let block = ContentBlock::new(chunk, &language_registry, cx);
741 let chunk = if is_thought {
742 AssistantMessageChunk::Thought { block }
743 } else {
744 AssistantMessageChunk::Message { block }
745 };
746
747 self.push_entry(
748 AgentThreadEntry::AssistantMessage(AssistantMessage {
749 chunks: vec![chunk],
750 }),
751 cx,
752 );
753 }
754 }
755
756 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
757 self.entries.push(entry);
758 cx.emit(AcpThreadEvent::NewEntry);
759 }
760
761 pub fn update_tool_call(
762 &mut self,
763 update: impl Into<ToolCallUpdate>,
764 cx: &mut Context<Self>,
765 ) -> Result<()> {
766 let update = update.into();
767 let languages = self.project.read(cx).languages().clone();
768
769 let (ix, current_call) = self
770 .tool_call_mut(update.id())
771 .context("Tool call not found")?;
772 match update {
773 ToolCallUpdate::UpdateFields(update) => {
774 current_call.update_fields(update.fields, languages, cx);
775 }
776 ToolCallUpdate::UpdateDiff(update) => {
777 current_call.content.clear();
778 current_call
779 .content
780 .push(ToolCallContent::Diff(update.diff));
781 }
782 ToolCallUpdate::UpdateTerminal(update) => {
783 current_call.content.clear();
784 current_call
785 .content
786 .push(ToolCallContent::Terminal(update.terminal));
787 }
788 }
789
790 cx.emit(AcpThreadEvent::EntryUpdated(ix));
791
792 Ok(())
793 }
794
795 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
796 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
797 let status = ToolCallStatus::Allowed {
798 status: tool_call.status,
799 };
800 self.upsert_tool_call_inner(tool_call, status, cx)
801 }
802
803 pub fn upsert_tool_call_inner(
804 &mut self,
805 tool_call: acp::ToolCall,
806 status: ToolCallStatus,
807 cx: &mut Context<Self>,
808 ) {
809 let language_registry = self.project.read(cx).languages().clone();
810 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
811
812 let location = call.locations.last().cloned();
813
814 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
815 *current_call = call;
816
817 cx.emit(AcpThreadEvent::EntryUpdated(ix));
818 } else {
819 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
820 }
821
822 if let Some(location) = location {
823 self.set_project_location(location, cx)
824 }
825 }
826
827 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
828 // The tool call we are looking for is typically the last one, or very close to the end.
829 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
830 self.entries
831 .iter_mut()
832 .enumerate()
833 .rev()
834 .find_map(|(index, tool_call)| {
835 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
836 && &tool_call.id == id
837 {
838 Some((index, tool_call))
839 } else {
840 None
841 }
842 })
843 }
844
845 pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
846 self.project.update(cx, |project, cx| {
847 let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
848 return;
849 };
850 let buffer = project.open_buffer(path, cx);
851 cx.spawn(async move |project, cx| {
852 let buffer = buffer.await?;
853
854 project.update(cx, |project, cx| {
855 let position = if let Some(line) = location.line {
856 let snapshot = buffer.read(cx).snapshot();
857 let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
858 snapshot.anchor_before(point)
859 } else {
860 Anchor::MIN
861 };
862
863 project.set_agent_location(
864 Some(AgentLocation {
865 buffer: buffer.downgrade(),
866 position,
867 }),
868 cx,
869 );
870 })
871 })
872 .detach_and_log_err(cx);
873 });
874 }
875
876 pub fn request_tool_call_authorization(
877 &mut self,
878 tool_call: acp::ToolCall,
879 options: Vec<acp::PermissionOption>,
880 cx: &mut Context<Self>,
881 ) -> oneshot::Receiver<acp::PermissionOptionId> {
882 let (tx, rx) = oneshot::channel();
883
884 let status = ToolCallStatus::WaitingForConfirmation {
885 options,
886 respond_tx: tx,
887 };
888
889 self.upsert_tool_call_inner(tool_call, status, cx);
890 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
891 rx
892 }
893
894 pub fn authorize_tool_call(
895 &mut self,
896 id: acp::ToolCallId,
897 option_id: acp::PermissionOptionId,
898 option_kind: acp::PermissionOptionKind,
899 cx: &mut Context<Self>,
900 ) {
901 let Some((ix, call)) = self.tool_call_mut(&id) else {
902 return;
903 };
904
905 let new_status = match option_kind {
906 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
907 ToolCallStatus::Rejected
908 }
909 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
910 ToolCallStatus::Allowed {
911 status: acp::ToolCallStatus::InProgress,
912 }
913 }
914 };
915
916 let curr_status = mem::replace(&mut call.status, new_status);
917
918 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
919 respond_tx.send(option_id).log_err();
920 } else if cfg!(debug_assertions) {
921 panic!("tried to authorize an already authorized tool call");
922 }
923
924 cx.emit(AcpThreadEvent::EntryUpdated(ix));
925 }
926
927 /// Returns true if the last turn is awaiting tool authorization
928 pub fn waiting_for_tool_confirmation(&self) -> bool {
929 for entry in self.entries.iter().rev() {
930 match &entry {
931 AgentThreadEntry::ToolCall(call) => match call.status {
932 ToolCallStatus::WaitingForConfirmation { .. } => return true,
933 ToolCallStatus::Allowed { .. }
934 | ToolCallStatus::Rejected
935 | ToolCallStatus::Canceled => continue,
936 },
937 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
938 // Reached the beginning of the turn
939 return false;
940 }
941 }
942 }
943 false
944 }
945
946 pub fn plan(&self) -> &Plan {
947 &self.plan
948 }
949
950 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
951 let new_entries_len = request.entries.len();
952 let mut new_entries = request.entries.into_iter();
953
954 // Reuse existing markdown to prevent flickering
955 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
956 let PlanEntry {
957 content,
958 priority,
959 status,
960 } = old;
961 content.update(cx, |old, cx| {
962 old.replace(new.content, cx);
963 });
964 *priority = new.priority;
965 *status = new.status;
966 }
967 for new in new_entries {
968 self.plan.entries.push(PlanEntry::from_acp(new, cx))
969 }
970 self.plan.entries.truncate(new_entries_len);
971
972 cx.notify();
973 }
974
975 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
976 self.plan
977 .entries
978 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
979 cx.notify();
980 }
981
982 #[cfg(any(test, feature = "test-support"))]
983 pub fn send_raw(
984 &mut self,
985 message: &str,
986 cx: &mut Context<Self>,
987 ) -> BoxFuture<'static, Result<()>> {
988 self.send(
989 vec![acp::ContentBlock::Text(acp::TextContent {
990 text: message.to_string(),
991 annotations: None,
992 })],
993 cx,
994 )
995 }
996
997 pub fn send(
998 &mut self,
999 message: Vec<acp::ContentBlock>,
1000 cx: &mut Context<Self>,
1001 ) -> BoxFuture<'static, Result<()>> {
1002 let block = ContentBlock::new_combined(
1003 message.clone(),
1004 self.project.read(cx).languages().clone(),
1005 cx,
1006 );
1007 self.push_entry(
1008 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1009 cx,
1010 );
1011 self.clear_completed_plan_entries(cx);
1012
1013 let (tx, rx) = oneshot::channel();
1014 let cancel_task = self.cancel(cx);
1015
1016 self.send_task = Some(cx.spawn(async move |this, cx| {
1017 async {
1018 cancel_task.await;
1019
1020 let result = this
1021 .update(cx, |this, cx| {
1022 this.connection.prompt(
1023 acp::PromptRequest {
1024 prompt: message,
1025 session_id: this.session_id.clone(),
1026 },
1027 cx,
1028 )
1029 })?
1030 .await;
1031
1032 tx.send(result).log_err();
1033
1034 anyhow::Ok(())
1035 }
1036 .await
1037 .log_err();
1038 }));
1039
1040 cx.spawn(async move |this, cx| match rx.await {
1041 Ok(Err(e)) => {
1042 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1043 .log_err();
1044 Err(e)?
1045 }
1046 result => {
1047 let cancelled = matches!(
1048 result,
1049 Ok(Ok(acp::PromptResponse {
1050 stop_reason: acp::StopReason::Cancelled
1051 }))
1052 );
1053
1054 // We only take the task if the current prompt wasn't cancelled.
1055 //
1056 // This prompt may have been cancelled because another one was sent
1057 // while it was still generating. In these cases, dropping `send_task`
1058 // would cause the next generation to be cancelled.
1059 if !cancelled {
1060 this.update(cx, |this, _cx| this.send_task.take()).ok();
1061 }
1062
1063 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1064 .log_err();
1065 Ok(())
1066 }
1067 })
1068 .boxed()
1069 }
1070
1071 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1072 let Some(send_task) = self.send_task.take() else {
1073 return Task::ready(());
1074 };
1075
1076 for entry in self.entries.iter_mut() {
1077 if let AgentThreadEntry::ToolCall(call) = entry {
1078 let cancel = matches!(
1079 call.status,
1080 ToolCallStatus::WaitingForConfirmation { .. }
1081 | ToolCallStatus::Allowed {
1082 status: acp::ToolCallStatus::InProgress
1083 }
1084 );
1085
1086 if cancel {
1087 call.status = ToolCallStatus::Canceled;
1088 }
1089 }
1090 }
1091
1092 self.connection.cancel(&self.session_id, cx);
1093
1094 // Wait for the send task to complete
1095 cx.foreground_executor().spawn(send_task)
1096 }
1097
1098 pub fn read_text_file(
1099 &self,
1100 path: PathBuf,
1101 line: Option<u32>,
1102 limit: Option<u32>,
1103 reuse_shared_snapshot: bool,
1104 cx: &mut Context<Self>,
1105 ) -> Task<Result<String>> {
1106 let project = self.project.clone();
1107 let action_log = self.action_log.clone();
1108 cx.spawn(async move |this, cx| {
1109 let load = project.update(cx, |project, cx| {
1110 let path = project
1111 .project_path_for_absolute_path(&path, cx)
1112 .context("invalid path")?;
1113 anyhow::Ok(project.open_buffer(path, cx))
1114 });
1115 let buffer = load??.await?;
1116
1117 let snapshot = if reuse_shared_snapshot {
1118 this.read_with(cx, |this, _| {
1119 this.shared_buffers.get(&buffer.clone()).cloned()
1120 })
1121 .log_err()
1122 .flatten()
1123 } else {
1124 None
1125 };
1126
1127 let snapshot = if let Some(snapshot) = snapshot {
1128 snapshot
1129 } else {
1130 action_log.update(cx, |action_log, cx| {
1131 action_log.buffer_read(buffer.clone(), cx);
1132 })?;
1133 project.update(cx, |project, cx| {
1134 let position = buffer
1135 .read(cx)
1136 .snapshot()
1137 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1138 project.set_agent_location(
1139 Some(AgentLocation {
1140 buffer: buffer.downgrade(),
1141 position,
1142 }),
1143 cx,
1144 );
1145 })?;
1146
1147 buffer.update(cx, |buffer, _| buffer.snapshot())?
1148 };
1149
1150 this.update(cx, |this, _| {
1151 let text = snapshot.text();
1152 this.shared_buffers.insert(buffer.clone(), snapshot);
1153 if line.is_none() && limit.is_none() {
1154 return Ok(text);
1155 }
1156 let limit = limit.unwrap_or(u32::MAX) as usize;
1157 let Some(line) = line else {
1158 return Ok(text.lines().take(limit).collect::<String>());
1159 };
1160
1161 let count = text.lines().count();
1162 if count < line as usize {
1163 anyhow::bail!("There are only {} lines", count);
1164 }
1165 Ok(text
1166 .lines()
1167 .skip(line as usize + 1)
1168 .take(limit)
1169 .collect::<String>())
1170 })?
1171 })
1172 }
1173
1174 pub fn write_text_file(
1175 &self,
1176 path: PathBuf,
1177 content: String,
1178 cx: &mut Context<Self>,
1179 ) -> Task<Result<()>> {
1180 let project = self.project.clone();
1181 let action_log = self.action_log.clone();
1182 cx.spawn(async move |this, cx| {
1183 let load = project.update(cx, |project, cx| {
1184 let path = project
1185 .project_path_for_absolute_path(&path, cx)
1186 .context("invalid path")?;
1187 anyhow::Ok(project.open_buffer(path, cx))
1188 });
1189 let buffer = load??.await?;
1190 let snapshot = this.update(cx, |this, cx| {
1191 this.shared_buffers
1192 .get(&buffer)
1193 .cloned()
1194 .unwrap_or_else(|| buffer.read(cx).snapshot())
1195 })?;
1196 let edits = cx
1197 .background_executor()
1198 .spawn(async move {
1199 let old_text = snapshot.text();
1200 text_diff(old_text.as_str(), &content)
1201 .into_iter()
1202 .map(|(range, replacement)| {
1203 (
1204 snapshot.anchor_after(range.start)
1205 ..snapshot.anchor_before(range.end),
1206 replacement,
1207 )
1208 })
1209 .collect::<Vec<_>>()
1210 })
1211 .await;
1212 cx.update(|cx| {
1213 project.update(cx, |project, cx| {
1214 project.set_agent_location(
1215 Some(AgentLocation {
1216 buffer: buffer.downgrade(),
1217 position: edits
1218 .last()
1219 .map(|(range, _)| range.end)
1220 .unwrap_or(Anchor::MIN),
1221 }),
1222 cx,
1223 );
1224 });
1225
1226 action_log.update(cx, |action_log, cx| {
1227 action_log.buffer_read(buffer.clone(), cx);
1228 });
1229 buffer.update(cx, |buffer, cx| {
1230 buffer.edit(edits, None, cx);
1231 });
1232 action_log.update(cx, |action_log, cx| {
1233 action_log.buffer_edited(buffer.clone(), cx);
1234 });
1235 })?;
1236 project
1237 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1238 .await
1239 })
1240 }
1241
1242 pub fn to_markdown(&self, cx: &App) -> String {
1243 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1244 }
1245
1246 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1247 cx.emit(AcpThreadEvent::ServerExited(status));
1248 }
1249}
1250
1251fn markdown_for_raw_output(
1252 raw_output: &serde_json::Value,
1253 language_registry: &Arc<LanguageRegistry>,
1254 cx: &mut App,
1255) -> Option<Entity<Markdown>> {
1256 match raw_output {
1257 serde_json::Value::Null => None,
1258 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1259 Markdown::new(
1260 value.to_string().into(),
1261 Some(language_registry.clone()),
1262 None,
1263 cx,
1264 )
1265 })),
1266 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1267 Markdown::new(
1268 value.to_string().into(),
1269 Some(language_registry.clone()),
1270 None,
1271 cx,
1272 )
1273 })),
1274 serde_json::Value::String(value) => Some(cx.new(|cx| {
1275 Markdown::new(
1276 value.clone().into(),
1277 Some(language_registry.clone()),
1278 None,
1279 cx,
1280 )
1281 })),
1282 value => Some(cx.new(|cx| {
1283 Markdown::new(
1284 format!("```json\n{}\n```", value).into(),
1285 Some(language_registry.clone()),
1286 None,
1287 cx,
1288 )
1289 })),
1290 }
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295 use super::*;
1296 use anyhow::anyhow;
1297 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1298 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1299 use indoc::indoc;
1300 use project::FakeFs;
1301 use rand::Rng as _;
1302 use serde_json::json;
1303 use settings::SettingsStore;
1304 use smol::stream::StreamExt as _;
1305 use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1306
1307 use util::path;
1308
1309 fn init_test(cx: &mut TestAppContext) {
1310 env_logger::try_init().ok();
1311 cx.update(|cx| {
1312 let settings_store = SettingsStore::test(cx);
1313 cx.set_global(settings_store);
1314 Project::init_settings(cx);
1315 language::init(cx);
1316 });
1317 }
1318
1319 #[gpui::test]
1320 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1321 init_test(cx);
1322
1323 let fs = FakeFs::new(cx.executor());
1324 let project = Project::test(fs, [], cx).await;
1325 let connection = Rc::new(FakeAgentConnection::new());
1326 let thread = cx
1327 .spawn(async move |mut cx| {
1328 connection
1329 .new_thread(project, Path::new(path!("/test")), &mut cx)
1330 .await
1331 })
1332 .await
1333 .unwrap();
1334
1335 // Test creating a new user message
1336 thread.update(cx, |thread, cx| {
1337 thread.push_user_content_block(
1338 acp::ContentBlock::Text(acp::TextContent {
1339 annotations: None,
1340 text: "Hello, ".to_string(),
1341 }),
1342 cx,
1343 );
1344 });
1345
1346 thread.update(cx, |thread, cx| {
1347 assert_eq!(thread.entries.len(), 1);
1348 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1349 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1350 } else {
1351 panic!("Expected UserMessage");
1352 }
1353 });
1354
1355 // Test appending to existing user message
1356 thread.update(cx, |thread, cx| {
1357 thread.push_user_content_block(
1358 acp::ContentBlock::Text(acp::TextContent {
1359 annotations: None,
1360 text: "world!".to_string(),
1361 }),
1362 cx,
1363 );
1364 });
1365
1366 thread.update(cx, |thread, cx| {
1367 assert_eq!(thread.entries.len(), 1);
1368 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1369 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1370 } else {
1371 panic!("Expected UserMessage");
1372 }
1373 });
1374
1375 // Test creating new user message after assistant message
1376 thread.update(cx, |thread, cx| {
1377 thread.push_assistant_content_block(
1378 acp::ContentBlock::Text(acp::TextContent {
1379 annotations: None,
1380 text: "Assistant response".to_string(),
1381 }),
1382 false,
1383 cx,
1384 );
1385 });
1386
1387 thread.update(cx, |thread, cx| {
1388 thread.push_user_content_block(
1389 acp::ContentBlock::Text(acp::TextContent {
1390 annotations: None,
1391 text: "New user message".to_string(),
1392 }),
1393 cx,
1394 );
1395 });
1396
1397 thread.update(cx, |thread, cx| {
1398 assert_eq!(thread.entries.len(), 3);
1399 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1400 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1401 } else {
1402 panic!("Expected UserMessage at index 2");
1403 }
1404 });
1405 }
1406
1407 #[gpui::test]
1408 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1409 init_test(cx);
1410
1411 let fs = FakeFs::new(cx.executor());
1412 let project = Project::test(fs, [], cx).await;
1413 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1414 |_, thread, mut cx| {
1415 async move {
1416 thread.update(&mut cx, |thread, cx| {
1417 thread
1418 .handle_session_update(
1419 acp::SessionUpdate::AgentThoughtChunk {
1420 content: "Thinking ".into(),
1421 },
1422 cx,
1423 )
1424 .unwrap();
1425 thread
1426 .handle_session_update(
1427 acp::SessionUpdate::AgentThoughtChunk {
1428 content: "hard!".into(),
1429 },
1430 cx,
1431 )
1432 .unwrap();
1433 })?;
1434 Ok(acp::PromptResponse {
1435 stop_reason: acp::StopReason::EndTurn,
1436 })
1437 }
1438 .boxed_local()
1439 },
1440 ));
1441
1442 let thread = cx
1443 .spawn(async move |mut cx| {
1444 connection
1445 .new_thread(project, Path::new(path!("/test")), &mut cx)
1446 .await
1447 })
1448 .await
1449 .unwrap();
1450
1451 thread
1452 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1453 .await
1454 .unwrap();
1455
1456 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1457 assert_eq!(
1458 output,
1459 indoc! {r#"
1460 ## User
1461
1462 Hello from Zed!
1463
1464 ## Assistant
1465
1466 <thinking>
1467 Thinking hard!
1468 </thinking>
1469
1470 "#}
1471 );
1472 }
1473
1474 #[gpui::test]
1475 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1476 init_test(cx);
1477
1478 let fs = FakeFs::new(cx.executor());
1479 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1480 .await;
1481 let project = Project::test(fs.clone(), [], cx).await;
1482 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1483 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1484 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1485 move |_, thread, mut cx| {
1486 let read_file_tx = read_file_tx.clone();
1487 async move {
1488 let content = thread
1489 .update(&mut cx, |thread, cx| {
1490 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1491 })
1492 .unwrap()
1493 .await
1494 .unwrap();
1495 assert_eq!(content, "one\ntwo\nthree\n");
1496 read_file_tx.take().unwrap().send(()).unwrap();
1497 thread
1498 .update(&mut cx, |thread, cx| {
1499 thread.write_text_file(
1500 path!("/tmp/foo").into(),
1501 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1502 cx,
1503 )
1504 })
1505 .unwrap()
1506 .await
1507 .unwrap();
1508 Ok(acp::PromptResponse {
1509 stop_reason: acp::StopReason::EndTurn,
1510 })
1511 }
1512 .boxed_local()
1513 },
1514 ));
1515
1516 let (worktree, pathbuf) = project
1517 .update(cx, |project, cx| {
1518 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1519 })
1520 .await
1521 .unwrap();
1522 let buffer = project
1523 .update(cx, |project, cx| {
1524 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1525 })
1526 .await
1527 .unwrap();
1528
1529 let thread = cx
1530 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1531 .await
1532 .unwrap();
1533
1534 let request = thread.update(cx, |thread, cx| {
1535 thread.send_raw("Extend the count in /tmp/foo", cx)
1536 });
1537 read_file_rx.await.ok();
1538 buffer.update(cx, |buffer, cx| {
1539 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1540 });
1541 cx.run_until_parked();
1542 assert_eq!(
1543 buffer.read_with(cx, |buffer, _| buffer.text()),
1544 "zero\none\ntwo\nthree\nfour\nfive\n"
1545 );
1546 assert_eq!(
1547 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1548 "zero\none\ntwo\nthree\nfour\nfive\n"
1549 );
1550 request.await.unwrap();
1551 }
1552
1553 #[gpui::test]
1554 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1555 init_test(cx);
1556
1557 let fs = FakeFs::new(cx.executor());
1558 let project = Project::test(fs, [], cx).await;
1559 let id = acp::ToolCallId("test".into());
1560
1561 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1562 let id = id.clone();
1563 move |_, thread, mut cx| {
1564 let id = id.clone();
1565 async move {
1566 thread
1567 .update(&mut cx, |thread, cx| {
1568 thread.handle_session_update(
1569 acp::SessionUpdate::ToolCall(acp::ToolCall {
1570 id: id.clone(),
1571 title: "Label".into(),
1572 kind: acp::ToolKind::Fetch,
1573 status: acp::ToolCallStatus::InProgress,
1574 content: vec![],
1575 locations: vec![],
1576 raw_input: None,
1577 raw_output: None,
1578 }),
1579 cx,
1580 )
1581 })
1582 .unwrap()
1583 .unwrap();
1584 Ok(acp::PromptResponse {
1585 stop_reason: acp::StopReason::EndTurn,
1586 })
1587 }
1588 .boxed_local()
1589 }
1590 }));
1591
1592 let thread = cx
1593 .spawn(async move |mut cx| {
1594 connection
1595 .new_thread(project, Path::new(path!("/test")), &mut cx)
1596 .await
1597 })
1598 .await
1599 .unwrap();
1600
1601 let request = thread.update(cx, |thread, cx| {
1602 thread.send_raw("Fetch https://example.com", cx)
1603 });
1604
1605 run_until_first_tool_call(&thread, cx).await;
1606
1607 thread.read_with(cx, |thread, _| {
1608 assert!(matches!(
1609 thread.entries[1],
1610 AgentThreadEntry::ToolCall(ToolCall {
1611 status: ToolCallStatus::Allowed {
1612 status: acp::ToolCallStatus::InProgress,
1613 ..
1614 },
1615 ..
1616 })
1617 ));
1618 });
1619
1620 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1621
1622 thread.read_with(cx, |thread, _| {
1623 assert!(matches!(
1624 &thread.entries[1],
1625 AgentThreadEntry::ToolCall(ToolCall {
1626 status: ToolCallStatus::Canceled,
1627 ..
1628 })
1629 ));
1630 });
1631
1632 thread
1633 .update(cx, |thread, cx| {
1634 thread.handle_session_update(
1635 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1636 id,
1637 fields: acp::ToolCallUpdateFields {
1638 status: Some(acp::ToolCallStatus::Completed),
1639 ..Default::default()
1640 },
1641 }),
1642 cx,
1643 )
1644 })
1645 .unwrap();
1646
1647 request.await.unwrap();
1648
1649 thread.read_with(cx, |thread, _| {
1650 assert!(matches!(
1651 thread.entries[1],
1652 AgentThreadEntry::ToolCall(ToolCall {
1653 status: ToolCallStatus::Allowed {
1654 status: acp::ToolCallStatus::Completed,
1655 ..
1656 },
1657 ..
1658 })
1659 ));
1660 });
1661 }
1662
1663 #[gpui::test]
1664 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1665 init_test(cx);
1666 let fs = FakeFs::new(cx.background_executor.clone());
1667 fs.insert_tree(path!("/test"), json!({})).await;
1668 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1669
1670 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1671 move |_, thread, mut cx| {
1672 async move {
1673 thread
1674 .update(&mut cx, |thread, cx| {
1675 thread.handle_session_update(
1676 acp::SessionUpdate::ToolCall(acp::ToolCall {
1677 id: acp::ToolCallId("test".into()),
1678 title: "Label".into(),
1679 kind: acp::ToolKind::Edit,
1680 status: acp::ToolCallStatus::Completed,
1681 content: vec![acp::ToolCallContent::Diff {
1682 diff: acp::Diff {
1683 path: "/test/test.txt".into(),
1684 old_text: None,
1685 new_text: "foo".into(),
1686 },
1687 }],
1688 locations: vec![],
1689 raw_input: None,
1690 raw_output: None,
1691 }),
1692 cx,
1693 )
1694 })
1695 .unwrap()
1696 .unwrap();
1697 Ok(acp::PromptResponse {
1698 stop_reason: acp::StopReason::EndTurn,
1699 })
1700 }
1701 .boxed_local()
1702 }
1703 }));
1704
1705 let thread = connection
1706 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1707 .await
1708 .unwrap();
1709 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1710 .await
1711 .unwrap();
1712
1713 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1714 }
1715
1716 async fn run_until_first_tool_call(
1717 thread: &Entity<AcpThread>,
1718 cx: &mut TestAppContext,
1719 ) -> usize {
1720 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1721
1722 let subscription = cx.update(|cx| {
1723 cx.subscribe(thread, move |thread, _, cx| {
1724 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1725 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1726 return tx.try_send(ix).unwrap();
1727 }
1728 }
1729 })
1730 });
1731
1732 select! {
1733 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1734 panic!("Timeout waiting for tool call")
1735 }
1736 ix = rx.next().fuse() => {
1737 drop(subscription);
1738 ix.unwrap()
1739 }
1740 }
1741 }
1742
1743 #[derive(Clone, Default)]
1744 struct FakeAgentConnection {
1745 auth_methods: Vec<acp::AuthMethod>,
1746 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1747 on_user_message: Option<
1748 Rc<
1749 dyn Fn(
1750 acp::PromptRequest,
1751 WeakEntity<AcpThread>,
1752 AsyncApp,
1753 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1754 + 'static,
1755 >,
1756 >,
1757 }
1758
1759 impl FakeAgentConnection {
1760 fn new() -> Self {
1761 Self {
1762 auth_methods: Vec::new(),
1763 on_user_message: None,
1764 sessions: Arc::default(),
1765 }
1766 }
1767
1768 #[expect(unused)]
1769 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1770 self.auth_methods = auth_methods;
1771 self
1772 }
1773
1774 fn on_user_message(
1775 mut self,
1776 handler: impl Fn(
1777 acp::PromptRequest,
1778 WeakEntity<AcpThread>,
1779 AsyncApp,
1780 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1781 + 'static,
1782 ) -> Self {
1783 self.on_user_message.replace(Rc::new(handler));
1784 self
1785 }
1786 }
1787
1788 impl AgentConnection for FakeAgentConnection {
1789 fn auth_methods(&self) -> &[acp::AuthMethod] {
1790 &self.auth_methods
1791 }
1792
1793 fn new_thread(
1794 self: Rc<Self>,
1795 project: Entity<Project>,
1796 _cwd: &Path,
1797 cx: &mut gpui::AsyncApp,
1798 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1799 let session_id = acp::SessionId(
1800 rand::thread_rng()
1801 .sample_iter(&rand::distributions::Alphanumeric)
1802 .take(7)
1803 .map(char::from)
1804 .collect::<String>()
1805 .into(),
1806 );
1807 let thread = cx
1808 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1809 .unwrap();
1810 self.sessions.lock().insert(session_id, thread.downgrade());
1811 Task::ready(Ok(thread))
1812 }
1813
1814 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1815 if self.auth_methods().iter().any(|m| m.id == method) {
1816 Task::ready(Ok(()))
1817 } else {
1818 Task::ready(Err(anyhow!("Invalid Auth Method")))
1819 }
1820 }
1821
1822 fn prompt(
1823 &self,
1824 params: acp::PromptRequest,
1825 cx: &mut App,
1826 ) -> Task<gpui::Result<acp::PromptResponse>> {
1827 let sessions = self.sessions.lock();
1828 let thread = sessions.get(¶ms.session_id).unwrap();
1829 if let Some(handler) = &self.on_user_message {
1830 let handler = handler.clone();
1831 let thread = thread.clone();
1832 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1833 } else {
1834 Task::ready(Ok(acp::PromptResponse {
1835 stop_reason: acp::StopReason::EndTurn,
1836 }))
1837 }
1838 }
1839
1840 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1841 let sessions = self.sessions.lock();
1842 let thread = sessions.get(&session_id).unwrap().clone();
1843
1844 cx.spawn(async move |cx| {
1845 thread
1846 .update(cx, |thread, cx| thread.cancel(cx))
1847 .unwrap()
1848 .await
1849 })
1850 .detach();
1851 }
1852 }
1853}