1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol::{self as acp};
13use anyhow::{Context as _, Result};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::Formatter;
24use std::process::ExitStatus;
25use std::rc::Rc;
26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
27use ui::App;
28use util::ResultExt;
29
30#[derive(Debug)]
31pub struct UserMessage {
32 pub content: ContentBlock,
33}
34
35impl UserMessage {
36 pub fn from_acp(
37 message: impl IntoIterator<Item = acp::ContentBlock>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut App,
40 ) -> Self {
41 let mut content = ContentBlock::Empty;
42 for chunk in message {
43 content.append(chunk, &language_registry, cx)
44 }
45 Self { content: content }
46 }
47
48 fn to_markdown(&self, cx: &App) -> String {
49 format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
126 if let AgentThreadEntry::ToolCall(ToolCall {
127 locations,
128 resolved_locations,
129 ..
130 }) = self
131 {
132 Some((
133 locations.get(ix)?.clone(),
134 resolved_locations.get(ix)?.clone()?,
135 ))
136 } else {
137 None
138 }
139 }
140}
141
142#[derive(Debug)]
143pub struct ToolCall {
144 pub id: acp::ToolCallId,
145 pub label: Entity<Markdown>,
146 pub kind: acp::ToolKind,
147 pub content: Vec<ToolCallContent>,
148 pub status: ToolCallStatus,
149 pub locations: Vec<acp::ToolCallLocation>,
150 pub resolved_locations: Vec<Option<AgentLocation>>,
151 pub raw_input: Option<serde_json::Value>,
152 pub raw_output: Option<serde_json::Value>,
153}
154
155impl ToolCall {
156 fn from_acp(
157 tool_call: acp::ToolCall,
158 status: ToolCallStatus,
159 language_registry: Arc<LanguageRegistry>,
160 cx: &mut App,
161 ) -> Self {
162 Self {
163 id: tool_call.id,
164 label: cx.new(|cx| {
165 Markdown::new(
166 tool_call.title.into(),
167 Some(language_registry.clone()),
168 None,
169 cx,
170 )
171 }),
172 kind: tool_call.kind,
173 content: tool_call
174 .content
175 .into_iter()
176 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
177 .collect(),
178 locations: tool_call.locations,
179 resolved_locations: Vec::default(),
180 status,
181 raw_input: tool_call.raw_input,
182 raw_output: tool_call.raw_output,
183 }
184 }
185
186 fn update_fields(
187 &mut self,
188 fields: acp::ToolCallUpdateFields,
189 language_registry: Arc<LanguageRegistry>,
190 cx: &mut App,
191 ) {
192 let acp::ToolCallUpdateFields {
193 kind,
194 status,
195 title,
196 content,
197 locations,
198 raw_input,
199 raw_output,
200 } = fields;
201
202 if let Some(kind) = kind {
203 self.kind = kind;
204 }
205
206 if let Some(status) = status {
207 self.status = ToolCallStatus::Allowed { status };
208 }
209
210 if let Some(title) = title {
211 self.label.update(cx, |label, cx| {
212 label.replace(title, cx);
213 });
214 }
215
216 if let Some(content) = content {
217 self.content = content
218 .into_iter()
219 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
220 .collect();
221 }
222
223 if let Some(locations) = locations {
224 self.locations = locations;
225 }
226
227 if let Some(raw_input) = raw_input {
228 self.raw_input = Some(raw_input);
229 }
230
231 if let Some(raw_output) = raw_output {
232 if self.content.is_empty() {
233 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
234 {
235 self.content
236 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
237 markdown,
238 }));
239 }
240 }
241 self.raw_output = Some(raw_output);
242 }
243 }
244
245 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
246 self.content.iter().filter_map(|content| match content {
247 ToolCallContent::Diff(diff) => Some(diff),
248 ToolCallContent::ContentBlock(_) => None,
249 ToolCallContent::Terminal(_) => None,
250 })
251 }
252
253 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
254 self.content.iter().filter_map(|content| match content {
255 ToolCallContent::Terminal(terminal) => Some(terminal),
256 ToolCallContent::ContentBlock(_) => None,
257 ToolCallContent::Diff(_) => None,
258 })
259 }
260
261 fn to_markdown(&self, cx: &App) -> String {
262 let mut markdown = format!(
263 "**Tool Call: {}**\nStatus: {}\n\n",
264 self.label.read(cx).source(),
265 self.status
266 );
267 for content in &self.content {
268 markdown.push_str(content.to_markdown(cx).as_str());
269 markdown.push_str("\n\n");
270 }
271 markdown
272 }
273
274 async fn resolve_location(
275 location: acp::ToolCallLocation,
276 project: WeakEntity<Project>,
277 cx: &mut AsyncApp,
278 ) -> Option<AgentLocation> {
279 let buffer = project
280 .update(cx, |project, cx| {
281 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
282 Some(project.open_buffer(path, cx))
283 } else {
284 None
285 }
286 })
287 .ok()??;
288 let buffer = buffer.await.log_err()?;
289 let position = buffer
290 .update(cx, |buffer, _| {
291 if let Some(row) = location.line {
292 let snapshot = buffer.snapshot();
293 let column = snapshot.indent_size_for_line(row).len;
294 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
295 snapshot.anchor_before(point)
296 } else {
297 Anchor::MIN
298 }
299 })
300 .ok()?;
301
302 Some(AgentLocation {
303 buffer: buffer.downgrade(),
304 position,
305 })
306 }
307
308 fn resolve_locations(
309 &self,
310 project: Entity<Project>,
311 cx: &mut App,
312 ) -> Task<Vec<Option<AgentLocation>>> {
313 let locations = self.locations.clone();
314 project.update(cx, |_, cx| {
315 cx.spawn(async move |project, cx| {
316 let mut new_locations = Vec::new();
317 for location in locations {
318 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
319 }
320 new_locations
321 })
322 })
323 }
324}
325
326#[derive(Debug)]
327pub enum ToolCallStatus {
328 WaitingForConfirmation {
329 options: Vec<acp::PermissionOption>,
330 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
331 },
332 Allowed {
333 status: acp::ToolCallStatus,
334 },
335 Rejected,
336 Canceled,
337}
338
339impl Display for ToolCallStatus {
340 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
341 write!(
342 f,
343 "{}",
344 match self {
345 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
346 ToolCallStatus::Allowed { status } => match status {
347 acp::ToolCallStatus::Pending => "Pending",
348 acp::ToolCallStatus::InProgress => "In Progress",
349 acp::ToolCallStatus::Completed => "Completed",
350 acp::ToolCallStatus::Failed => "Failed",
351 },
352 ToolCallStatus::Rejected => "Rejected",
353 ToolCallStatus::Canceled => "Canceled",
354 }
355 )
356 }
357}
358
359#[derive(Debug, PartialEq, Clone)]
360pub enum ContentBlock {
361 Empty,
362 Markdown { markdown: Entity<Markdown> },
363 ResourceLink { resource_link: acp::ResourceLink },
364}
365
366impl ContentBlock {
367 pub fn new(
368 block: acp::ContentBlock,
369 language_registry: &Arc<LanguageRegistry>,
370 cx: &mut App,
371 ) -> Self {
372 let mut this = Self::Empty;
373 this.append(block, language_registry, cx);
374 this
375 }
376
377 pub fn new_combined(
378 blocks: impl IntoIterator<Item = acp::ContentBlock>,
379 language_registry: Arc<LanguageRegistry>,
380 cx: &mut App,
381 ) -> Self {
382 let mut this = Self::Empty;
383 for block in blocks {
384 this.append(block, &language_registry, cx);
385 }
386 this
387 }
388
389 pub fn append(
390 &mut self,
391 block: acp::ContentBlock,
392 language_registry: &Arc<LanguageRegistry>,
393 cx: &mut App,
394 ) {
395 if matches!(self, ContentBlock::Empty) {
396 if let acp::ContentBlock::ResourceLink(resource_link) = block {
397 *self = ContentBlock::ResourceLink { resource_link };
398 return;
399 }
400 }
401
402 let new_content = self.extract_content_from_block(block);
403
404 match self {
405 ContentBlock::Empty => {
406 *self = Self::create_markdown_block(new_content, language_registry, cx);
407 }
408 ContentBlock::Markdown { markdown } => {
409 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
410 }
411 ContentBlock::ResourceLink { resource_link } => {
412 let existing_content = Self::resource_link_to_content(&resource_link.uri);
413 let combined = format!("{}\n{}", existing_content, new_content);
414
415 *self = Self::create_markdown_block(combined, language_registry, cx);
416 }
417 }
418 }
419
420 fn resource_link_to_content(uri: &str) -> String {
421 if let Some(uri) = MentionUri::parse(&uri).log_err() {
422 uri.to_link()
423 } else {
424 uri.to_string().clone()
425 }
426 }
427
428 fn create_markdown_block(
429 content: String,
430 language_registry: &Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) -> ContentBlock {
433 ContentBlock::Markdown {
434 markdown: cx
435 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
436 }
437 }
438
439 fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
440 match block {
441 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
442 acp::ContentBlock::ResourceLink(resource_link) => {
443 Self::resource_link_to_content(&resource_link.uri)
444 }
445 acp::ContentBlock::Resource(acp::EmbeddedResource {
446 resource:
447 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
448 uri,
449 ..
450 }),
451 ..
452 }) => Self::resource_link_to_content(&uri),
453 acp::ContentBlock::Image(_)
454 | acp::ContentBlock::Audio(_)
455 | acp::ContentBlock::Resource(_) => String::new(),
456 }
457 }
458
459 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
460 match self {
461 ContentBlock::Empty => "",
462 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
463 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
464 }
465 }
466
467 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
468 match self {
469 ContentBlock::Empty => None,
470 ContentBlock::Markdown { markdown } => Some(markdown),
471 ContentBlock::ResourceLink { .. } => None,
472 }
473 }
474
475 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
476 match self {
477 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
478 _ => None,
479 }
480 }
481}
482
483#[derive(Debug)]
484pub enum ToolCallContent {
485 ContentBlock(ContentBlock),
486 Diff(Entity<Diff>),
487 Terminal(Entity<Terminal>),
488}
489
490impl ToolCallContent {
491 pub fn from_acp(
492 content: acp::ToolCallContent,
493 language_registry: Arc<LanguageRegistry>,
494 cx: &mut App,
495 ) -> Self {
496 match content {
497 acp::ToolCallContent::Content { content } => {
498 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
499 }
500 acp::ToolCallContent::Diff { diff } => {
501 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
502 }
503 }
504 }
505
506 pub fn to_markdown(&self, cx: &App) -> String {
507 match self {
508 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
509 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
510 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
511 }
512 }
513}
514
515#[derive(Debug, PartialEq)]
516pub enum ToolCallUpdate {
517 UpdateFields(acp::ToolCallUpdate),
518 UpdateDiff(ToolCallUpdateDiff),
519 UpdateTerminal(ToolCallUpdateTerminal),
520}
521
522impl ToolCallUpdate {
523 fn id(&self) -> &acp::ToolCallId {
524 match self {
525 Self::UpdateFields(update) => &update.id,
526 Self::UpdateDiff(diff) => &diff.id,
527 Self::UpdateTerminal(terminal) => &terminal.id,
528 }
529 }
530}
531
532impl From<acp::ToolCallUpdate> for ToolCallUpdate {
533 fn from(update: acp::ToolCallUpdate) -> Self {
534 Self::UpdateFields(update)
535 }
536}
537
538impl From<ToolCallUpdateDiff> for ToolCallUpdate {
539 fn from(diff: ToolCallUpdateDiff) -> Self {
540 Self::UpdateDiff(diff)
541 }
542}
543
544#[derive(Debug, PartialEq)]
545pub struct ToolCallUpdateDiff {
546 pub id: acp::ToolCallId,
547 pub diff: Entity<Diff>,
548}
549
550impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
551 fn from(terminal: ToolCallUpdateTerminal) -> Self {
552 Self::UpdateTerminal(terminal)
553 }
554}
555
556#[derive(Debug, PartialEq)]
557pub struct ToolCallUpdateTerminal {
558 pub id: acp::ToolCallId,
559 pub terminal: Entity<Terminal>,
560}
561
562#[derive(Debug, Default)]
563pub struct Plan {
564 pub entries: Vec<PlanEntry>,
565}
566
567#[derive(Debug)]
568pub struct PlanStats<'a> {
569 pub in_progress_entry: Option<&'a PlanEntry>,
570 pub pending: u32,
571 pub completed: u32,
572}
573
574impl Plan {
575 pub fn is_empty(&self) -> bool {
576 self.entries.is_empty()
577 }
578
579 pub fn stats(&self) -> PlanStats<'_> {
580 let mut stats = PlanStats {
581 in_progress_entry: None,
582 pending: 0,
583 completed: 0,
584 };
585
586 for entry in &self.entries {
587 match &entry.status {
588 acp::PlanEntryStatus::Pending => {
589 stats.pending += 1;
590 }
591 acp::PlanEntryStatus::InProgress => {
592 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
593 }
594 acp::PlanEntryStatus::Completed => {
595 stats.completed += 1;
596 }
597 }
598 }
599
600 stats
601 }
602}
603
604#[derive(Debug)]
605pub struct PlanEntry {
606 pub content: Entity<Markdown>,
607 pub priority: acp::PlanEntryPriority,
608 pub status: acp::PlanEntryStatus,
609}
610
611impl PlanEntry {
612 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
613 Self {
614 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
615 priority: entry.priority,
616 status: entry.status,
617 }
618 }
619}
620
621pub struct AcpThread {
622 title: SharedString,
623 entries: Vec<AgentThreadEntry>,
624 plan: Plan,
625 project: Entity<Project>,
626 action_log: Entity<ActionLog>,
627 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
628 send_task: Option<Task<()>>,
629 connection: Rc<dyn AgentConnection>,
630 session_id: acp::SessionId,
631}
632
633pub enum AcpThreadEvent {
634 NewEntry,
635 EntryUpdated(usize),
636 ToolAuthorizationRequired,
637 Stopped,
638 Error,
639 ServerExited(ExitStatus),
640}
641
642impl EventEmitter<AcpThreadEvent> for AcpThread {}
643
644#[derive(PartialEq, Eq)]
645pub enum ThreadStatus {
646 Idle,
647 WaitingForToolConfirmation,
648 Generating,
649}
650
651#[derive(Debug, Clone)]
652pub enum LoadError {
653 Unsupported {
654 error_message: SharedString,
655 upgrade_message: SharedString,
656 upgrade_command: String,
657 },
658 Exited(i32),
659 Other(SharedString),
660}
661
662impl Display for LoadError {
663 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
664 match self {
665 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
666 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
667 LoadError::Other(msg) => write!(f, "{}", msg),
668 }
669 }
670}
671
672impl Error for LoadError {}
673
674impl AcpThread {
675 pub fn new(
676 title: impl Into<SharedString>,
677 connection: Rc<dyn AgentConnection>,
678 project: Entity<Project>,
679 session_id: acp::SessionId,
680 cx: &mut Context<Self>,
681 ) -> Self {
682 let action_log = cx.new(|_| ActionLog::new(project.clone()));
683
684 Self {
685 action_log,
686 shared_buffers: Default::default(),
687 entries: Default::default(),
688 plan: Default::default(),
689 title: title.into(),
690 project,
691 send_task: None,
692 connection,
693 session_id,
694 }
695 }
696
697 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
698 &self.connection
699 }
700
701 pub fn action_log(&self) -> &Entity<ActionLog> {
702 &self.action_log
703 }
704
705 pub fn project(&self) -> &Entity<Project> {
706 &self.project
707 }
708
709 pub fn title(&self) -> SharedString {
710 self.title.clone()
711 }
712
713 pub fn entries(&self) -> &[AgentThreadEntry] {
714 &self.entries
715 }
716
717 pub fn session_id(&self) -> &acp::SessionId {
718 &self.session_id
719 }
720
721 pub fn status(&self) -> ThreadStatus {
722 if self.send_task.is_some() {
723 if self.waiting_for_tool_confirmation() {
724 ThreadStatus::WaitingForToolConfirmation
725 } else {
726 ThreadStatus::Generating
727 }
728 } else {
729 ThreadStatus::Idle
730 }
731 }
732
733 pub fn has_pending_edit_tool_calls(&self) -> bool {
734 for entry in self.entries.iter().rev() {
735 match entry {
736 AgentThreadEntry::UserMessage(_) => return false,
737 AgentThreadEntry::ToolCall(
738 call @ ToolCall {
739 status:
740 ToolCallStatus::Allowed {
741 status:
742 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
743 },
744 ..
745 },
746 ) if call.diffs().next().is_some() => {
747 return true;
748 }
749 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
750 }
751 }
752
753 false
754 }
755
756 pub fn used_tools_since_last_user_message(&self) -> bool {
757 for entry in self.entries.iter().rev() {
758 match entry {
759 AgentThreadEntry::UserMessage(..) => return false,
760 AgentThreadEntry::AssistantMessage(..) => continue,
761 AgentThreadEntry::ToolCall(..) => return true,
762 }
763 }
764
765 false
766 }
767
768 pub fn handle_session_update(
769 &mut self,
770 update: acp::SessionUpdate,
771 cx: &mut Context<Self>,
772 ) -> Result<()> {
773 match update {
774 acp::SessionUpdate::UserMessageChunk { content } => {
775 self.push_user_content_block(content, cx);
776 }
777 acp::SessionUpdate::AgentMessageChunk { content } => {
778 self.push_assistant_content_block(content, false, cx);
779 }
780 acp::SessionUpdate::AgentThoughtChunk { content } => {
781 self.push_assistant_content_block(content, true, cx);
782 }
783 acp::SessionUpdate::ToolCall(tool_call) => {
784 self.upsert_tool_call(tool_call, cx);
785 }
786 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
787 self.update_tool_call(tool_call_update, cx)?;
788 }
789 acp::SessionUpdate::Plan(plan) => {
790 self.update_plan(plan, cx);
791 }
792 }
793 Ok(())
794 }
795
796 pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
797 let language_registry = self.project.read(cx).languages().clone();
798 let entries_len = self.entries.len();
799
800 if let Some(last_entry) = self.entries.last_mut()
801 && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
802 {
803 content.append(chunk, &language_registry, cx);
804 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
805 } else {
806 let content = ContentBlock::new(chunk, &language_registry, cx);
807 self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
808 }
809 }
810
811 pub fn push_assistant_content_block(
812 &mut self,
813 chunk: acp::ContentBlock,
814 is_thought: bool,
815 cx: &mut Context<Self>,
816 ) {
817 let language_registry = self.project.read(cx).languages().clone();
818 let entries_len = self.entries.len();
819 if let Some(last_entry) = self.entries.last_mut()
820 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
821 {
822 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
823 match (chunks.last_mut(), is_thought) {
824 (Some(AssistantMessageChunk::Message { block }), false)
825 | (Some(AssistantMessageChunk::Thought { block }), true) => {
826 block.append(chunk, &language_registry, cx)
827 }
828 _ => {
829 let block = ContentBlock::new(chunk, &language_registry, cx);
830 if is_thought {
831 chunks.push(AssistantMessageChunk::Thought { block })
832 } else {
833 chunks.push(AssistantMessageChunk::Message { block })
834 }
835 }
836 }
837 } else {
838 let block = ContentBlock::new(chunk, &language_registry, cx);
839 let chunk = if is_thought {
840 AssistantMessageChunk::Thought { block }
841 } else {
842 AssistantMessageChunk::Message { block }
843 };
844
845 self.push_entry(
846 AgentThreadEntry::AssistantMessage(AssistantMessage {
847 chunks: vec![chunk],
848 }),
849 cx,
850 );
851 }
852 }
853
854 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
855 self.entries.push(entry);
856 cx.emit(AcpThreadEvent::NewEntry);
857 }
858
859 pub fn update_tool_call(
860 &mut self,
861 update: impl Into<ToolCallUpdate>,
862 cx: &mut Context<Self>,
863 ) -> Result<()> {
864 let update = update.into();
865 let languages = self.project.read(cx).languages().clone();
866
867 let (ix, current_call) = self
868 .tool_call_mut(update.id())
869 .context("Tool call not found")?;
870 match update {
871 ToolCallUpdate::UpdateFields(update) => {
872 let location_updated = update.fields.locations.is_some();
873 current_call.update_fields(update.fields, languages, cx);
874 if location_updated {
875 self.resolve_locations(update.id.clone(), cx);
876 }
877 }
878 ToolCallUpdate::UpdateDiff(update) => {
879 current_call.content.clear();
880 current_call
881 .content
882 .push(ToolCallContent::Diff(update.diff));
883 }
884 ToolCallUpdate::UpdateTerminal(update) => {
885 current_call.content.clear();
886 current_call
887 .content
888 .push(ToolCallContent::Terminal(update.terminal));
889 }
890 }
891
892 cx.emit(AcpThreadEvent::EntryUpdated(ix));
893
894 Ok(())
895 }
896
897 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
898 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
899 let status = ToolCallStatus::Allowed {
900 status: tool_call.status,
901 };
902 self.upsert_tool_call_inner(tool_call, status, cx)
903 }
904
905 pub fn upsert_tool_call_inner(
906 &mut self,
907 tool_call: acp::ToolCall,
908 status: ToolCallStatus,
909 cx: &mut Context<Self>,
910 ) {
911 let language_registry = self.project.read(cx).languages().clone();
912 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
913 let id = call.id.clone();
914
915 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
916 *current_call = call;
917
918 cx.emit(AcpThreadEvent::EntryUpdated(ix));
919 } else {
920 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
921 };
922
923 self.resolve_locations(id, cx);
924 }
925
926 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
927 // The tool call we are looking for is typically the last one, or very close to the end.
928 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
929 self.entries
930 .iter_mut()
931 .enumerate()
932 .rev()
933 .find_map(|(index, tool_call)| {
934 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
935 && &tool_call.id == id
936 {
937 Some((index, tool_call))
938 } else {
939 None
940 }
941 })
942 }
943
944 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
945 let project = self.project.clone();
946 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
947 return;
948 };
949 let task = tool_call.resolve_locations(project, cx);
950 cx.spawn(async move |this, cx| {
951 let resolved_locations = task.await;
952 this.update(cx, |this, cx| {
953 let project = this.project.clone();
954 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
955 return;
956 };
957 if let Some(Some(location)) = resolved_locations.last() {
958 project.update(cx, |project, cx| {
959 if let Some(agent_location) = project.agent_location() {
960 let should_ignore = agent_location.buffer == location.buffer
961 && location
962 .buffer
963 .update(cx, |buffer, _| {
964 let snapshot = buffer.snapshot();
965 let old_position =
966 agent_location.position.to_point(&snapshot);
967 let new_position = location.position.to_point(&snapshot);
968 // ignore this so that when we get updates from the edit tool
969 // the position doesn't reset to the startof line
970 old_position.row == new_position.row
971 && old_position.column > new_position.column
972 })
973 .ok()
974 .unwrap_or_default();
975 if !should_ignore {
976 project.set_agent_location(Some(location.clone()), cx);
977 }
978 }
979 });
980 }
981 if tool_call.resolved_locations != resolved_locations {
982 tool_call.resolved_locations = resolved_locations;
983 cx.emit(AcpThreadEvent::EntryUpdated(ix));
984 }
985 })
986 })
987 .detach();
988 }
989
990 pub fn request_tool_call_authorization(
991 &mut self,
992 tool_call: acp::ToolCall,
993 options: Vec<acp::PermissionOption>,
994 cx: &mut Context<Self>,
995 ) -> oneshot::Receiver<acp::PermissionOptionId> {
996 let (tx, rx) = oneshot::channel();
997
998 let status = ToolCallStatus::WaitingForConfirmation {
999 options,
1000 respond_tx: tx,
1001 };
1002
1003 self.upsert_tool_call_inner(tool_call, status, cx);
1004 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1005 rx
1006 }
1007
1008 pub fn authorize_tool_call(
1009 &mut self,
1010 id: acp::ToolCallId,
1011 option_id: acp::PermissionOptionId,
1012 option_kind: acp::PermissionOptionKind,
1013 cx: &mut Context<Self>,
1014 ) {
1015 let Some((ix, call)) = self.tool_call_mut(&id) else {
1016 return;
1017 };
1018
1019 let new_status = match option_kind {
1020 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1021 ToolCallStatus::Rejected
1022 }
1023 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1024 ToolCallStatus::Allowed {
1025 status: acp::ToolCallStatus::InProgress,
1026 }
1027 }
1028 };
1029
1030 let curr_status = mem::replace(&mut call.status, new_status);
1031
1032 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1033 respond_tx.send(option_id).log_err();
1034 } else if cfg!(debug_assertions) {
1035 panic!("tried to authorize an already authorized tool call");
1036 }
1037
1038 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1039 }
1040
1041 /// Returns true if the last turn is awaiting tool authorization
1042 pub fn waiting_for_tool_confirmation(&self) -> bool {
1043 for entry in self.entries.iter().rev() {
1044 match &entry {
1045 AgentThreadEntry::ToolCall(call) => match call.status {
1046 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1047 ToolCallStatus::Allowed { .. }
1048 | ToolCallStatus::Rejected
1049 | ToolCallStatus::Canceled => continue,
1050 },
1051 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1052 // Reached the beginning of the turn
1053 return false;
1054 }
1055 }
1056 }
1057 false
1058 }
1059
1060 pub fn plan(&self) -> &Plan {
1061 &self.plan
1062 }
1063
1064 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1065 let new_entries_len = request.entries.len();
1066 let mut new_entries = request.entries.into_iter();
1067
1068 // Reuse existing markdown to prevent flickering
1069 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1070 let PlanEntry {
1071 content,
1072 priority,
1073 status,
1074 } = old;
1075 content.update(cx, |old, cx| {
1076 old.replace(new.content, cx);
1077 });
1078 *priority = new.priority;
1079 *status = new.status;
1080 }
1081 for new in new_entries {
1082 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1083 }
1084 self.plan.entries.truncate(new_entries_len);
1085
1086 cx.notify();
1087 }
1088
1089 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1090 self.plan
1091 .entries
1092 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1093 cx.notify();
1094 }
1095
1096 #[cfg(any(test, feature = "test-support"))]
1097 pub fn send_raw(
1098 &mut self,
1099 message: &str,
1100 cx: &mut Context<Self>,
1101 ) -> BoxFuture<'static, Result<()>> {
1102 self.send(
1103 vec![acp::ContentBlock::Text(acp::TextContent {
1104 text: message.to_string(),
1105 annotations: None,
1106 })],
1107 cx,
1108 )
1109 }
1110
1111 pub fn send(
1112 &mut self,
1113 message: Vec<acp::ContentBlock>,
1114 cx: &mut Context<Self>,
1115 ) -> BoxFuture<'static, Result<()>> {
1116 let block = ContentBlock::new_combined(
1117 message.clone(),
1118 self.project.read(cx).languages().clone(),
1119 cx,
1120 );
1121 self.push_entry(
1122 AgentThreadEntry::UserMessage(UserMessage { content: block }),
1123 cx,
1124 );
1125 self.clear_completed_plan_entries(cx);
1126
1127 let (tx, rx) = oneshot::channel();
1128 let cancel_task = self.cancel(cx);
1129
1130 self.send_task = Some(cx.spawn(async move |this, cx| {
1131 async {
1132 cancel_task.await;
1133
1134 let result = this
1135 .update(cx, |this, cx| {
1136 this.connection.prompt(
1137 acp::PromptRequest {
1138 prompt: message,
1139 session_id: this.session_id.clone(),
1140 },
1141 cx,
1142 )
1143 })?
1144 .await;
1145
1146 tx.send(result).log_err();
1147
1148 anyhow::Ok(())
1149 }
1150 .await
1151 .log_err();
1152 }));
1153
1154 cx.spawn(async move |this, cx| match rx.await {
1155 Ok(Err(e)) => {
1156 this.update(cx, |this, cx| {
1157 this.send_task.take();
1158 cx.emit(AcpThreadEvent::Error)
1159 })
1160 .log_err();
1161 Err(e)?
1162 }
1163 result => {
1164 let cancelled = matches!(
1165 result,
1166 Ok(Ok(acp::PromptResponse {
1167 stop_reason: acp::StopReason::Cancelled
1168 }))
1169 );
1170
1171 // We only take the task if the current prompt wasn't cancelled.
1172 //
1173 // This prompt may have been cancelled because another one was sent
1174 // while it was still generating. In these cases, dropping `send_task`
1175 // would cause the next generation to be cancelled.
1176 if !cancelled {
1177 this.update(cx, |this, _cx| this.send_task.take()).ok();
1178 }
1179
1180 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1181 .log_err();
1182 Ok(())
1183 }
1184 })
1185 .boxed()
1186 }
1187
1188 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1189 let Some(send_task) = self.send_task.take() else {
1190 return Task::ready(());
1191 };
1192
1193 for entry in self.entries.iter_mut() {
1194 if let AgentThreadEntry::ToolCall(call) = entry {
1195 let cancel = matches!(
1196 call.status,
1197 ToolCallStatus::WaitingForConfirmation { .. }
1198 | ToolCallStatus::Allowed {
1199 status: acp::ToolCallStatus::InProgress
1200 }
1201 );
1202
1203 if cancel {
1204 call.status = ToolCallStatus::Canceled;
1205 }
1206 }
1207 }
1208
1209 self.connection.cancel(&self.session_id, cx);
1210
1211 // Wait for the send task to complete
1212 cx.foreground_executor().spawn(send_task)
1213 }
1214
1215 pub fn read_text_file(
1216 &self,
1217 path: PathBuf,
1218 line: Option<u32>,
1219 limit: Option<u32>,
1220 reuse_shared_snapshot: bool,
1221 cx: &mut Context<Self>,
1222 ) -> Task<Result<String>> {
1223 let project = self.project.clone();
1224 let action_log = self.action_log.clone();
1225 cx.spawn(async move |this, cx| {
1226 let load = project.update(cx, |project, cx| {
1227 let path = project
1228 .project_path_for_absolute_path(&path, cx)
1229 .context("invalid path")?;
1230 anyhow::Ok(project.open_buffer(path, cx))
1231 });
1232 let buffer = load??.await?;
1233
1234 let snapshot = if reuse_shared_snapshot {
1235 this.read_with(cx, |this, _| {
1236 this.shared_buffers.get(&buffer.clone()).cloned()
1237 })
1238 .log_err()
1239 .flatten()
1240 } else {
1241 None
1242 };
1243
1244 let snapshot = if let Some(snapshot) = snapshot {
1245 snapshot
1246 } else {
1247 action_log.update(cx, |action_log, cx| {
1248 action_log.buffer_read(buffer.clone(), cx);
1249 })?;
1250 project.update(cx, |project, cx| {
1251 let position = buffer
1252 .read(cx)
1253 .snapshot()
1254 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1255 project.set_agent_location(
1256 Some(AgentLocation {
1257 buffer: buffer.downgrade(),
1258 position,
1259 }),
1260 cx,
1261 );
1262 })?;
1263
1264 buffer.update(cx, |buffer, _| buffer.snapshot())?
1265 };
1266
1267 this.update(cx, |this, _| {
1268 let text = snapshot.text();
1269 this.shared_buffers.insert(buffer.clone(), snapshot);
1270 if line.is_none() && limit.is_none() {
1271 return Ok(text);
1272 }
1273 let limit = limit.unwrap_or(u32::MAX) as usize;
1274 let Some(line) = line else {
1275 return Ok(text.lines().take(limit).collect::<String>());
1276 };
1277
1278 let count = text.lines().count();
1279 if count < line as usize {
1280 anyhow::bail!("There are only {} lines", count);
1281 }
1282 Ok(text
1283 .lines()
1284 .skip(line as usize + 1)
1285 .take(limit)
1286 .collect::<String>())
1287 })?
1288 })
1289 }
1290
1291 pub fn write_text_file(
1292 &self,
1293 path: PathBuf,
1294 content: String,
1295 cx: &mut Context<Self>,
1296 ) -> Task<Result<()>> {
1297 let project = self.project.clone();
1298 let action_log = self.action_log.clone();
1299 cx.spawn(async move |this, cx| {
1300 let load = project.update(cx, |project, cx| {
1301 let path = project
1302 .project_path_for_absolute_path(&path, cx)
1303 .context("invalid path")?;
1304 anyhow::Ok(project.open_buffer(path, cx))
1305 });
1306 let buffer = load??.await?;
1307 let snapshot = this.update(cx, |this, cx| {
1308 this.shared_buffers
1309 .get(&buffer)
1310 .cloned()
1311 .unwrap_or_else(|| buffer.read(cx).snapshot())
1312 })?;
1313 let edits = cx
1314 .background_executor()
1315 .spawn(async move {
1316 let old_text = snapshot.text();
1317 text_diff(old_text.as_str(), &content)
1318 .into_iter()
1319 .map(|(range, replacement)| {
1320 (
1321 snapshot.anchor_after(range.start)
1322 ..snapshot.anchor_before(range.end),
1323 replacement,
1324 )
1325 })
1326 .collect::<Vec<_>>()
1327 })
1328 .await;
1329 cx.update(|cx| {
1330 project.update(cx, |project, cx| {
1331 project.set_agent_location(
1332 Some(AgentLocation {
1333 buffer: buffer.downgrade(),
1334 position: edits
1335 .last()
1336 .map(|(range, _)| range.end)
1337 .unwrap_or(Anchor::MIN),
1338 }),
1339 cx,
1340 );
1341 });
1342
1343 action_log.update(cx, |action_log, cx| {
1344 action_log.buffer_read(buffer.clone(), cx);
1345 });
1346 buffer.update(cx, |buffer, cx| {
1347 buffer.edit(edits, None, cx);
1348 });
1349 action_log.update(cx, |action_log, cx| {
1350 action_log.buffer_edited(buffer.clone(), cx);
1351 });
1352 })?;
1353 project
1354 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1355 .await
1356 })
1357 }
1358
1359 pub fn to_markdown(&self, cx: &App) -> String {
1360 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1361 }
1362
1363 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1364 cx.emit(AcpThreadEvent::ServerExited(status));
1365 }
1366}
1367
1368fn markdown_for_raw_output(
1369 raw_output: &serde_json::Value,
1370 language_registry: &Arc<LanguageRegistry>,
1371 cx: &mut App,
1372) -> Option<Entity<Markdown>> {
1373 match raw_output {
1374 serde_json::Value::Null => None,
1375 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1376 Markdown::new(
1377 value.to_string().into(),
1378 Some(language_registry.clone()),
1379 None,
1380 cx,
1381 )
1382 })),
1383 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1384 Markdown::new(
1385 value.to_string().into(),
1386 Some(language_registry.clone()),
1387 None,
1388 cx,
1389 )
1390 })),
1391 serde_json::Value::String(value) => Some(cx.new(|cx| {
1392 Markdown::new(
1393 value.clone().into(),
1394 Some(language_registry.clone()),
1395 None,
1396 cx,
1397 )
1398 })),
1399 value => Some(cx.new(|cx| {
1400 Markdown::new(
1401 format!("```json\n{}\n```", value).into(),
1402 Some(language_registry.clone()),
1403 None,
1404 cx,
1405 )
1406 })),
1407 }
1408}
1409
1410#[cfg(test)]
1411mod tests {
1412 use super::*;
1413 use anyhow::anyhow;
1414 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1415 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1416 use indoc::indoc;
1417 use project::FakeFs;
1418 use rand::Rng as _;
1419 use serde_json::json;
1420 use settings::SettingsStore;
1421 use smol::stream::StreamExt as _;
1422 use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1423
1424 use util::path;
1425
1426 fn init_test(cx: &mut TestAppContext) {
1427 env_logger::try_init().ok();
1428 cx.update(|cx| {
1429 let settings_store = SettingsStore::test(cx);
1430 cx.set_global(settings_store);
1431 Project::init_settings(cx);
1432 language::init(cx);
1433 });
1434 }
1435
1436 #[gpui::test]
1437 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1438 init_test(cx);
1439
1440 let fs = FakeFs::new(cx.executor());
1441 let project = Project::test(fs, [], cx).await;
1442 let connection = Rc::new(FakeAgentConnection::new());
1443 let thread = cx
1444 .spawn(async move |mut cx| {
1445 connection
1446 .new_thread(project, Path::new(path!("/test")), &mut cx)
1447 .await
1448 })
1449 .await
1450 .unwrap();
1451
1452 // Test creating a new user message
1453 thread.update(cx, |thread, cx| {
1454 thread.push_user_content_block(
1455 acp::ContentBlock::Text(acp::TextContent {
1456 annotations: None,
1457 text: "Hello, ".to_string(),
1458 }),
1459 cx,
1460 );
1461 });
1462
1463 thread.update(cx, |thread, cx| {
1464 assert_eq!(thread.entries.len(), 1);
1465 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1466 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1467 } else {
1468 panic!("Expected UserMessage");
1469 }
1470 });
1471
1472 // Test appending to existing user message
1473 thread.update(cx, |thread, cx| {
1474 thread.push_user_content_block(
1475 acp::ContentBlock::Text(acp::TextContent {
1476 annotations: None,
1477 text: "world!".to_string(),
1478 }),
1479 cx,
1480 );
1481 });
1482
1483 thread.update(cx, |thread, cx| {
1484 assert_eq!(thread.entries.len(), 1);
1485 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1486 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1487 } else {
1488 panic!("Expected UserMessage");
1489 }
1490 });
1491
1492 // Test creating new user message after assistant message
1493 thread.update(cx, |thread, cx| {
1494 thread.push_assistant_content_block(
1495 acp::ContentBlock::Text(acp::TextContent {
1496 annotations: None,
1497 text: "Assistant response".to_string(),
1498 }),
1499 false,
1500 cx,
1501 );
1502 });
1503
1504 thread.update(cx, |thread, cx| {
1505 thread.push_user_content_block(
1506 acp::ContentBlock::Text(acp::TextContent {
1507 annotations: None,
1508 text: "New user message".to_string(),
1509 }),
1510 cx,
1511 );
1512 });
1513
1514 thread.update(cx, |thread, cx| {
1515 assert_eq!(thread.entries.len(), 3);
1516 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1517 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1518 } else {
1519 panic!("Expected UserMessage at index 2");
1520 }
1521 });
1522 }
1523
1524 #[gpui::test]
1525 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1526 init_test(cx);
1527
1528 let fs = FakeFs::new(cx.executor());
1529 let project = Project::test(fs, [], cx).await;
1530 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1531 |_, thread, mut cx| {
1532 async move {
1533 thread.update(&mut cx, |thread, cx| {
1534 thread
1535 .handle_session_update(
1536 acp::SessionUpdate::AgentThoughtChunk {
1537 content: "Thinking ".into(),
1538 },
1539 cx,
1540 )
1541 .unwrap();
1542 thread
1543 .handle_session_update(
1544 acp::SessionUpdate::AgentThoughtChunk {
1545 content: "hard!".into(),
1546 },
1547 cx,
1548 )
1549 .unwrap();
1550 })?;
1551 Ok(acp::PromptResponse {
1552 stop_reason: acp::StopReason::EndTurn,
1553 })
1554 }
1555 .boxed_local()
1556 },
1557 ));
1558
1559 let thread = cx
1560 .spawn(async move |mut cx| {
1561 connection
1562 .new_thread(project, Path::new(path!("/test")), &mut cx)
1563 .await
1564 })
1565 .await
1566 .unwrap();
1567
1568 thread
1569 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1570 .await
1571 .unwrap();
1572
1573 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1574 assert_eq!(
1575 output,
1576 indoc! {r#"
1577 ## User
1578
1579 Hello from Zed!
1580
1581 ## Assistant
1582
1583 <thinking>
1584 Thinking hard!
1585 </thinking>
1586
1587 "#}
1588 );
1589 }
1590
1591 #[gpui::test]
1592 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1593 init_test(cx);
1594
1595 let fs = FakeFs::new(cx.executor());
1596 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1597 .await;
1598 let project = Project::test(fs.clone(), [], cx).await;
1599 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1600 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1601 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1602 move |_, thread, mut cx| {
1603 let read_file_tx = read_file_tx.clone();
1604 async move {
1605 let content = thread
1606 .update(&mut cx, |thread, cx| {
1607 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1608 })
1609 .unwrap()
1610 .await
1611 .unwrap();
1612 assert_eq!(content, "one\ntwo\nthree\n");
1613 read_file_tx.take().unwrap().send(()).unwrap();
1614 thread
1615 .update(&mut cx, |thread, cx| {
1616 thread.write_text_file(
1617 path!("/tmp/foo").into(),
1618 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1619 cx,
1620 )
1621 })
1622 .unwrap()
1623 .await
1624 .unwrap();
1625 Ok(acp::PromptResponse {
1626 stop_reason: acp::StopReason::EndTurn,
1627 })
1628 }
1629 .boxed_local()
1630 },
1631 ));
1632
1633 let (worktree, pathbuf) = project
1634 .update(cx, |project, cx| {
1635 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1636 })
1637 .await
1638 .unwrap();
1639 let buffer = project
1640 .update(cx, |project, cx| {
1641 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1642 })
1643 .await
1644 .unwrap();
1645
1646 let thread = cx
1647 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1648 .await
1649 .unwrap();
1650
1651 let request = thread.update(cx, |thread, cx| {
1652 thread.send_raw("Extend the count in /tmp/foo", cx)
1653 });
1654 read_file_rx.await.ok();
1655 buffer.update(cx, |buffer, cx| {
1656 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1657 });
1658 cx.run_until_parked();
1659 assert_eq!(
1660 buffer.read_with(cx, |buffer, _| buffer.text()),
1661 "zero\none\ntwo\nthree\nfour\nfive\n"
1662 );
1663 assert_eq!(
1664 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1665 "zero\none\ntwo\nthree\nfour\nfive\n"
1666 );
1667 request.await.unwrap();
1668 }
1669
1670 #[gpui::test]
1671 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1672 init_test(cx);
1673
1674 let fs = FakeFs::new(cx.executor());
1675 let project = Project::test(fs, [], cx).await;
1676 let id = acp::ToolCallId("test".into());
1677
1678 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1679 let id = id.clone();
1680 move |_, thread, mut cx| {
1681 let id = id.clone();
1682 async move {
1683 thread
1684 .update(&mut cx, |thread, cx| {
1685 thread.handle_session_update(
1686 acp::SessionUpdate::ToolCall(acp::ToolCall {
1687 id: id.clone(),
1688 title: "Label".into(),
1689 kind: acp::ToolKind::Fetch,
1690 status: acp::ToolCallStatus::InProgress,
1691 content: vec![],
1692 locations: vec![],
1693 raw_input: None,
1694 raw_output: None,
1695 }),
1696 cx,
1697 )
1698 })
1699 .unwrap()
1700 .unwrap();
1701 Ok(acp::PromptResponse {
1702 stop_reason: acp::StopReason::EndTurn,
1703 })
1704 }
1705 .boxed_local()
1706 }
1707 }));
1708
1709 let thread = cx
1710 .spawn(async move |mut cx| {
1711 connection
1712 .new_thread(project, Path::new(path!("/test")), &mut cx)
1713 .await
1714 })
1715 .await
1716 .unwrap();
1717
1718 let request = thread.update(cx, |thread, cx| {
1719 thread.send_raw("Fetch https://example.com", cx)
1720 });
1721
1722 run_until_first_tool_call(&thread, cx).await;
1723
1724 thread.read_with(cx, |thread, _| {
1725 assert!(matches!(
1726 thread.entries[1],
1727 AgentThreadEntry::ToolCall(ToolCall {
1728 status: ToolCallStatus::Allowed {
1729 status: acp::ToolCallStatus::InProgress,
1730 ..
1731 },
1732 ..
1733 })
1734 ));
1735 });
1736
1737 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1738
1739 thread.read_with(cx, |thread, _| {
1740 assert!(matches!(
1741 &thread.entries[1],
1742 AgentThreadEntry::ToolCall(ToolCall {
1743 status: ToolCallStatus::Canceled,
1744 ..
1745 })
1746 ));
1747 });
1748
1749 thread
1750 .update(cx, |thread, cx| {
1751 thread.handle_session_update(
1752 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1753 id,
1754 fields: acp::ToolCallUpdateFields {
1755 status: Some(acp::ToolCallStatus::Completed),
1756 ..Default::default()
1757 },
1758 }),
1759 cx,
1760 )
1761 })
1762 .unwrap();
1763
1764 request.await.unwrap();
1765
1766 thread.read_with(cx, |thread, _| {
1767 assert!(matches!(
1768 thread.entries[1],
1769 AgentThreadEntry::ToolCall(ToolCall {
1770 status: ToolCallStatus::Allowed {
1771 status: acp::ToolCallStatus::Completed,
1772 ..
1773 },
1774 ..
1775 })
1776 ));
1777 });
1778 }
1779
1780 #[gpui::test]
1781 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1782 init_test(cx);
1783 let fs = FakeFs::new(cx.background_executor.clone());
1784 fs.insert_tree(path!("/test"), json!({})).await;
1785 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1786
1787 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1788 move |_, thread, mut cx| {
1789 async move {
1790 thread
1791 .update(&mut cx, |thread, cx| {
1792 thread.handle_session_update(
1793 acp::SessionUpdate::ToolCall(acp::ToolCall {
1794 id: acp::ToolCallId("test".into()),
1795 title: "Label".into(),
1796 kind: acp::ToolKind::Edit,
1797 status: acp::ToolCallStatus::Completed,
1798 content: vec![acp::ToolCallContent::Diff {
1799 diff: acp::Diff {
1800 path: "/test/test.txt".into(),
1801 old_text: None,
1802 new_text: "foo".into(),
1803 },
1804 }],
1805 locations: vec![],
1806 raw_input: None,
1807 raw_output: None,
1808 }),
1809 cx,
1810 )
1811 })
1812 .unwrap()
1813 .unwrap();
1814 Ok(acp::PromptResponse {
1815 stop_reason: acp::StopReason::EndTurn,
1816 })
1817 }
1818 .boxed_local()
1819 }
1820 }));
1821
1822 let thread = connection
1823 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1824 .await
1825 .unwrap();
1826 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1827 .await
1828 .unwrap();
1829
1830 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1831 }
1832
1833 async fn run_until_first_tool_call(
1834 thread: &Entity<AcpThread>,
1835 cx: &mut TestAppContext,
1836 ) -> usize {
1837 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1838
1839 let subscription = cx.update(|cx| {
1840 cx.subscribe(thread, move |thread, _, cx| {
1841 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1842 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1843 return tx.try_send(ix).unwrap();
1844 }
1845 }
1846 })
1847 });
1848
1849 select! {
1850 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1851 panic!("Timeout waiting for tool call")
1852 }
1853 ix = rx.next().fuse() => {
1854 drop(subscription);
1855 ix.unwrap()
1856 }
1857 }
1858 }
1859
1860 #[derive(Clone, Default)]
1861 struct FakeAgentConnection {
1862 auth_methods: Vec<acp::AuthMethod>,
1863 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1864 on_user_message: Option<
1865 Rc<
1866 dyn Fn(
1867 acp::PromptRequest,
1868 WeakEntity<AcpThread>,
1869 AsyncApp,
1870 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1871 + 'static,
1872 >,
1873 >,
1874 }
1875
1876 impl FakeAgentConnection {
1877 fn new() -> Self {
1878 Self {
1879 auth_methods: Vec::new(),
1880 on_user_message: None,
1881 sessions: Arc::default(),
1882 }
1883 }
1884
1885 #[expect(unused)]
1886 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1887 self.auth_methods = auth_methods;
1888 self
1889 }
1890
1891 fn on_user_message(
1892 mut self,
1893 handler: impl Fn(
1894 acp::PromptRequest,
1895 WeakEntity<AcpThread>,
1896 AsyncApp,
1897 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1898 + 'static,
1899 ) -> Self {
1900 self.on_user_message.replace(Rc::new(handler));
1901 self
1902 }
1903 }
1904
1905 impl AgentConnection for FakeAgentConnection {
1906 fn auth_methods(&self) -> &[acp::AuthMethod] {
1907 &self.auth_methods
1908 }
1909
1910 fn new_thread(
1911 self: Rc<Self>,
1912 project: Entity<Project>,
1913 _cwd: &Path,
1914 cx: &mut gpui::AsyncApp,
1915 ) -> Task<gpui::Result<Entity<AcpThread>>> {
1916 let session_id = acp::SessionId(
1917 rand::thread_rng()
1918 .sample_iter(&rand::distributions::Alphanumeric)
1919 .take(7)
1920 .map(char::from)
1921 .collect::<String>()
1922 .into(),
1923 );
1924 let thread = cx
1925 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1926 .unwrap();
1927 self.sessions.lock().insert(session_id, thread.downgrade());
1928 Task::ready(Ok(thread))
1929 }
1930
1931 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1932 if self.auth_methods().iter().any(|m| m.id == method) {
1933 Task::ready(Ok(()))
1934 } else {
1935 Task::ready(Err(anyhow!("Invalid Auth Method")))
1936 }
1937 }
1938
1939 fn prompt(
1940 &self,
1941 params: acp::PromptRequest,
1942 cx: &mut App,
1943 ) -> Task<gpui::Result<acp::PromptResponse>> {
1944 let sessions = self.sessions.lock();
1945 let thread = sessions.get(¶ms.session_id).unwrap();
1946 if let Some(handler) = &self.on_user_message {
1947 let handler = handler.clone();
1948 let thread = thread.clone();
1949 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1950 } else {
1951 Task::ready(Ok(acp::PromptResponse {
1952 stop_reason: acp::StopReason::EndTurn,
1953 }))
1954 }
1955 }
1956
1957 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1958 let sessions = self.sessions.lock();
1959 let thread = sessions.get(&session_id).unwrap().clone();
1960
1961 cx.spawn(async move |cx| {
1962 thread
1963 .update(cx, |thread, cx| thread.cancel(cx))
1964 .unwrap()
1965 .await
1966 })
1967 .detach();
1968 }
1969 }
1970}