1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub id: Option<UserMessageId>,
34 pub content: ContentBlock,
35 pub chunks: Vec<acp::ContentBlock>,
36 pub checkpoint: Option<GitStoreCheckpoint>,
37}
38
39impl UserMessage {
40 fn to_markdown(&self, cx: &App) -> String {
41 let mut markdown = String::new();
42 if let Some(_) = self.checkpoint {
43 writeln!(markdown, "## User (checkpoint)").unwrap();
44 } else {
45 writeln!(markdown, "## User").unwrap();
46 }
47 writeln!(markdown).unwrap();
48 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
49 writeln!(markdown).unwrap();
50 markdown
51 }
52}
53
54#[derive(Debug, PartialEq)]
55pub struct AssistantMessage {
56 pub chunks: Vec<AssistantMessageChunk>,
57}
58
59impl AssistantMessage {
60 pub fn to_markdown(&self, cx: &App) -> String {
61 format!(
62 "## Assistant\n\n{}\n\n",
63 self.chunks
64 .iter()
65 .map(|chunk| chunk.to_markdown(cx))
66 .join("\n\n")
67 )
68 }
69}
70
71#[derive(Debug, PartialEq)]
72pub enum AssistantMessageChunk {
73 Message { block: ContentBlock },
74 Thought { block: ContentBlock },
75}
76
77impl AssistantMessageChunk {
78 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
79 Self::Message {
80 block: ContentBlock::new(chunk.into(), language_registry, cx),
81 }
82 }
83
84 fn to_markdown(&self, cx: &App) -> String {
85 match self {
86 Self::Message { block } => block.to_markdown(cx).to_string(),
87 Self::Thought { block } => {
88 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
89 }
90 }
91 }
92}
93
94#[derive(Debug)]
95pub enum AgentThreadEntry {
96 UserMessage(UserMessage),
97 AssistantMessage(AssistantMessage),
98 ToolCall(ToolCall),
99}
100
101impl AgentThreadEntry {
102 fn to_markdown(&self, cx: &App) -> String {
103 match self {
104 Self::UserMessage(message) => message.to_markdown(cx),
105 Self::AssistantMessage(message) => message.to_markdown(cx),
106 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
107 }
108 }
109
110 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
111 if let AgentThreadEntry::ToolCall(call) = self {
112 itertools::Either::Left(call.diffs())
113 } else {
114 itertools::Either::Right(std::iter::empty())
115 }
116 }
117
118 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
119 if let AgentThreadEntry::ToolCall(call) = self {
120 itertools::Either::Left(call.terminals())
121 } else {
122 itertools::Either::Right(std::iter::empty())
123 }
124 }
125
126 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
127 if let AgentThreadEntry::ToolCall(ToolCall {
128 locations,
129 resolved_locations,
130 ..
131 }) = self
132 {
133 Some((
134 locations.get(ix)?.clone(),
135 resolved_locations.get(ix)?.clone()?,
136 ))
137 } else {
138 None
139 }
140 }
141}
142
143#[derive(Debug)]
144pub struct ToolCall {
145 pub id: acp::ToolCallId,
146 pub label: Entity<Markdown>,
147 pub kind: acp::ToolKind,
148 pub content: Vec<ToolCallContent>,
149 pub status: ToolCallStatus,
150 pub locations: Vec<acp::ToolCallLocation>,
151 pub resolved_locations: Vec<Option<AgentLocation>>,
152 pub raw_input: Option<serde_json::Value>,
153 pub raw_output: Option<serde_json::Value>,
154}
155
156impl ToolCall {
157 fn from_acp(
158 tool_call: acp::ToolCall,
159 status: ToolCallStatus,
160 language_registry: Arc<LanguageRegistry>,
161 cx: &mut App,
162 ) -> Self {
163 Self {
164 id: tool_call.id,
165 label: cx.new(|cx| {
166 Markdown::new(
167 tool_call.title.into(),
168 Some(language_registry.clone()),
169 None,
170 cx,
171 )
172 }),
173 kind: tool_call.kind,
174 content: tool_call
175 .content
176 .into_iter()
177 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
178 .collect(),
179 locations: tool_call.locations,
180 resolved_locations: Vec::default(),
181 status,
182 raw_input: tool_call.raw_input,
183 raw_output: tool_call.raw_output,
184 }
185 }
186
187 fn update_fields(
188 &mut self,
189 fields: acp::ToolCallUpdateFields,
190 language_registry: Arc<LanguageRegistry>,
191 cx: &mut App,
192 ) {
193 let acp::ToolCallUpdateFields {
194 kind,
195 status,
196 title,
197 content,
198 locations,
199 raw_input,
200 raw_output,
201 } = fields;
202
203 if let Some(kind) = kind {
204 self.kind = kind;
205 }
206
207 if let Some(status) = status {
208 self.status = ToolCallStatus::Allowed { status };
209 }
210
211 if let Some(title) = title {
212 self.label.update(cx, |label, cx| {
213 label.replace(title, cx);
214 });
215 }
216
217 if let Some(content) = content {
218 self.content = content
219 .into_iter()
220 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
221 .collect();
222 }
223
224 if let Some(locations) = locations {
225 self.locations = locations;
226 }
227
228 if let Some(raw_input) = raw_input {
229 self.raw_input = Some(raw_input);
230 }
231
232 if let Some(raw_output) = raw_output {
233 if self.content.is_empty() {
234 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
235 {
236 self.content
237 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
238 markdown,
239 }));
240 }
241 }
242 self.raw_output = Some(raw_output);
243 }
244 }
245
246 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
247 self.content.iter().filter_map(|content| match content {
248 ToolCallContent::Diff(diff) => Some(diff),
249 ToolCallContent::ContentBlock(_) => None,
250 ToolCallContent::Terminal(_) => None,
251 })
252 }
253
254 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
255 self.content.iter().filter_map(|content| match content {
256 ToolCallContent::Terminal(terminal) => Some(terminal),
257 ToolCallContent::ContentBlock(_) => None,
258 ToolCallContent::Diff(_) => None,
259 })
260 }
261
262 fn to_markdown(&self, cx: &App) -> String {
263 let mut markdown = format!(
264 "**Tool Call: {}**\nStatus: {}\n\n",
265 self.label.read(cx).source(),
266 self.status
267 );
268 for content in &self.content {
269 markdown.push_str(content.to_markdown(cx).as_str());
270 markdown.push_str("\n\n");
271 }
272 markdown
273 }
274
275 async fn resolve_location(
276 location: acp::ToolCallLocation,
277 project: WeakEntity<Project>,
278 cx: &mut AsyncApp,
279 ) -> Option<AgentLocation> {
280 let buffer = project
281 .update(cx, |project, cx| {
282 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
283 Some(project.open_buffer(path, cx))
284 } else {
285 None
286 }
287 })
288 .ok()??;
289 let buffer = buffer.await.log_err()?;
290 let position = buffer
291 .update(cx, |buffer, _| {
292 if let Some(row) = location.line {
293 let snapshot = buffer.snapshot();
294 let column = snapshot.indent_size_for_line(row).len;
295 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
296 snapshot.anchor_before(point)
297 } else {
298 Anchor::MIN
299 }
300 })
301 .ok()?;
302
303 Some(AgentLocation {
304 buffer: buffer.downgrade(),
305 position,
306 })
307 }
308
309 fn resolve_locations(
310 &self,
311 project: Entity<Project>,
312 cx: &mut App,
313 ) -> Task<Vec<Option<AgentLocation>>> {
314 let locations = self.locations.clone();
315 project.update(cx, |_, cx| {
316 cx.spawn(async move |project, cx| {
317 let mut new_locations = Vec::new();
318 for location in locations {
319 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
320 }
321 new_locations
322 })
323 })
324 }
325}
326
327#[derive(Debug)]
328pub enum ToolCallStatus {
329 WaitingForConfirmation {
330 options: Vec<acp::PermissionOption>,
331 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
332 },
333 Allowed {
334 status: acp::ToolCallStatus,
335 },
336 Rejected,
337 Canceled,
338}
339
340impl Display for ToolCallStatus {
341 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
342 write!(
343 f,
344 "{}",
345 match self {
346 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
347 ToolCallStatus::Allowed { status } => match status {
348 acp::ToolCallStatus::Pending => "Pending",
349 acp::ToolCallStatus::InProgress => "In Progress",
350 acp::ToolCallStatus::Completed => "Completed",
351 acp::ToolCallStatus::Failed => "Failed",
352 },
353 ToolCallStatus::Rejected => "Rejected",
354 ToolCallStatus::Canceled => "Canceled",
355 }
356 )
357 }
358}
359
360#[derive(Debug, PartialEq, Clone)]
361pub enum ContentBlock {
362 Empty,
363 Markdown { markdown: Entity<Markdown> },
364 ResourceLink { resource_link: acp::ResourceLink },
365}
366
367impl ContentBlock {
368 pub fn new(
369 block: acp::ContentBlock,
370 language_registry: &Arc<LanguageRegistry>,
371 cx: &mut App,
372 ) -> Self {
373 let mut this = Self::Empty;
374 this.append(block, language_registry, cx);
375 this
376 }
377
378 pub fn new_combined(
379 blocks: impl IntoIterator<Item = acp::ContentBlock>,
380 language_registry: Arc<LanguageRegistry>,
381 cx: &mut App,
382 ) -> Self {
383 let mut this = Self::Empty;
384 for block in blocks {
385 this.append(block, &language_registry, cx);
386 }
387 this
388 }
389
390 pub fn append(
391 &mut self,
392 block: acp::ContentBlock,
393 language_registry: &Arc<LanguageRegistry>,
394 cx: &mut App,
395 ) {
396 if matches!(self, ContentBlock::Empty) {
397 if let acp::ContentBlock::ResourceLink(resource_link) = block {
398 *self = ContentBlock::ResourceLink { resource_link };
399 return;
400 }
401 }
402
403 let new_content = self.block_string_contents(block);
404
405 match self {
406 ContentBlock::Empty => {
407 *self = Self::create_markdown_block(new_content, language_registry, cx);
408 }
409 ContentBlock::Markdown { markdown } => {
410 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
411 }
412 ContentBlock::ResourceLink { resource_link } => {
413 let existing_content = Self::resource_link_md(&resource_link.uri);
414 let combined = format!("{}\n{}", existing_content, new_content);
415
416 *self = Self::create_markdown_block(combined, language_registry, cx);
417 }
418 }
419 }
420
421 fn create_markdown_block(
422 content: String,
423 language_registry: &Arc<LanguageRegistry>,
424 cx: &mut App,
425 ) -> ContentBlock {
426 ContentBlock::Markdown {
427 markdown: cx
428 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
429 }
430 }
431
432 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
433 match block {
434 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
435 acp::ContentBlock::ResourceLink(resource_link) => {
436 Self::resource_link_md(&resource_link.uri)
437 }
438 acp::ContentBlock::Resource(acp::EmbeddedResource {
439 resource:
440 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
441 uri,
442 ..
443 }),
444 ..
445 }) => Self::resource_link_md(&uri),
446 acp::ContentBlock::Image(_)
447 | acp::ContentBlock::Audio(_)
448 | acp::ContentBlock::Resource(_) => String::new(),
449 }
450 }
451
452 fn resource_link_md(uri: &str) -> String {
453 if let Some(uri) = MentionUri::parse(&uri).log_err() {
454 uri.as_link().to_string()
455 } else {
456 uri.to_string()
457 }
458 }
459
460 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
461 match self {
462 ContentBlock::Empty => "",
463 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
464 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
465 }
466 }
467
468 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
469 match self {
470 ContentBlock::Empty => None,
471 ContentBlock::Markdown { markdown } => Some(markdown),
472 ContentBlock::ResourceLink { .. } => None,
473 }
474 }
475
476 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
477 match self {
478 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
479 _ => None,
480 }
481 }
482}
483
484#[derive(Debug)]
485pub enum ToolCallContent {
486 ContentBlock(ContentBlock),
487 Diff(Entity<Diff>),
488 Terminal(Entity<Terminal>),
489}
490
491impl ToolCallContent {
492 pub fn from_acp(
493 content: acp::ToolCallContent,
494 language_registry: Arc<LanguageRegistry>,
495 cx: &mut App,
496 ) -> Self {
497 match content {
498 acp::ToolCallContent::Content { content } => {
499 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
500 }
501 acp::ToolCallContent::Diff { diff } => {
502 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
503 }
504 }
505 }
506
507 pub fn to_markdown(&self, cx: &App) -> String {
508 match self {
509 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
510 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
511 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
512 }
513 }
514}
515
516#[derive(Debug, PartialEq)]
517pub enum ToolCallUpdate {
518 UpdateFields(acp::ToolCallUpdate),
519 UpdateDiff(ToolCallUpdateDiff),
520 UpdateTerminal(ToolCallUpdateTerminal),
521}
522
523impl ToolCallUpdate {
524 fn id(&self) -> &acp::ToolCallId {
525 match self {
526 Self::UpdateFields(update) => &update.id,
527 Self::UpdateDiff(diff) => &diff.id,
528 Self::UpdateTerminal(terminal) => &terminal.id,
529 }
530 }
531}
532
533impl From<acp::ToolCallUpdate> for ToolCallUpdate {
534 fn from(update: acp::ToolCallUpdate) -> Self {
535 Self::UpdateFields(update)
536 }
537}
538
539impl From<ToolCallUpdateDiff> for ToolCallUpdate {
540 fn from(diff: ToolCallUpdateDiff) -> Self {
541 Self::UpdateDiff(diff)
542 }
543}
544
545#[derive(Debug, PartialEq)]
546pub struct ToolCallUpdateDiff {
547 pub id: acp::ToolCallId,
548 pub diff: Entity<Diff>,
549}
550
551impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
552 fn from(terminal: ToolCallUpdateTerminal) -> Self {
553 Self::UpdateTerminal(terminal)
554 }
555}
556
557#[derive(Debug, PartialEq)]
558pub struct ToolCallUpdateTerminal {
559 pub id: acp::ToolCallId,
560 pub terminal: Entity<Terminal>,
561}
562
563#[derive(Debug, Default)]
564pub struct Plan {
565 pub entries: Vec<PlanEntry>,
566}
567
568#[derive(Debug)]
569pub struct PlanStats<'a> {
570 pub in_progress_entry: Option<&'a PlanEntry>,
571 pub pending: u32,
572 pub completed: u32,
573}
574
575impl Plan {
576 pub fn is_empty(&self) -> bool {
577 self.entries.is_empty()
578 }
579
580 pub fn stats(&self) -> PlanStats<'_> {
581 let mut stats = PlanStats {
582 in_progress_entry: None,
583 pending: 0,
584 completed: 0,
585 };
586
587 for entry in &self.entries {
588 match &entry.status {
589 acp::PlanEntryStatus::Pending => {
590 stats.pending += 1;
591 }
592 acp::PlanEntryStatus::InProgress => {
593 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
594 }
595 acp::PlanEntryStatus::Completed => {
596 stats.completed += 1;
597 }
598 }
599 }
600
601 stats
602 }
603}
604
605#[derive(Debug)]
606pub struct PlanEntry {
607 pub content: Entity<Markdown>,
608 pub priority: acp::PlanEntryPriority,
609 pub status: acp::PlanEntryStatus,
610}
611
612impl PlanEntry {
613 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
614 Self {
615 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
616 priority: entry.priority,
617 status: entry.status,
618 }
619 }
620}
621
622pub struct AcpThread {
623 title: SharedString,
624 entries: Vec<AgentThreadEntry>,
625 plan: Plan,
626 project: Entity<Project>,
627 action_log: Entity<ActionLog>,
628 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
629 send_task: Option<Task<()>>,
630 connection: Rc<dyn AgentConnection>,
631 session_id: acp::SessionId,
632}
633
634pub enum AcpThreadEvent {
635 NewEntry,
636 EntryUpdated(usize),
637 EntriesRemoved(Range<usize>),
638 ToolAuthorizationRequired,
639 Stopped,
640 Error,
641 ServerExited(ExitStatus),
642}
643
644impl EventEmitter<AcpThreadEvent> for AcpThread {}
645
646#[derive(PartialEq, Eq)]
647pub enum ThreadStatus {
648 Idle,
649 WaitingForToolConfirmation,
650 Generating,
651}
652
653#[derive(Debug, Clone)]
654pub enum LoadError {
655 Unsupported {
656 error_message: SharedString,
657 upgrade_message: SharedString,
658 upgrade_command: String,
659 },
660 Exited(i32),
661 Other(SharedString),
662}
663
664impl Display for LoadError {
665 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
666 match self {
667 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
668 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
669 LoadError::Other(msg) => write!(f, "{}", msg),
670 }
671 }
672}
673
674impl Error for LoadError {}
675
676impl AcpThread {
677 pub fn new(
678 title: impl Into<SharedString>,
679 connection: Rc<dyn AgentConnection>,
680 project: Entity<Project>,
681 session_id: acp::SessionId,
682 cx: &mut Context<Self>,
683 ) -> Self {
684 let action_log = cx.new(|_| ActionLog::new(project.clone()));
685
686 Self {
687 action_log,
688 shared_buffers: Default::default(),
689 entries: Default::default(),
690 plan: Default::default(),
691 title: title.into(),
692 project,
693 send_task: None,
694 connection,
695 session_id,
696 }
697 }
698
699 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
700 &self.connection
701 }
702
703 pub fn action_log(&self) -> &Entity<ActionLog> {
704 &self.action_log
705 }
706
707 pub fn project(&self) -> &Entity<Project> {
708 &self.project
709 }
710
711 pub fn title(&self) -> SharedString {
712 self.title.clone()
713 }
714
715 pub fn entries(&self) -> &[AgentThreadEntry] {
716 &self.entries
717 }
718
719 pub fn session_id(&self) -> &acp::SessionId {
720 &self.session_id
721 }
722
723 pub fn status(&self) -> ThreadStatus {
724 if self.send_task.is_some() {
725 if self.waiting_for_tool_confirmation() {
726 ThreadStatus::WaitingForToolConfirmation
727 } else {
728 ThreadStatus::Generating
729 }
730 } else {
731 ThreadStatus::Idle
732 }
733 }
734
735 pub fn has_pending_edit_tool_calls(&self) -> bool {
736 for entry in self.entries.iter().rev() {
737 match entry {
738 AgentThreadEntry::UserMessage(_) => return false,
739 AgentThreadEntry::ToolCall(
740 call @ ToolCall {
741 status:
742 ToolCallStatus::Allowed {
743 status:
744 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
745 },
746 ..
747 },
748 ) if call.diffs().next().is_some() => {
749 return true;
750 }
751 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
752 }
753 }
754
755 false
756 }
757
758 pub fn used_tools_since_last_user_message(&self) -> bool {
759 for entry in self.entries.iter().rev() {
760 match entry {
761 AgentThreadEntry::UserMessage(..) => return false,
762 AgentThreadEntry::AssistantMessage(..) => continue,
763 AgentThreadEntry::ToolCall(..) => return true,
764 }
765 }
766
767 false
768 }
769
770 pub fn handle_session_update(
771 &mut self,
772 update: acp::SessionUpdate,
773 cx: &mut Context<Self>,
774 ) -> Result<()> {
775 match update {
776 acp::SessionUpdate::UserMessageChunk { content } => {
777 self.push_user_content_block(None, content, cx);
778 }
779 acp::SessionUpdate::AgentMessageChunk { content } => {
780 self.push_assistant_content_block(content, false, cx);
781 }
782 acp::SessionUpdate::AgentThoughtChunk { content } => {
783 self.push_assistant_content_block(content, true, cx);
784 }
785 acp::SessionUpdate::ToolCall(tool_call) => {
786 self.upsert_tool_call(tool_call, cx);
787 }
788 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
789 self.update_tool_call(tool_call_update, cx)?;
790 }
791 acp::SessionUpdate::Plan(plan) => {
792 self.update_plan(plan, cx);
793 }
794 }
795 Ok(())
796 }
797
798 pub fn push_user_content_block(
799 &mut self,
800 message_id: Option<UserMessageId>,
801 chunk: acp::ContentBlock,
802 cx: &mut Context<Self>,
803 ) {
804 let language_registry = self.project.read(cx).languages().clone();
805 let entries_len = self.entries.len();
806
807 if let Some(last_entry) = self.entries.last_mut()
808 && let AgentThreadEntry::UserMessage(UserMessage {
809 id,
810 content,
811 chunks,
812 ..
813 }) = last_entry
814 {
815 *id = message_id.or(id.take());
816 content.append(chunk.clone(), &language_registry, cx);
817 chunks.push(chunk);
818 let idx = entries_len - 1;
819 cx.emit(AcpThreadEvent::EntryUpdated(idx));
820 } else {
821 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
822 self.push_entry(
823 AgentThreadEntry::UserMessage(UserMessage {
824 id: message_id,
825 content,
826 chunks: vec![chunk],
827 checkpoint: None,
828 }),
829 cx,
830 );
831 }
832 }
833
834 pub fn push_assistant_content_block(
835 &mut self,
836 chunk: acp::ContentBlock,
837 is_thought: bool,
838 cx: &mut Context<Self>,
839 ) {
840 let language_registry = self.project.read(cx).languages().clone();
841 let entries_len = self.entries.len();
842 if let Some(last_entry) = self.entries.last_mut()
843 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
844 {
845 let idx = entries_len - 1;
846 cx.emit(AcpThreadEvent::EntryUpdated(idx));
847 match (chunks.last_mut(), is_thought) {
848 (Some(AssistantMessageChunk::Message { block }), false)
849 | (Some(AssistantMessageChunk::Thought { block }), true) => {
850 block.append(chunk, &language_registry, cx)
851 }
852 _ => {
853 let block = ContentBlock::new(chunk, &language_registry, cx);
854 if is_thought {
855 chunks.push(AssistantMessageChunk::Thought { block })
856 } else {
857 chunks.push(AssistantMessageChunk::Message { block })
858 }
859 }
860 }
861 } else {
862 let block = ContentBlock::new(chunk, &language_registry, cx);
863 let chunk = if is_thought {
864 AssistantMessageChunk::Thought { block }
865 } else {
866 AssistantMessageChunk::Message { block }
867 };
868
869 self.push_entry(
870 AgentThreadEntry::AssistantMessage(AssistantMessage {
871 chunks: vec![chunk],
872 }),
873 cx,
874 );
875 }
876 }
877
878 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
879 self.entries.push(entry);
880 cx.emit(AcpThreadEvent::NewEntry);
881 }
882
883 pub fn update_tool_call(
884 &mut self,
885 update: impl Into<ToolCallUpdate>,
886 cx: &mut Context<Self>,
887 ) -> Result<()> {
888 let update = update.into();
889 let languages = self.project.read(cx).languages().clone();
890
891 let (ix, current_call) = self
892 .tool_call_mut(update.id())
893 .context("Tool call not found")?;
894 match update {
895 ToolCallUpdate::UpdateFields(update) => {
896 let location_updated = update.fields.locations.is_some();
897 current_call.update_fields(update.fields, languages, cx);
898 if location_updated {
899 self.resolve_locations(update.id.clone(), cx);
900 }
901 }
902 ToolCallUpdate::UpdateDiff(update) => {
903 current_call.content.clear();
904 current_call
905 .content
906 .push(ToolCallContent::Diff(update.diff));
907 }
908 ToolCallUpdate::UpdateTerminal(update) => {
909 current_call.content.clear();
910 current_call
911 .content
912 .push(ToolCallContent::Terminal(update.terminal));
913 }
914 }
915
916 cx.emit(AcpThreadEvent::EntryUpdated(ix));
917
918 Ok(())
919 }
920
921 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
922 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
923 let status = ToolCallStatus::Allowed {
924 status: tool_call.status,
925 };
926 self.upsert_tool_call_inner(tool_call, status, cx)
927 }
928
929 pub fn upsert_tool_call_inner(
930 &mut self,
931 tool_call: acp::ToolCall,
932 status: ToolCallStatus,
933 cx: &mut Context<Self>,
934 ) {
935 let language_registry = self.project.read(cx).languages().clone();
936 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
937 let id = call.id.clone();
938
939 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
940 *current_call = call;
941
942 cx.emit(AcpThreadEvent::EntryUpdated(ix));
943 } else {
944 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
945 };
946
947 self.resolve_locations(id, cx);
948 }
949
950 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
951 // The tool call we are looking for is typically the last one, or very close to the end.
952 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
953 self.entries
954 .iter_mut()
955 .enumerate()
956 .rev()
957 .find_map(|(index, tool_call)| {
958 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
959 && &tool_call.id == id
960 {
961 Some((index, tool_call))
962 } else {
963 None
964 }
965 })
966 }
967
968 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
969 let project = self.project.clone();
970 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
971 return;
972 };
973 let task = tool_call.resolve_locations(project, cx);
974 cx.spawn(async move |this, cx| {
975 let resolved_locations = task.await;
976 this.update(cx, |this, cx| {
977 let project = this.project.clone();
978 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
979 return;
980 };
981 if let Some(Some(location)) = resolved_locations.last() {
982 project.update(cx, |project, cx| {
983 if let Some(agent_location) = project.agent_location() {
984 let should_ignore = agent_location.buffer == location.buffer
985 && location
986 .buffer
987 .update(cx, |buffer, _| {
988 let snapshot = buffer.snapshot();
989 let old_position =
990 agent_location.position.to_point(&snapshot);
991 let new_position = location.position.to_point(&snapshot);
992 // ignore this so that when we get updates from the edit tool
993 // the position doesn't reset to the startof line
994 old_position.row == new_position.row
995 && old_position.column > new_position.column
996 })
997 .ok()
998 .unwrap_or_default();
999 if !should_ignore {
1000 project.set_agent_location(Some(location.clone()), cx);
1001 }
1002 }
1003 });
1004 }
1005 if tool_call.resolved_locations != resolved_locations {
1006 tool_call.resolved_locations = resolved_locations;
1007 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1008 }
1009 })
1010 })
1011 .detach();
1012 }
1013
1014 pub fn request_tool_call_authorization(
1015 &mut self,
1016 tool_call: acp::ToolCall,
1017 options: Vec<acp::PermissionOption>,
1018 cx: &mut Context<Self>,
1019 ) -> oneshot::Receiver<acp::PermissionOptionId> {
1020 let (tx, rx) = oneshot::channel();
1021
1022 let status = ToolCallStatus::WaitingForConfirmation {
1023 options,
1024 respond_tx: tx,
1025 };
1026
1027 self.upsert_tool_call_inner(tool_call, status, cx);
1028 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1029 rx
1030 }
1031
1032 pub fn authorize_tool_call(
1033 &mut self,
1034 id: acp::ToolCallId,
1035 option_id: acp::PermissionOptionId,
1036 option_kind: acp::PermissionOptionKind,
1037 cx: &mut Context<Self>,
1038 ) {
1039 let Some((ix, call)) = self.tool_call_mut(&id) else {
1040 return;
1041 };
1042
1043 let new_status = match option_kind {
1044 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1045 ToolCallStatus::Rejected
1046 }
1047 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1048 ToolCallStatus::Allowed {
1049 status: acp::ToolCallStatus::InProgress,
1050 }
1051 }
1052 };
1053
1054 let curr_status = mem::replace(&mut call.status, new_status);
1055
1056 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1057 respond_tx.send(option_id).log_err();
1058 } else if cfg!(debug_assertions) {
1059 panic!("tried to authorize an already authorized tool call");
1060 }
1061
1062 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1063 }
1064
1065 /// Returns true if the last turn is awaiting tool authorization
1066 pub fn waiting_for_tool_confirmation(&self) -> bool {
1067 for entry in self.entries.iter().rev() {
1068 match &entry {
1069 AgentThreadEntry::ToolCall(call) => match call.status {
1070 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1071 ToolCallStatus::Allowed { .. }
1072 | ToolCallStatus::Rejected
1073 | ToolCallStatus::Canceled => continue,
1074 },
1075 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1076 // Reached the beginning of the turn
1077 return false;
1078 }
1079 }
1080 }
1081 false
1082 }
1083
1084 pub fn plan(&self) -> &Plan {
1085 &self.plan
1086 }
1087
1088 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1089 let new_entries_len = request.entries.len();
1090 let mut new_entries = request.entries.into_iter();
1091
1092 // Reuse existing markdown to prevent flickering
1093 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1094 let PlanEntry {
1095 content,
1096 priority,
1097 status,
1098 } = old;
1099 content.update(cx, |old, cx| {
1100 old.replace(new.content, cx);
1101 });
1102 *priority = new.priority;
1103 *status = new.status;
1104 }
1105 for new in new_entries {
1106 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1107 }
1108 self.plan.entries.truncate(new_entries_len);
1109
1110 cx.notify();
1111 }
1112
1113 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1114 self.plan
1115 .entries
1116 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1117 cx.notify();
1118 }
1119
1120 #[cfg(any(test, feature = "test-support"))]
1121 pub fn send_raw(
1122 &mut self,
1123 message: &str,
1124 cx: &mut Context<Self>,
1125 ) -> BoxFuture<'static, Result<()>> {
1126 self.send(
1127 vec![acp::ContentBlock::Text(acp::TextContent {
1128 text: message.to_string(),
1129 annotations: None,
1130 })],
1131 cx,
1132 )
1133 }
1134
1135 pub fn send(
1136 &mut self,
1137 message: Vec<acp::ContentBlock>,
1138 cx: &mut Context<Self>,
1139 ) -> BoxFuture<'static, Result<()>> {
1140 let block = ContentBlock::new_combined(
1141 message.clone(),
1142 self.project.read(cx).languages().clone(),
1143 cx,
1144 );
1145 let git_store = self.project.read(cx).git_store().clone();
1146
1147 let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1148 let message_id = if self
1149 .connection
1150 .session_editor(&self.session_id, cx)
1151 .is_some()
1152 {
1153 Some(UserMessageId::new())
1154 } else {
1155 None
1156 };
1157 self.push_entry(
1158 AgentThreadEntry::UserMessage(UserMessage {
1159 id: message_id.clone(),
1160 content: block,
1161 chunks: message.clone(),
1162 checkpoint: None,
1163 }),
1164 cx,
1165 );
1166 self.clear_completed_plan_entries(cx);
1167
1168 let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
1169 let (tx, rx) = oneshot::channel();
1170 let cancel_task = self.cancel(cx);
1171 let request = acp::PromptRequest {
1172 prompt: message,
1173 session_id: self.session_id.clone(),
1174 };
1175
1176 self.send_task = Some(cx.spawn({
1177 let message_id = message_id.clone();
1178 async move |this, cx| {
1179 cancel_task.await;
1180
1181 old_checkpoint_tx.send(old_checkpoint.await).ok();
1182 if let Ok(result) = this.update(cx, |this, cx| {
1183 this.connection.prompt(message_id, request, cx)
1184 }) {
1185 tx.send(result.await).log_err();
1186 }
1187 }
1188 }));
1189
1190 cx.spawn(async move |this, cx| {
1191 let old_checkpoint = old_checkpoint_rx
1192 .await
1193 .map_err(|_| anyhow!("send canceled"))
1194 .flatten()
1195 .context("failed to get old checkpoint")
1196 .log_err();
1197
1198 let response = rx.await;
1199
1200 if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
1201 let new_checkpoint = git_store
1202 .update(cx, |git, cx| git.checkpoint(cx))?
1203 .await
1204 .context("failed to get new checkpoint")
1205 .log_err();
1206 if let Some(new_checkpoint) = new_checkpoint {
1207 let equal = git_store
1208 .update(cx, |git, cx| {
1209 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1210 })?
1211 .await
1212 .unwrap_or(true);
1213 if !equal {
1214 this.update(cx, |this, cx| {
1215 if let Some((ix, message)) = this.user_message_mut(&message_id) {
1216 message.checkpoint = Some(old_checkpoint);
1217 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1218 }
1219 })?;
1220 }
1221 }
1222 }
1223
1224 this.update(cx, |this, cx| {
1225 match response {
1226 Ok(Err(e)) => {
1227 this.send_task.take();
1228 cx.emit(AcpThreadEvent::Error);
1229 Err(e)
1230 }
1231 result => {
1232 let cancelled = matches!(
1233 result,
1234 Ok(Ok(acp::PromptResponse {
1235 stop_reason: acp::StopReason::Cancelled
1236 }))
1237 );
1238
1239 // We only take the task if the current prompt wasn't cancelled.
1240 //
1241 // This prompt may have been cancelled because another one was sent
1242 // while it was still generating. In these cases, dropping `send_task`
1243 // would cause the next generation to be cancelled.
1244 if !cancelled {
1245 this.send_task.take();
1246 }
1247
1248 cx.emit(AcpThreadEvent::Stopped);
1249 Ok(())
1250 }
1251 }
1252 })?
1253 })
1254 .boxed()
1255 }
1256
1257 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1258 let Some(send_task) = self.send_task.take() else {
1259 return Task::ready(());
1260 };
1261
1262 for entry in self.entries.iter_mut() {
1263 if let AgentThreadEntry::ToolCall(call) = entry {
1264 let cancel = matches!(
1265 call.status,
1266 ToolCallStatus::WaitingForConfirmation { .. }
1267 | ToolCallStatus::Allowed {
1268 status: acp::ToolCallStatus::InProgress
1269 }
1270 );
1271
1272 if cancel {
1273 call.status = ToolCallStatus::Canceled;
1274 }
1275 }
1276 }
1277
1278 self.connection.cancel(&self.session_id, cx);
1279
1280 // Wait for the send task to complete
1281 cx.foreground_executor().spawn(send_task)
1282 }
1283
1284 /// Rewinds this thread to before the entry at `index`, removing it and all
1285 /// subsequent entries while reverting any changes made from that point.
1286 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1287 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1288 return Task::ready(Err(anyhow!("not supported")));
1289 };
1290 let Some(message) = self.user_message(&id) else {
1291 return Task::ready(Err(anyhow!("message not found")));
1292 };
1293
1294 let checkpoint = message.checkpoint.clone();
1295
1296 let git_store = self.project.read(cx).git_store().clone();
1297 cx.spawn(async move |this, cx| {
1298 if let Some(checkpoint) = checkpoint {
1299 git_store
1300 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1301 .await?;
1302 }
1303
1304 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1305 .await?;
1306 this.update(cx, |this, cx| {
1307 if let Some((ix, _)) = this.user_message_mut(&id) {
1308 let range = ix..this.entries.len();
1309 this.entries.truncate(ix);
1310 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1311 }
1312 })
1313 })
1314 }
1315
1316 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1317 self.entries.iter().find_map(|entry| {
1318 if let AgentThreadEntry::UserMessage(message) = entry {
1319 if message.id.as_ref() == Some(&id) {
1320 Some(message)
1321 } else {
1322 None
1323 }
1324 } else {
1325 None
1326 }
1327 })
1328 }
1329
1330 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1331 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1332 if let AgentThreadEntry::UserMessage(message) = entry {
1333 if message.id.as_ref() == Some(&id) {
1334 Some((ix, message))
1335 } else {
1336 None
1337 }
1338 } else {
1339 None
1340 }
1341 })
1342 }
1343
1344 pub fn read_text_file(
1345 &self,
1346 path: PathBuf,
1347 line: Option<u32>,
1348 limit: Option<u32>,
1349 reuse_shared_snapshot: bool,
1350 cx: &mut Context<Self>,
1351 ) -> Task<Result<String>> {
1352 let project = self.project.clone();
1353 let action_log = self.action_log.clone();
1354 cx.spawn(async move |this, cx| {
1355 let load = project.update(cx, |project, cx| {
1356 let path = project
1357 .project_path_for_absolute_path(&path, cx)
1358 .context("invalid path")?;
1359 anyhow::Ok(project.open_buffer(path, cx))
1360 });
1361 let buffer = load??.await?;
1362
1363 let snapshot = if reuse_shared_snapshot {
1364 this.read_with(cx, |this, _| {
1365 this.shared_buffers.get(&buffer.clone()).cloned()
1366 })
1367 .log_err()
1368 .flatten()
1369 } else {
1370 None
1371 };
1372
1373 let snapshot = if let Some(snapshot) = snapshot {
1374 snapshot
1375 } else {
1376 action_log.update(cx, |action_log, cx| {
1377 action_log.buffer_read(buffer.clone(), cx);
1378 })?;
1379 project.update(cx, |project, cx| {
1380 let position = buffer
1381 .read(cx)
1382 .snapshot()
1383 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1384 project.set_agent_location(
1385 Some(AgentLocation {
1386 buffer: buffer.downgrade(),
1387 position,
1388 }),
1389 cx,
1390 );
1391 })?;
1392
1393 buffer.update(cx, |buffer, _| buffer.snapshot())?
1394 };
1395
1396 this.update(cx, |this, _| {
1397 let text = snapshot.text();
1398 this.shared_buffers.insert(buffer.clone(), snapshot);
1399 if line.is_none() && limit.is_none() {
1400 return Ok(text);
1401 }
1402 let limit = limit.unwrap_or(u32::MAX) as usize;
1403 let Some(line) = line else {
1404 return Ok(text.lines().take(limit).collect::<String>());
1405 };
1406
1407 let count = text.lines().count();
1408 if count < line as usize {
1409 anyhow::bail!("There are only {} lines", count);
1410 }
1411 Ok(text
1412 .lines()
1413 .skip(line as usize + 1)
1414 .take(limit)
1415 .collect::<String>())
1416 })?
1417 })
1418 }
1419
1420 pub fn write_text_file(
1421 &self,
1422 path: PathBuf,
1423 content: String,
1424 cx: &mut Context<Self>,
1425 ) -> Task<Result<()>> {
1426 let project = self.project.clone();
1427 let action_log = self.action_log.clone();
1428 cx.spawn(async move |this, cx| {
1429 let load = project.update(cx, |project, cx| {
1430 let path = project
1431 .project_path_for_absolute_path(&path, cx)
1432 .context("invalid path")?;
1433 anyhow::Ok(project.open_buffer(path, cx))
1434 });
1435 let buffer = load??.await?;
1436 let snapshot = this.update(cx, |this, cx| {
1437 this.shared_buffers
1438 .get(&buffer)
1439 .cloned()
1440 .unwrap_or_else(|| buffer.read(cx).snapshot())
1441 })?;
1442 let edits = cx
1443 .background_executor()
1444 .spawn(async move {
1445 let old_text = snapshot.text();
1446 text_diff(old_text.as_str(), &content)
1447 .into_iter()
1448 .map(|(range, replacement)| {
1449 (
1450 snapshot.anchor_after(range.start)
1451 ..snapshot.anchor_before(range.end),
1452 replacement,
1453 )
1454 })
1455 .collect::<Vec<_>>()
1456 })
1457 .await;
1458 cx.update(|cx| {
1459 project.update(cx, |project, cx| {
1460 project.set_agent_location(
1461 Some(AgentLocation {
1462 buffer: buffer.downgrade(),
1463 position: edits
1464 .last()
1465 .map(|(range, _)| range.end)
1466 .unwrap_or(Anchor::MIN),
1467 }),
1468 cx,
1469 );
1470 });
1471
1472 action_log.update(cx, |action_log, cx| {
1473 action_log.buffer_read(buffer.clone(), cx);
1474 });
1475 buffer.update(cx, |buffer, cx| {
1476 buffer.edit(edits, None, cx);
1477 });
1478 action_log.update(cx, |action_log, cx| {
1479 action_log.buffer_edited(buffer.clone(), cx);
1480 });
1481 })?;
1482 project
1483 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1484 .await
1485 })
1486 }
1487
1488 pub fn to_markdown(&self, cx: &App) -> String {
1489 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1490 }
1491
1492 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1493 cx.emit(AcpThreadEvent::ServerExited(status));
1494 }
1495}
1496
1497fn markdown_for_raw_output(
1498 raw_output: &serde_json::Value,
1499 language_registry: &Arc<LanguageRegistry>,
1500 cx: &mut App,
1501) -> Option<Entity<Markdown>> {
1502 match raw_output {
1503 serde_json::Value::Null => None,
1504 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1505 Markdown::new(
1506 value.to_string().into(),
1507 Some(language_registry.clone()),
1508 None,
1509 cx,
1510 )
1511 })),
1512 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1513 Markdown::new(
1514 value.to_string().into(),
1515 Some(language_registry.clone()),
1516 None,
1517 cx,
1518 )
1519 })),
1520 serde_json::Value::String(value) => Some(cx.new(|cx| {
1521 Markdown::new(
1522 value.clone().into(),
1523 Some(language_registry.clone()),
1524 None,
1525 cx,
1526 )
1527 })),
1528 value => Some(cx.new(|cx| {
1529 Markdown::new(
1530 format!("```json\n{}\n```", value).into(),
1531 Some(language_registry.clone()),
1532 None,
1533 cx,
1534 )
1535 })),
1536 }
1537}
1538
1539#[cfg(test)]
1540mod tests {
1541 use super::*;
1542 use anyhow::anyhow;
1543 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1544 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1545 use indoc::indoc;
1546 use project::{FakeFs, Fs};
1547 use rand::Rng as _;
1548 use serde_json::json;
1549 use settings::SettingsStore;
1550 use smol::stream::StreamExt as _;
1551 use std::{
1552 cell::RefCell,
1553 path::Path,
1554 rc::Rc,
1555 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1556 time::Duration,
1557 };
1558 use util::path;
1559
1560 fn init_test(cx: &mut TestAppContext) {
1561 env_logger::try_init().ok();
1562 cx.update(|cx| {
1563 let settings_store = SettingsStore::test(cx);
1564 cx.set_global(settings_store);
1565 Project::init_settings(cx);
1566 language::init(cx);
1567 });
1568 }
1569
1570 #[gpui::test]
1571 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1572 init_test(cx);
1573
1574 let fs = FakeFs::new(cx.executor());
1575 let project = Project::test(fs, [], cx).await;
1576 let connection = Rc::new(FakeAgentConnection::new());
1577 let thread = cx
1578 .spawn(async move |mut cx| {
1579 connection
1580 .new_thread(project, Path::new(path!("/test")), &mut cx)
1581 .await
1582 })
1583 .await
1584 .unwrap();
1585
1586 // Test creating a new user message
1587 thread.update(cx, |thread, cx| {
1588 thread.push_user_content_block(
1589 None,
1590 acp::ContentBlock::Text(acp::TextContent {
1591 annotations: None,
1592 text: "Hello, ".to_string(),
1593 }),
1594 cx,
1595 );
1596 });
1597
1598 thread.update(cx, |thread, cx| {
1599 assert_eq!(thread.entries.len(), 1);
1600 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1601 assert_eq!(user_msg.id, None);
1602 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1603 } else {
1604 panic!("Expected UserMessage");
1605 }
1606 });
1607
1608 // Test appending to existing user message
1609 let message_1_id = UserMessageId::new();
1610 thread.update(cx, |thread, cx| {
1611 thread.push_user_content_block(
1612 Some(message_1_id.clone()),
1613 acp::ContentBlock::Text(acp::TextContent {
1614 annotations: None,
1615 text: "world!".to_string(),
1616 }),
1617 cx,
1618 );
1619 });
1620
1621 thread.update(cx, |thread, cx| {
1622 assert_eq!(thread.entries.len(), 1);
1623 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1624 assert_eq!(user_msg.id, Some(message_1_id));
1625 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1626 } else {
1627 panic!("Expected UserMessage");
1628 }
1629 });
1630
1631 // Test creating new user message after assistant message
1632 thread.update(cx, |thread, cx| {
1633 thread.push_assistant_content_block(
1634 acp::ContentBlock::Text(acp::TextContent {
1635 annotations: None,
1636 text: "Assistant response".to_string(),
1637 }),
1638 false,
1639 cx,
1640 );
1641 });
1642
1643 let message_2_id = UserMessageId::new();
1644 thread.update(cx, |thread, cx| {
1645 thread.push_user_content_block(
1646 Some(message_2_id.clone()),
1647 acp::ContentBlock::Text(acp::TextContent {
1648 annotations: None,
1649 text: "New user message".to_string(),
1650 }),
1651 cx,
1652 );
1653 });
1654
1655 thread.update(cx, |thread, cx| {
1656 assert_eq!(thread.entries.len(), 3);
1657 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1658 assert_eq!(user_msg.id, Some(message_2_id));
1659 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1660 } else {
1661 panic!("Expected UserMessage at index 2");
1662 }
1663 });
1664 }
1665
1666 #[gpui::test]
1667 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1668 init_test(cx);
1669
1670 let fs = FakeFs::new(cx.executor());
1671 let project = Project::test(fs, [], cx).await;
1672 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1673 |_, thread, mut cx| {
1674 async move {
1675 thread.update(&mut cx, |thread, cx| {
1676 thread
1677 .handle_session_update(
1678 acp::SessionUpdate::AgentThoughtChunk {
1679 content: "Thinking ".into(),
1680 },
1681 cx,
1682 )
1683 .unwrap();
1684 thread
1685 .handle_session_update(
1686 acp::SessionUpdate::AgentThoughtChunk {
1687 content: "hard!".into(),
1688 },
1689 cx,
1690 )
1691 .unwrap();
1692 })?;
1693 Ok(acp::PromptResponse {
1694 stop_reason: acp::StopReason::EndTurn,
1695 })
1696 }
1697 .boxed_local()
1698 },
1699 ));
1700
1701 let thread = cx
1702 .spawn(async move |mut cx| {
1703 connection
1704 .new_thread(project, Path::new(path!("/test")), &mut cx)
1705 .await
1706 })
1707 .await
1708 .unwrap();
1709
1710 thread
1711 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1712 .await
1713 .unwrap();
1714
1715 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1716 assert_eq!(
1717 output,
1718 indoc! {r#"
1719 ## User
1720
1721 Hello from Zed!
1722
1723 ## Assistant
1724
1725 <thinking>
1726 Thinking hard!
1727 </thinking>
1728
1729 "#}
1730 );
1731 }
1732
1733 #[gpui::test]
1734 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1735 init_test(cx);
1736
1737 let fs = FakeFs::new(cx.executor());
1738 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1739 .await;
1740 let project = Project::test(fs.clone(), [], cx).await;
1741 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1742 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1743 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1744 move |_, thread, mut cx| {
1745 let read_file_tx = read_file_tx.clone();
1746 async move {
1747 let content = thread
1748 .update(&mut cx, |thread, cx| {
1749 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1750 })
1751 .unwrap()
1752 .await
1753 .unwrap();
1754 assert_eq!(content, "one\ntwo\nthree\n");
1755 read_file_tx.take().unwrap().send(()).unwrap();
1756 thread
1757 .update(&mut cx, |thread, cx| {
1758 thread.write_text_file(
1759 path!("/tmp/foo").into(),
1760 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1761 cx,
1762 )
1763 })
1764 .unwrap()
1765 .await
1766 .unwrap();
1767 Ok(acp::PromptResponse {
1768 stop_reason: acp::StopReason::EndTurn,
1769 })
1770 }
1771 .boxed_local()
1772 },
1773 ));
1774
1775 let (worktree, pathbuf) = project
1776 .update(cx, |project, cx| {
1777 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1778 })
1779 .await
1780 .unwrap();
1781 let buffer = project
1782 .update(cx, |project, cx| {
1783 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1784 })
1785 .await
1786 .unwrap();
1787
1788 let thread = cx
1789 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1790 .await
1791 .unwrap();
1792
1793 let request = thread.update(cx, |thread, cx| {
1794 thread.send_raw("Extend the count in /tmp/foo", cx)
1795 });
1796 read_file_rx.await.ok();
1797 buffer.update(cx, |buffer, cx| {
1798 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1799 });
1800 cx.run_until_parked();
1801 assert_eq!(
1802 buffer.read_with(cx, |buffer, _| buffer.text()),
1803 "zero\none\ntwo\nthree\nfour\nfive\n"
1804 );
1805 assert_eq!(
1806 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1807 "zero\none\ntwo\nthree\nfour\nfive\n"
1808 );
1809 request.await.unwrap();
1810 }
1811
1812 #[gpui::test]
1813 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1814 init_test(cx);
1815
1816 let fs = FakeFs::new(cx.executor());
1817 let project = Project::test(fs, [], cx).await;
1818 let id = acp::ToolCallId("test".into());
1819
1820 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1821 let id = id.clone();
1822 move |_, thread, mut cx| {
1823 let id = id.clone();
1824 async move {
1825 thread
1826 .update(&mut cx, |thread, cx| {
1827 thread.handle_session_update(
1828 acp::SessionUpdate::ToolCall(acp::ToolCall {
1829 id: id.clone(),
1830 title: "Label".into(),
1831 kind: acp::ToolKind::Fetch,
1832 status: acp::ToolCallStatus::InProgress,
1833 content: vec![],
1834 locations: vec![],
1835 raw_input: None,
1836 raw_output: None,
1837 }),
1838 cx,
1839 )
1840 })
1841 .unwrap()
1842 .unwrap();
1843 Ok(acp::PromptResponse {
1844 stop_reason: acp::StopReason::EndTurn,
1845 })
1846 }
1847 .boxed_local()
1848 }
1849 }));
1850
1851 let thread = cx
1852 .spawn(async move |mut cx| {
1853 connection
1854 .new_thread(project, Path::new(path!("/test")), &mut cx)
1855 .await
1856 })
1857 .await
1858 .unwrap();
1859
1860 let request = thread.update(cx, |thread, cx| {
1861 thread.send_raw("Fetch https://example.com", cx)
1862 });
1863
1864 run_until_first_tool_call(&thread, cx).await;
1865
1866 thread.read_with(cx, |thread, _| {
1867 assert!(matches!(
1868 thread.entries[1],
1869 AgentThreadEntry::ToolCall(ToolCall {
1870 status: ToolCallStatus::Allowed {
1871 status: acp::ToolCallStatus::InProgress,
1872 ..
1873 },
1874 ..
1875 })
1876 ));
1877 });
1878
1879 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1880
1881 thread.read_with(cx, |thread, _| {
1882 assert!(matches!(
1883 &thread.entries[1],
1884 AgentThreadEntry::ToolCall(ToolCall {
1885 status: ToolCallStatus::Canceled,
1886 ..
1887 })
1888 ));
1889 });
1890
1891 thread
1892 .update(cx, |thread, cx| {
1893 thread.handle_session_update(
1894 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1895 id,
1896 fields: acp::ToolCallUpdateFields {
1897 status: Some(acp::ToolCallStatus::Completed),
1898 ..Default::default()
1899 },
1900 }),
1901 cx,
1902 )
1903 })
1904 .unwrap();
1905
1906 request.await.unwrap();
1907
1908 thread.read_with(cx, |thread, _| {
1909 assert!(matches!(
1910 thread.entries[1],
1911 AgentThreadEntry::ToolCall(ToolCall {
1912 status: ToolCallStatus::Allowed {
1913 status: acp::ToolCallStatus::Completed,
1914 ..
1915 },
1916 ..
1917 })
1918 ));
1919 });
1920 }
1921
1922 #[gpui::test]
1923 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1924 init_test(cx);
1925 let fs = FakeFs::new(cx.background_executor.clone());
1926 fs.insert_tree(path!("/test"), json!({})).await;
1927 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1928
1929 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1930 move |_, thread, mut cx| {
1931 async move {
1932 thread
1933 .update(&mut cx, |thread, cx| {
1934 thread.handle_session_update(
1935 acp::SessionUpdate::ToolCall(acp::ToolCall {
1936 id: acp::ToolCallId("test".into()),
1937 title: "Label".into(),
1938 kind: acp::ToolKind::Edit,
1939 status: acp::ToolCallStatus::Completed,
1940 content: vec![acp::ToolCallContent::Diff {
1941 diff: acp::Diff {
1942 path: "/test/test.txt".into(),
1943 old_text: None,
1944 new_text: "foo".into(),
1945 },
1946 }],
1947 locations: vec![],
1948 raw_input: None,
1949 raw_output: None,
1950 }),
1951 cx,
1952 )
1953 })
1954 .unwrap()
1955 .unwrap();
1956 Ok(acp::PromptResponse {
1957 stop_reason: acp::StopReason::EndTurn,
1958 })
1959 }
1960 .boxed_local()
1961 }
1962 }));
1963
1964 let thread = connection
1965 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1966 .await
1967 .unwrap();
1968 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1969 .await
1970 .unwrap();
1971
1972 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1973 }
1974
1975 #[gpui::test(iterations = 10)]
1976 async fn test_checkpoints(cx: &mut TestAppContext) {
1977 init_test(cx);
1978 let fs = FakeFs::new(cx.background_executor.clone());
1979 fs.insert_tree(
1980 path!("/test"),
1981 json!({
1982 ".git": {}
1983 }),
1984 )
1985 .await;
1986 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1987
1988 let simulate_changes = Arc::new(AtomicBool::new(true));
1989 let next_filename = Arc::new(AtomicUsize::new(0));
1990 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1991 let simulate_changes = simulate_changes.clone();
1992 let next_filename = next_filename.clone();
1993 let fs = fs.clone();
1994 move |request, thread, mut cx| {
1995 let fs = fs.clone();
1996 let simulate_changes = simulate_changes.clone();
1997 let next_filename = next_filename.clone();
1998 async move {
1999 if simulate_changes.load(SeqCst) {
2000 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2001 fs.write(Path::new(&filename), b"").await?;
2002 }
2003
2004 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2005 panic!("expected text content block");
2006 };
2007 thread.update(&mut cx, |thread, cx| {
2008 thread
2009 .handle_session_update(
2010 acp::SessionUpdate::AgentMessageChunk {
2011 content: content.text.to_uppercase().into(),
2012 },
2013 cx,
2014 )
2015 .unwrap();
2016 })?;
2017 Ok(acp::PromptResponse {
2018 stop_reason: acp::StopReason::EndTurn,
2019 })
2020 }
2021 .boxed_local()
2022 }
2023 }));
2024 let thread = connection
2025 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
2026 .await
2027 .unwrap();
2028
2029 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2030 .await
2031 .unwrap();
2032 thread.read_with(cx, |thread, cx| {
2033 assert_eq!(
2034 thread.to_markdown(cx),
2035 indoc! {"
2036 ## User (checkpoint)
2037
2038 Lorem
2039
2040 ## Assistant
2041
2042 LOREM
2043
2044 "}
2045 );
2046 });
2047 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2048
2049 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2050 .await
2051 .unwrap();
2052 thread.read_with(cx, |thread, cx| {
2053 assert_eq!(
2054 thread.to_markdown(cx),
2055 indoc! {"
2056 ## User (checkpoint)
2057
2058 Lorem
2059
2060 ## Assistant
2061
2062 LOREM
2063
2064 ## User (checkpoint)
2065
2066 ipsum
2067
2068 ## Assistant
2069
2070 IPSUM
2071
2072 "}
2073 );
2074 });
2075 assert_eq!(
2076 fs.files(),
2077 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2078 );
2079
2080 // Checkpoint isn't stored when there are no changes.
2081 simulate_changes.store(false, SeqCst);
2082 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2083 .await
2084 .unwrap();
2085 thread.read_with(cx, |thread, cx| {
2086 assert_eq!(
2087 thread.to_markdown(cx),
2088 indoc! {"
2089 ## User (checkpoint)
2090
2091 Lorem
2092
2093 ## Assistant
2094
2095 LOREM
2096
2097 ## User (checkpoint)
2098
2099 ipsum
2100
2101 ## Assistant
2102
2103 IPSUM
2104
2105 ## User
2106
2107 dolor
2108
2109 ## Assistant
2110
2111 DOLOR
2112
2113 "}
2114 );
2115 });
2116 assert_eq!(
2117 fs.files(),
2118 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2119 );
2120
2121 // Rewinding the conversation truncates the history and restores the checkpoint.
2122 thread
2123 .update(cx, |thread, cx| {
2124 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2125 panic!("unexpected entries {:?}", thread.entries)
2126 };
2127 thread.rewind(message.id.clone().unwrap(), cx)
2128 })
2129 .await
2130 .unwrap();
2131 thread.read_with(cx, |thread, cx| {
2132 assert_eq!(
2133 thread.to_markdown(cx),
2134 indoc! {"
2135 ## User (checkpoint)
2136
2137 Lorem
2138
2139 ## Assistant
2140
2141 LOREM
2142
2143 "}
2144 );
2145 });
2146 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2147 }
2148
2149 async fn run_until_first_tool_call(
2150 thread: &Entity<AcpThread>,
2151 cx: &mut TestAppContext,
2152 ) -> usize {
2153 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2154
2155 let subscription = cx.update(|cx| {
2156 cx.subscribe(thread, move |thread, _, cx| {
2157 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2158 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2159 return tx.try_send(ix).unwrap();
2160 }
2161 }
2162 })
2163 });
2164
2165 select! {
2166 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2167 panic!("Timeout waiting for tool call")
2168 }
2169 ix = rx.next().fuse() => {
2170 drop(subscription);
2171 ix.unwrap()
2172 }
2173 }
2174 }
2175
2176 #[derive(Clone, Default)]
2177 struct FakeAgentConnection {
2178 auth_methods: Vec<acp::AuthMethod>,
2179 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2180 on_user_message: Option<
2181 Rc<
2182 dyn Fn(
2183 acp::PromptRequest,
2184 WeakEntity<AcpThread>,
2185 AsyncApp,
2186 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2187 + 'static,
2188 >,
2189 >,
2190 }
2191
2192 impl FakeAgentConnection {
2193 fn new() -> Self {
2194 Self {
2195 auth_methods: Vec::new(),
2196 on_user_message: None,
2197 sessions: Arc::default(),
2198 }
2199 }
2200
2201 #[expect(unused)]
2202 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2203 self.auth_methods = auth_methods;
2204 self
2205 }
2206
2207 fn on_user_message(
2208 mut self,
2209 handler: impl Fn(
2210 acp::PromptRequest,
2211 WeakEntity<AcpThread>,
2212 AsyncApp,
2213 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2214 + 'static,
2215 ) -> Self {
2216 self.on_user_message.replace(Rc::new(handler));
2217 self
2218 }
2219 }
2220
2221 impl AgentConnection for FakeAgentConnection {
2222 fn auth_methods(&self) -> &[acp::AuthMethod] {
2223 &self.auth_methods
2224 }
2225
2226 fn new_thread(
2227 self: Rc<Self>,
2228 project: Entity<Project>,
2229 _cwd: &Path,
2230 cx: &mut gpui::AsyncApp,
2231 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2232 let session_id = acp::SessionId(
2233 rand::thread_rng()
2234 .sample_iter(&rand::distributions::Alphanumeric)
2235 .take(7)
2236 .map(char::from)
2237 .collect::<String>()
2238 .into(),
2239 );
2240 let thread = cx
2241 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
2242 .unwrap();
2243 self.sessions.lock().insert(session_id, thread.downgrade());
2244 Task::ready(Ok(thread))
2245 }
2246
2247 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2248 if self.auth_methods().iter().any(|m| m.id == method) {
2249 Task::ready(Ok(()))
2250 } else {
2251 Task::ready(Err(anyhow!("Invalid Auth Method")))
2252 }
2253 }
2254
2255 fn prompt(
2256 &self,
2257 _id: Option<UserMessageId>,
2258 params: acp::PromptRequest,
2259 cx: &mut App,
2260 ) -> Task<gpui::Result<acp::PromptResponse>> {
2261 let sessions = self.sessions.lock();
2262 let thread = sessions.get(¶ms.session_id).unwrap();
2263 if let Some(handler) = &self.on_user_message {
2264 let handler = handler.clone();
2265 let thread = thread.clone();
2266 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2267 } else {
2268 Task::ready(Ok(acp::PromptResponse {
2269 stop_reason: acp::StopReason::EndTurn,
2270 }))
2271 }
2272 }
2273
2274 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2275 let sessions = self.sessions.lock();
2276 let thread = sessions.get(&session_id).unwrap().clone();
2277
2278 cx.spawn(async move |cx| {
2279 thread
2280 .update(cx, |thread, cx| thread.cancel(cx))
2281 .unwrap()
2282 .await
2283 })
2284 .detach();
2285 }
2286
2287 fn session_editor(
2288 &self,
2289 session_id: &acp::SessionId,
2290 _cx: &mut App,
2291 ) -> Option<Rc<dyn AgentSessionEditor>> {
2292 Some(Rc::new(FakeAgentSessionEditor {
2293 _session_id: session_id.clone(),
2294 }))
2295 }
2296 }
2297
2298 struct FakeAgentSessionEditor {
2299 _session_id: acp::SessionId,
2300 }
2301
2302 impl AgentSessionEditor for FakeAgentSessionEditor {
2303 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2304 Task::ready(Ok(()))
2305 }
2306 }
2307}