1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub id: Option<UserMessageId>,
34 pub content: ContentBlock,
35 pub chunks: Vec<acp::ContentBlock>,
36 pub checkpoint: Option<GitStoreCheckpoint>,
37}
38
39impl UserMessage {
40 fn to_markdown(&self, cx: &App) -> String {
41 let mut markdown = String::new();
42 if let Some(_) = self.checkpoint {
43 writeln!(markdown, "## User (checkpoint)").unwrap();
44 } else {
45 writeln!(markdown, "## User").unwrap();
46 }
47 writeln!(markdown).unwrap();
48 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
49 writeln!(markdown).unwrap();
50 markdown
51 }
52}
53
54#[derive(Debug, PartialEq)]
55pub struct AssistantMessage {
56 pub chunks: Vec<AssistantMessageChunk>,
57}
58
59impl AssistantMessage {
60 pub fn to_markdown(&self, cx: &App) -> String {
61 format!(
62 "## Assistant\n\n{}\n\n",
63 self.chunks
64 .iter()
65 .map(|chunk| chunk.to_markdown(cx))
66 .join("\n\n")
67 )
68 }
69}
70
71#[derive(Debug, PartialEq)]
72pub enum AssistantMessageChunk {
73 Message { block: ContentBlock },
74 Thought { block: ContentBlock },
75}
76
77impl AssistantMessageChunk {
78 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
79 Self::Message {
80 block: ContentBlock::new(chunk.into(), language_registry, cx),
81 }
82 }
83
84 fn to_markdown(&self, cx: &App) -> String {
85 match self {
86 Self::Message { block } => block.to_markdown(cx).to_string(),
87 Self::Thought { block } => {
88 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
89 }
90 }
91 }
92}
93
94#[derive(Debug)]
95pub enum AgentThreadEntry {
96 UserMessage(UserMessage),
97 AssistantMessage(AssistantMessage),
98 ToolCall(ToolCall),
99}
100
101impl AgentThreadEntry {
102 fn to_markdown(&self, cx: &App) -> String {
103 match self {
104 Self::UserMessage(message) => message.to_markdown(cx),
105 Self::AssistantMessage(message) => message.to_markdown(cx),
106 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
107 }
108 }
109
110 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
111 if let AgentThreadEntry::ToolCall(call) = self {
112 itertools::Either::Left(call.diffs())
113 } else {
114 itertools::Either::Right(std::iter::empty())
115 }
116 }
117
118 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
119 if let AgentThreadEntry::ToolCall(call) = self {
120 itertools::Either::Left(call.terminals())
121 } else {
122 itertools::Either::Right(std::iter::empty())
123 }
124 }
125
126 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
127 if let AgentThreadEntry::ToolCall(ToolCall {
128 locations,
129 resolved_locations,
130 ..
131 }) = self
132 {
133 Some((
134 locations.get(ix)?.clone(),
135 resolved_locations.get(ix)?.clone()?,
136 ))
137 } else {
138 None
139 }
140 }
141}
142
143#[derive(Debug)]
144pub struct ToolCall {
145 pub id: acp::ToolCallId,
146 pub label: Entity<Markdown>,
147 pub kind: acp::ToolKind,
148 pub content: Vec<ToolCallContent>,
149 pub status: ToolCallStatus,
150 pub locations: Vec<acp::ToolCallLocation>,
151 pub resolved_locations: Vec<Option<AgentLocation>>,
152 pub raw_input: Option<serde_json::Value>,
153 pub raw_output: Option<serde_json::Value>,
154}
155
156impl ToolCall {
157 fn from_acp(
158 tool_call: acp::ToolCall,
159 status: ToolCallStatus,
160 language_registry: Arc<LanguageRegistry>,
161 cx: &mut App,
162 ) -> Self {
163 Self {
164 id: tool_call.id,
165 label: cx.new(|cx| {
166 Markdown::new(
167 tool_call.title.into(),
168 Some(language_registry.clone()),
169 None,
170 cx,
171 )
172 }),
173 kind: tool_call.kind,
174 content: tool_call
175 .content
176 .into_iter()
177 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
178 .collect(),
179 locations: tool_call.locations,
180 resolved_locations: Vec::default(),
181 status,
182 raw_input: tool_call.raw_input,
183 raw_output: tool_call.raw_output,
184 }
185 }
186
187 fn update_fields(
188 &mut self,
189 fields: acp::ToolCallUpdateFields,
190 language_registry: Arc<LanguageRegistry>,
191 cx: &mut App,
192 ) {
193 let acp::ToolCallUpdateFields {
194 kind,
195 status,
196 title,
197 content,
198 locations,
199 raw_input,
200 raw_output,
201 } = fields;
202
203 if let Some(kind) = kind {
204 self.kind = kind;
205 }
206
207 if let Some(status) = status {
208 self.status = ToolCallStatus::Allowed { status };
209 }
210
211 if let Some(title) = title {
212 self.label.update(cx, |label, cx| {
213 label.replace(title, cx);
214 });
215 }
216
217 if let Some(content) = content {
218 self.content = content
219 .into_iter()
220 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
221 .collect();
222 }
223
224 if let Some(locations) = locations {
225 self.locations = locations;
226 }
227
228 if let Some(raw_input) = raw_input {
229 self.raw_input = Some(raw_input);
230 }
231
232 if let Some(raw_output) = raw_output {
233 if self.content.is_empty() {
234 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
235 {
236 self.content
237 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
238 markdown,
239 }));
240 }
241 }
242 self.raw_output = Some(raw_output);
243 }
244 }
245
246 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
247 self.content.iter().filter_map(|content| match content {
248 ToolCallContent::Diff(diff) => Some(diff),
249 ToolCallContent::ContentBlock(_) => None,
250 ToolCallContent::Terminal(_) => None,
251 })
252 }
253
254 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
255 self.content.iter().filter_map(|content| match content {
256 ToolCallContent::Terminal(terminal) => Some(terminal),
257 ToolCallContent::ContentBlock(_) => None,
258 ToolCallContent::Diff(_) => None,
259 })
260 }
261
262 fn to_markdown(&self, cx: &App) -> String {
263 let mut markdown = format!(
264 "**Tool Call: {}**\nStatus: {}\n\n",
265 self.label.read(cx).source(),
266 self.status
267 );
268 for content in &self.content {
269 markdown.push_str(content.to_markdown(cx).as_str());
270 markdown.push_str("\n\n");
271 }
272 markdown
273 }
274
275 async fn resolve_location(
276 location: acp::ToolCallLocation,
277 project: WeakEntity<Project>,
278 cx: &mut AsyncApp,
279 ) -> Option<AgentLocation> {
280 let buffer = project
281 .update(cx, |project, cx| {
282 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
283 Some(project.open_buffer(path, cx))
284 } else {
285 None
286 }
287 })
288 .ok()??;
289 let buffer = buffer.await.log_err()?;
290 let position = buffer
291 .update(cx, |buffer, _| {
292 if let Some(row) = location.line {
293 let snapshot = buffer.snapshot();
294 let column = snapshot.indent_size_for_line(row).len;
295 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
296 snapshot.anchor_before(point)
297 } else {
298 Anchor::MIN
299 }
300 })
301 .ok()?;
302
303 Some(AgentLocation {
304 buffer: buffer.downgrade(),
305 position,
306 })
307 }
308
309 fn resolve_locations(
310 &self,
311 project: Entity<Project>,
312 cx: &mut App,
313 ) -> Task<Vec<Option<AgentLocation>>> {
314 let locations = self.locations.clone();
315 project.update(cx, |_, cx| {
316 cx.spawn(async move |project, cx| {
317 let mut new_locations = Vec::new();
318 for location in locations {
319 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
320 }
321 new_locations
322 })
323 })
324 }
325}
326
327#[derive(Debug)]
328pub enum ToolCallStatus {
329 WaitingForConfirmation {
330 options: Vec<acp::PermissionOption>,
331 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
332 },
333 Allowed {
334 status: acp::ToolCallStatus,
335 },
336 Rejected,
337 Canceled,
338}
339
340impl Display for ToolCallStatus {
341 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
342 write!(
343 f,
344 "{}",
345 match self {
346 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
347 ToolCallStatus::Allowed { status } => match status {
348 acp::ToolCallStatus::Pending => "Pending",
349 acp::ToolCallStatus::InProgress => "In Progress",
350 acp::ToolCallStatus::Completed => "Completed",
351 acp::ToolCallStatus::Failed => "Failed",
352 },
353 ToolCallStatus::Rejected => "Rejected",
354 ToolCallStatus::Canceled => "Canceled",
355 }
356 )
357 }
358}
359
360#[derive(Debug, PartialEq, Clone)]
361pub enum ContentBlock {
362 Empty,
363 Markdown { markdown: Entity<Markdown> },
364 ResourceLink { resource_link: acp::ResourceLink },
365}
366
367impl ContentBlock {
368 pub fn new(
369 block: acp::ContentBlock,
370 language_registry: &Arc<LanguageRegistry>,
371 cx: &mut App,
372 ) -> Self {
373 let mut this = Self::Empty;
374 this.append(block, language_registry, cx);
375 this
376 }
377
378 pub fn new_combined(
379 blocks: impl IntoIterator<Item = acp::ContentBlock>,
380 language_registry: Arc<LanguageRegistry>,
381 cx: &mut App,
382 ) -> Self {
383 let mut this = Self::Empty;
384 for block in blocks {
385 this.append(block, &language_registry, cx);
386 }
387 this
388 }
389
390 pub fn append(
391 &mut self,
392 block: acp::ContentBlock,
393 language_registry: &Arc<LanguageRegistry>,
394 cx: &mut App,
395 ) {
396 if matches!(self, ContentBlock::Empty) {
397 if let acp::ContentBlock::ResourceLink(resource_link) = block {
398 *self = ContentBlock::ResourceLink { resource_link };
399 return;
400 }
401 }
402
403 let new_content = self.block_string_contents(block);
404
405 match self {
406 ContentBlock::Empty => {
407 *self = Self::create_markdown_block(new_content, language_registry, cx);
408 }
409 ContentBlock::Markdown { markdown } => {
410 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
411 }
412 ContentBlock::ResourceLink { resource_link } => {
413 let existing_content = Self::resource_link_md(&resource_link.uri);
414 let combined = format!("{}\n{}", existing_content, new_content);
415
416 *self = Self::create_markdown_block(combined, language_registry, cx);
417 }
418 }
419 }
420
421 fn create_markdown_block(
422 content: String,
423 language_registry: &Arc<LanguageRegistry>,
424 cx: &mut App,
425 ) -> ContentBlock {
426 ContentBlock::Markdown {
427 markdown: cx
428 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
429 }
430 }
431
432 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
433 match block {
434 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
435 acp::ContentBlock::ResourceLink(resource_link) => {
436 Self::resource_link_md(&resource_link.uri)
437 }
438 acp::ContentBlock::Resource(acp::EmbeddedResource {
439 resource:
440 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
441 uri,
442 ..
443 }),
444 ..
445 }) => Self::resource_link_md(&uri),
446 acp::ContentBlock::Image(image) => Self::image_md(&image),
447 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
448 }
449 }
450
451 fn resource_link_md(uri: &str) -> String {
452 if let Some(uri) = MentionUri::parse(&uri).log_err() {
453 uri.as_link().to_string()
454 } else {
455 uri.to_string()
456 }
457 }
458
459 fn image_md(_image: &acp::ImageContent) -> String {
460 "`Image`".into()
461 }
462
463 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
464 match self {
465 ContentBlock::Empty => "",
466 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
467 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
468 }
469 }
470
471 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
472 match self {
473 ContentBlock::Empty => None,
474 ContentBlock::Markdown { markdown } => Some(markdown),
475 ContentBlock::ResourceLink { .. } => None,
476 }
477 }
478
479 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
480 match self {
481 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
482 _ => None,
483 }
484 }
485}
486
487#[derive(Debug)]
488pub enum ToolCallContent {
489 ContentBlock(ContentBlock),
490 Diff(Entity<Diff>),
491 Terminal(Entity<Terminal>),
492}
493
494impl ToolCallContent {
495 pub fn from_acp(
496 content: acp::ToolCallContent,
497 language_registry: Arc<LanguageRegistry>,
498 cx: &mut App,
499 ) -> Self {
500 match content {
501 acp::ToolCallContent::Content { content } => {
502 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
503 }
504 acp::ToolCallContent::Diff { diff } => {
505 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
506 }
507 }
508 }
509
510 pub fn to_markdown(&self, cx: &App) -> String {
511 match self {
512 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
513 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
514 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
515 }
516 }
517}
518
519#[derive(Debug, PartialEq)]
520pub enum ToolCallUpdate {
521 UpdateFields(acp::ToolCallUpdate),
522 UpdateDiff(ToolCallUpdateDiff),
523 UpdateTerminal(ToolCallUpdateTerminal),
524}
525
526impl ToolCallUpdate {
527 fn id(&self) -> &acp::ToolCallId {
528 match self {
529 Self::UpdateFields(update) => &update.id,
530 Self::UpdateDiff(diff) => &diff.id,
531 Self::UpdateTerminal(terminal) => &terminal.id,
532 }
533 }
534}
535
536impl From<acp::ToolCallUpdate> for ToolCallUpdate {
537 fn from(update: acp::ToolCallUpdate) -> Self {
538 Self::UpdateFields(update)
539 }
540}
541
542impl From<ToolCallUpdateDiff> for ToolCallUpdate {
543 fn from(diff: ToolCallUpdateDiff) -> Self {
544 Self::UpdateDiff(diff)
545 }
546}
547
548#[derive(Debug, PartialEq)]
549pub struct ToolCallUpdateDiff {
550 pub id: acp::ToolCallId,
551 pub diff: Entity<Diff>,
552}
553
554impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
555 fn from(terminal: ToolCallUpdateTerminal) -> Self {
556 Self::UpdateTerminal(terminal)
557 }
558}
559
560#[derive(Debug, PartialEq)]
561pub struct ToolCallUpdateTerminal {
562 pub id: acp::ToolCallId,
563 pub terminal: Entity<Terminal>,
564}
565
566#[derive(Debug, Default)]
567pub struct Plan {
568 pub entries: Vec<PlanEntry>,
569}
570
571#[derive(Debug)]
572pub struct PlanStats<'a> {
573 pub in_progress_entry: Option<&'a PlanEntry>,
574 pub pending: u32,
575 pub completed: u32,
576}
577
578impl Plan {
579 pub fn is_empty(&self) -> bool {
580 self.entries.is_empty()
581 }
582
583 pub fn stats(&self) -> PlanStats<'_> {
584 let mut stats = PlanStats {
585 in_progress_entry: None,
586 pending: 0,
587 completed: 0,
588 };
589
590 for entry in &self.entries {
591 match &entry.status {
592 acp::PlanEntryStatus::Pending => {
593 stats.pending += 1;
594 }
595 acp::PlanEntryStatus::InProgress => {
596 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
597 }
598 acp::PlanEntryStatus::Completed => {
599 stats.completed += 1;
600 }
601 }
602 }
603
604 stats
605 }
606}
607
608#[derive(Debug)]
609pub struct PlanEntry {
610 pub content: Entity<Markdown>,
611 pub priority: acp::PlanEntryPriority,
612 pub status: acp::PlanEntryStatus,
613}
614
615impl PlanEntry {
616 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
617 Self {
618 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
619 priority: entry.priority,
620 status: entry.status,
621 }
622 }
623}
624
625pub struct AcpThread {
626 title: SharedString,
627 entries: Vec<AgentThreadEntry>,
628 plan: Plan,
629 project: Entity<Project>,
630 action_log: Entity<ActionLog>,
631 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
632 send_task: Option<Task<()>>,
633 connection: Rc<dyn AgentConnection>,
634 session_id: acp::SessionId,
635}
636
637pub enum AcpThreadEvent {
638 NewEntry,
639 EntryUpdated(usize),
640 EntriesRemoved(Range<usize>),
641 ToolAuthorizationRequired,
642 Stopped,
643 Error,
644 ServerExited(ExitStatus),
645}
646
647impl EventEmitter<AcpThreadEvent> for AcpThread {}
648
649#[derive(PartialEq, Eq)]
650pub enum ThreadStatus {
651 Idle,
652 WaitingForToolConfirmation,
653 Generating,
654}
655
656#[derive(Debug, Clone)]
657pub enum LoadError {
658 Unsupported {
659 error_message: SharedString,
660 upgrade_message: SharedString,
661 upgrade_command: String,
662 },
663 Exited(i32),
664 Other(SharedString),
665}
666
667impl Display for LoadError {
668 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
669 match self {
670 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
671 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
672 LoadError::Other(msg) => write!(f, "{}", msg),
673 }
674 }
675}
676
677impl Error for LoadError {}
678
679impl AcpThread {
680 pub fn new(
681 title: impl Into<SharedString>,
682 connection: Rc<dyn AgentConnection>,
683 project: Entity<Project>,
684 session_id: acp::SessionId,
685 cx: &mut Context<Self>,
686 ) -> Self {
687 let action_log = cx.new(|_| ActionLog::new(project.clone()));
688
689 Self {
690 action_log,
691 shared_buffers: Default::default(),
692 entries: Default::default(),
693 plan: Default::default(),
694 title: title.into(),
695 project,
696 send_task: None,
697 connection,
698 session_id,
699 }
700 }
701
702 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
703 &self.connection
704 }
705
706 pub fn action_log(&self) -> &Entity<ActionLog> {
707 &self.action_log
708 }
709
710 pub fn project(&self) -> &Entity<Project> {
711 &self.project
712 }
713
714 pub fn title(&self) -> SharedString {
715 self.title.clone()
716 }
717
718 pub fn entries(&self) -> &[AgentThreadEntry] {
719 &self.entries
720 }
721
722 pub fn session_id(&self) -> &acp::SessionId {
723 &self.session_id
724 }
725
726 pub fn status(&self) -> ThreadStatus {
727 if self.send_task.is_some() {
728 if self.waiting_for_tool_confirmation() {
729 ThreadStatus::WaitingForToolConfirmation
730 } else {
731 ThreadStatus::Generating
732 }
733 } else {
734 ThreadStatus::Idle
735 }
736 }
737
738 pub fn has_pending_edit_tool_calls(&self) -> bool {
739 for entry in self.entries.iter().rev() {
740 match entry {
741 AgentThreadEntry::UserMessage(_) => return false,
742 AgentThreadEntry::ToolCall(
743 call @ ToolCall {
744 status:
745 ToolCallStatus::Allowed {
746 status:
747 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
748 },
749 ..
750 },
751 ) if call.diffs().next().is_some() => {
752 return true;
753 }
754 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
755 }
756 }
757
758 false
759 }
760
761 pub fn used_tools_since_last_user_message(&self) -> bool {
762 for entry in self.entries.iter().rev() {
763 match entry {
764 AgentThreadEntry::UserMessage(..) => return false,
765 AgentThreadEntry::AssistantMessage(..) => continue,
766 AgentThreadEntry::ToolCall(..) => return true,
767 }
768 }
769
770 false
771 }
772
773 pub fn handle_session_update(
774 &mut self,
775 update: acp::SessionUpdate,
776 cx: &mut Context<Self>,
777 ) -> Result<()> {
778 match update {
779 acp::SessionUpdate::UserMessageChunk { content } => {
780 self.push_user_content_block(None, content, cx);
781 }
782 acp::SessionUpdate::AgentMessageChunk { content } => {
783 self.push_assistant_content_block(content, false, cx);
784 }
785 acp::SessionUpdate::AgentThoughtChunk { content } => {
786 self.push_assistant_content_block(content, true, cx);
787 }
788 acp::SessionUpdate::ToolCall(tool_call) => {
789 self.upsert_tool_call(tool_call, cx);
790 }
791 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
792 self.update_tool_call(tool_call_update, cx)?;
793 }
794 acp::SessionUpdate::Plan(plan) => {
795 self.update_plan(plan, cx);
796 }
797 }
798 Ok(())
799 }
800
801 pub fn push_user_content_block(
802 &mut self,
803 message_id: Option<UserMessageId>,
804 chunk: acp::ContentBlock,
805 cx: &mut Context<Self>,
806 ) {
807 let language_registry = self.project.read(cx).languages().clone();
808 let entries_len = self.entries.len();
809
810 if let Some(last_entry) = self.entries.last_mut()
811 && let AgentThreadEntry::UserMessage(UserMessage {
812 id,
813 content,
814 chunks,
815 ..
816 }) = last_entry
817 {
818 *id = message_id.or(id.take());
819 content.append(chunk.clone(), &language_registry, cx);
820 chunks.push(chunk);
821 let idx = entries_len - 1;
822 cx.emit(AcpThreadEvent::EntryUpdated(idx));
823 } else {
824 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
825 self.push_entry(
826 AgentThreadEntry::UserMessage(UserMessage {
827 id: message_id,
828 content,
829 chunks: vec![chunk],
830 checkpoint: None,
831 }),
832 cx,
833 );
834 }
835 }
836
837 pub fn push_assistant_content_block(
838 &mut self,
839 chunk: acp::ContentBlock,
840 is_thought: bool,
841 cx: &mut Context<Self>,
842 ) {
843 let language_registry = self.project.read(cx).languages().clone();
844 let entries_len = self.entries.len();
845 if let Some(last_entry) = self.entries.last_mut()
846 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
847 {
848 let idx = entries_len - 1;
849 cx.emit(AcpThreadEvent::EntryUpdated(idx));
850 match (chunks.last_mut(), is_thought) {
851 (Some(AssistantMessageChunk::Message { block }), false)
852 | (Some(AssistantMessageChunk::Thought { block }), true) => {
853 block.append(chunk, &language_registry, cx)
854 }
855 _ => {
856 let block = ContentBlock::new(chunk, &language_registry, cx);
857 if is_thought {
858 chunks.push(AssistantMessageChunk::Thought { block })
859 } else {
860 chunks.push(AssistantMessageChunk::Message { block })
861 }
862 }
863 }
864 } else {
865 let block = ContentBlock::new(chunk, &language_registry, cx);
866 let chunk = if is_thought {
867 AssistantMessageChunk::Thought { block }
868 } else {
869 AssistantMessageChunk::Message { block }
870 };
871
872 self.push_entry(
873 AgentThreadEntry::AssistantMessage(AssistantMessage {
874 chunks: vec![chunk],
875 }),
876 cx,
877 );
878 }
879 }
880
881 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
882 self.entries.push(entry);
883 cx.emit(AcpThreadEvent::NewEntry);
884 }
885
886 pub fn update_tool_call(
887 &mut self,
888 update: impl Into<ToolCallUpdate>,
889 cx: &mut Context<Self>,
890 ) -> Result<()> {
891 let update = update.into();
892 let languages = self.project.read(cx).languages().clone();
893
894 let (ix, current_call) = self
895 .tool_call_mut(update.id())
896 .context("Tool call not found")?;
897 match update {
898 ToolCallUpdate::UpdateFields(update) => {
899 let location_updated = update.fields.locations.is_some();
900 current_call.update_fields(update.fields, languages, cx);
901 if location_updated {
902 self.resolve_locations(update.id.clone(), cx);
903 }
904 }
905 ToolCallUpdate::UpdateDiff(update) => {
906 current_call.content.clear();
907 current_call
908 .content
909 .push(ToolCallContent::Diff(update.diff));
910 }
911 ToolCallUpdate::UpdateTerminal(update) => {
912 current_call.content.clear();
913 current_call
914 .content
915 .push(ToolCallContent::Terminal(update.terminal));
916 }
917 }
918
919 cx.emit(AcpThreadEvent::EntryUpdated(ix));
920
921 Ok(())
922 }
923
924 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
925 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
926 let status = ToolCallStatus::Allowed {
927 status: tool_call.status,
928 };
929 self.upsert_tool_call_inner(tool_call, status, cx)
930 }
931
932 pub fn upsert_tool_call_inner(
933 &mut self,
934 tool_call: acp::ToolCall,
935 status: ToolCallStatus,
936 cx: &mut Context<Self>,
937 ) {
938 let language_registry = self.project.read(cx).languages().clone();
939 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
940 let id = call.id.clone();
941
942 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
943 *current_call = call;
944
945 cx.emit(AcpThreadEvent::EntryUpdated(ix));
946 } else {
947 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
948 };
949
950 self.resolve_locations(id, cx);
951 }
952
953 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
954 // The tool call we are looking for is typically the last one, or very close to the end.
955 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
956 self.entries
957 .iter_mut()
958 .enumerate()
959 .rev()
960 .find_map(|(index, tool_call)| {
961 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
962 && &tool_call.id == id
963 {
964 Some((index, tool_call))
965 } else {
966 None
967 }
968 })
969 }
970
971 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
972 let project = self.project.clone();
973 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
974 return;
975 };
976 let task = tool_call.resolve_locations(project, cx);
977 cx.spawn(async move |this, cx| {
978 let resolved_locations = task.await;
979 this.update(cx, |this, cx| {
980 let project = this.project.clone();
981 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
982 return;
983 };
984 if let Some(Some(location)) = resolved_locations.last() {
985 project.update(cx, |project, cx| {
986 if let Some(agent_location) = project.agent_location() {
987 let should_ignore = agent_location.buffer == location.buffer
988 && location
989 .buffer
990 .update(cx, |buffer, _| {
991 let snapshot = buffer.snapshot();
992 let old_position =
993 agent_location.position.to_point(&snapshot);
994 let new_position = location.position.to_point(&snapshot);
995 // ignore this so that when we get updates from the edit tool
996 // the position doesn't reset to the startof line
997 old_position.row == new_position.row
998 && old_position.column > new_position.column
999 })
1000 .ok()
1001 .unwrap_or_default();
1002 if !should_ignore {
1003 project.set_agent_location(Some(location.clone()), cx);
1004 }
1005 }
1006 });
1007 }
1008 if tool_call.resolved_locations != resolved_locations {
1009 tool_call.resolved_locations = resolved_locations;
1010 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1011 }
1012 })
1013 })
1014 .detach();
1015 }
1016
1017 pub fn request_tool_call_authorization(
1018 &mut self,
1019 tool_call: acp::ToolCall,
1020 options: Vec<acp::PermissionOption>,
1021 cx: &mut Context<Self>,
1022 ) -> oneshot::Receiver<acp::PermissionOptionId> {
1023 let (tx, rx) = oneshot::channel();
1024
1025 let status = ToolCallStatus::WaitingForConfirmation {
1026 options,
1027 respond_tx: tx,
1028 };
1029
1030 self.upsert_tool_call_inner(tool_call, status, cx);
1031 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1032 rx
1033 }
1034
1035 pub fn authorize_tool_call(
1036 &mut self,
1037 id: acp::ToolCallId,
1038 option_id: acp::PermissionOptionId,
1039 option_kind: acp::PermissionOptionKind,
1040 cx: &mut Context<Self>,
1041 ) {
1042 let Some((ix, call)) = self.tool_call_mut(&id) else {
1043 return;
1044 };
1045
1046 let new_status = match option_kind {
1047 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1048 ToolCallStatus::Rejected
1049 }
1050 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1051 ToolCallStatus::Allowed {
1052 status: acp::ToolCallStatus::InProgress,
1053 }
1054 }
1055 };
1056
1057 let curr_status = mem::replace(&mut call.status, new_status);
1058
1059 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1060 respond_tx.send(option_id).log_err();
1061 } else if cfg!(debug_assertions) {
1062 panic!("tried to authorize an already authorized tool call");
1063 }
1064
1065 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1066 }
1067
1068 /// Returns true if the last turn is awaiting tool authorization
1069 pub fn waiting_for_tool_confirmation(&self) -> bool {
1070 for entry in self.entries.iter().rev() {
1071 match &entry {
1072 AgentThreadEntry::ToolCall(call) => match call.status {
1073 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1074 ToolCallStatus::Allowed { .. }
1075 | ToolCallStatus::Rejected
1076 | ToolCallStatus::Canceled => continue,
1077 },
1078 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1079 // Reached the beginning of the turn
1080 return false;
1081 }
1082 }
1083 }
1084 false
1085 }
1086
1087 pub fn plan(&self) -> &Plan {
1088 &self.plan
1089 }
1090
1091 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1092 let new_entries_len = request.entries.len();
1093 let mut new_entries = request.entries.into_iter();
1094
1095 // Reuse existing markdown to prevent flickering
1096 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1097 let PlanEntry {
1098 content,
1099 priority,
1100 status,
1101 } = old;
1102 content.update(cx, |old, cx| {
1103 old.replace(new.content, cx);
1104 });
1105 *priority = new.priority;
1106 *status = new.status;
1107 }
1108 for new in new_entries {
1109 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1110 }
1111 self.plan.entries.truncate(new_entries_len);
1112
1113 cx.notify();
1114 }
1115
1116 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1117 self.plan
1118 .entries
1119 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1120 cx.notify();
1121 }
1122
1123 #[cfg(any(test, feature = "test-support"))]
1124 pub fn send_raw(
1125 &mut self,
1126 message: &str,
1127 cx: &mut Context<Self>,
1128 ) -> BoxFuture<'static, Result<()>> {
1129 self.send(
1130 vec![acp::ContentBlock::Text(acp::TextContent {
1131 text: message.to_string(),
1132 annotations: None,
1133 })],
1134 cx,
1135 )
1136 }
1137
1138 pub fn send(
1139 &mut self,
1140 message: Vec<acp::ContentBlock>,
1141 cx: &mut Context<Self>,
1142 ) -> BoxFuture<'static, Result<()>> {
1143 let block = ContentBlock::new_combined(
1144 message.clone(),
1145 self.project.read(cx).languages().clone(),
1146 cx,
1147 );
1148 let git_store = self.project.read(cx).git_store().clone();
1149
1150 let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1151 let message_id = if self
1152 .connection
1153 .session_editor(&self.session_id, cx)
1154 .is_some()
1155 {
1156 Some(UserMessageId::new())
1157 } else {
1158 None
1159 };
1160 self.push_entry(
1161 AgentThreadEntry::UserMessage(UserMessage {
1162 id: message_id.clone(),
1163 content: block,
1164 chunks: message.clone(),
1165 checkpoint: None,
1166 }),
1167 cx,
1168 );
1169 self.clear_completed_plan_entries(cx);
1170
1171 let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
1172 let (tx, rx) = oneshot::channel();
1173 let cancel_task = self.cancel(cx);
1174 let request = acp::PromptRequest {
1175 prompt: message,
1176 session_id: self.session_id.clone(),
1177 };
1178
1179 self.send_task = Some(cx.spawn({
1180 let message_id = message_id.clone();
1181 async move |this, cx| {
1182 cancel_task.await;
1183
1184 old_checkpoint_tx.send(old_checkpoint.await).ok();
1185 if let Ok(result) = this.update(cx, |this, cx| {
1186 this.connection.prompt(message_id, request, cx)
1187 }) {
1188 tx.send(result.await).log_err();
1189 }
1190 }
1191 }));
1192
1193 cx.spawn(async move |this, cx| {
1194 let old_checkpoint = old_checkpoint_rx
1195 .await
1196 .map_err(|_| anyhow!("send canceled"))
1197 .flatten()
1198 .context("failed to get old checkpoint")
1199 .log_err();
1200
1201 let response = rx.await;
1202
1203 if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
1204 let new_checkpoint = git_store
1205 .update(cx, |git, cx| git.checkpoint(cx))?
1206 .await
1207 .context("failed to get new checkpoint")
1208 .log_err();
1209 if let Some(new_checkpoint) = new_checkpoint {
1210 let equal = git_store
1211 .update(cx, |git, cx| {
1212 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1213 })?
1214 .await
1215 .unwrap_or(true);
1216 if !equal {
1217 this.update(cx, |this, cx| {
1218 if let Some((ix, message)) = this.user_message_mut(&message_id) {
1219 message.checkpoint = Some(old_checkpoint);
1220 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1221 }
1222 })?;
1223 }
1224 }
1225 }
1226
1227 this.update(cx, |this, cx| {
1228 match response {
1229 Ok(Err(e)) => {
1230 this.send_task.take();
1231 cx.emit(AcpThreadEvent::Error);
1232 Err(e)
1233 }
1234 result => {
1235 let cancelled = matches!(
1236 result,
1237 Ok(Ok(acp::PromptResponse {
1238 stop_reason: acp::StopReason::Cancelled
1239 }))
1240 );
1241
1242 // We only take the task if the current prompt wasn't cancelled.
1243 //
1244 // This prompt may have been cancelled because another one was sent
1245 // while it was still generating. In these cases, dropping `send_task`
1246 // would cause the next generation to be cancelled.
1247 if !cancelled {
1248 this.send_task.take();
1249 }
1250
1251 cx.emit(AcpThreadEvent::Stopped);
1252 Ok(())
1253 }
1254 }
1255 })?
1256 })
1257 .boxed()
1258 }
1259
1260 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1261 let Some(send_task) = self.send_task.take() else {
1262 return Task::ready(());
1263 };
1264
1265 for entry in self.entries.iter_mut() {
1266 if let AgentThreadEntry::ToolCall(call) = entry {
1267 let cancel = matches!(
1268 call.status,
1269 ToolCallStatus::WaitingForConfirmation { .. }
1270 | ToolCallStatus::Allowed {
1271 status: acp::ToolCallStatus::InProgress
1272 }
1273 );
1274
1275 if cancel {
1276 call.status = ToolCallStatus::Canceled;
1277 }
1278 }
1279 }
1280
1281 self.connection.cancel(&self.session_id, cx);
1282
1283 // Wait for the send task to complete
1284 cx.foreground_executor().spawn(send_task)
1285 }
1286
1287 /// Rewinds this thread to before the entry at `index`, removing it and all
1288 /// subsequent entries while reverting any changes made from that point.
1289 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1290 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1291 return Task::ready(Err(anyhow!("not supported")));
1292 };
1293 let Some(message) = self.user_message(&id) else {
1294 return Task::ready(Err(anyhow!("message not found")));
1295 };
1296
1297 let checkpoint = message.checkpoint.clone();
1298
1299 let git_store = self.project.read(cx).git_store().clone();
1300 cx.spawn(async move |this, cx| {
1301 if let Some(checkpoint) = checkpoint {
1302 git_store
1303 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1304 .await?;
1305 }
1306
1307 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1308 .await?;
1309 this.update(cx, |this, cx| {
1310 if let Some((ix, _)) = this.user_message_mut(&id) {
1311 let range = ix..this.entries.len();
1312 this.entries.truncate(ix);
1313 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1314 }
1315 })
1316 })
1317 }
1318
1319 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1320 self.entries.iter().find_map(|entry| {
1321 if let AgentThreadEntry::UserMessage(message) = entry {
1322 if message.id.as_ref() == Some(&id) {
1323 Some(message)
1324 } else {
1325 None
1326 }
1327 } else {
1328 None
1329 }
1330 })
1331 }
1332
1333 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1334 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1335 if let AgentThreadEntry::UserMessage(message) = entry {
1336 if message.id.as_ref() == Some(&id) {
1337 Some((ix, message))
1338 } else {
1339 None
1340 }
1341 } else {
1342 None
1343 }
1344 })
1345 }
1346
1347 pub fn read_text_file(
1348 &self,
1349 path: PathBuf,
1350 line: Option<u32>,
1351 limit: Option<u32>,
1352 reuse_shared_snapshot: bool,
1353 cx: &mut Context<Self>,
1354 ) -> Task<Result<String>> {
1355 let project = self.project.clone();
1356 let action_log = self.action_log.clone();
1357 cx.spawn(async move |this, cx| {
1358 let load = project.update(cx, |project, cx| {
1359 let path = project
1360 .project_path_for_absolute_path(&path, cx)
1361 .context("invalid path")?;
1362 anyhow::Ok(project.open_buffer(path, cx))
1363 });
1364 let buffer = load??.await?;
1365
1366 let snapshot = if reuse_shared_snapshot {
1367 this.read_with(cx, |this, _| {
1368 this.shared_buffers.get(&buffer.clone()).cloned()
1369 })
1370 .log_err()
1371 .flatten()
1372 } else {
1373 None
1374 };
1375
1376 let snapshot = if let Some(snapshot) = snapshot {
1377 snapshot
1378 } else {
1379 action_log.update(cx, |action_log, cx| {
1380 action_log.buffer_read(buffer.clone(), cx);
1381 })?;
1382 project.update(cx, |project, cx| {
1383 let position = buffer
1384 .read(cx)
1385 .snapshot()
1386 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1387 project.set_agent_location(
1388 Some(AgentLocation {
1389 buffer: buffer.downgrade(),
1390 position,
1391 }),
1392 cx,
1393 );
1394 })?;
1395
1396 buffer.update(cx, |buffer, _| buffer.snapshot())?
1397 };
1398
1399 this.update(cx, |this, _| {
1400 let text = snapshot.text();
1401 this.shared_buffers.insert(buffer.clone(), snapshot);
1402 if line.is_none() && limit.is_none() {
1403 return Ok(text);
1404 }
1405 let limit = limit.unwrap_or(u32::MAX) as usize;
1406 let Some(line) = line else {
1407 return Ok(text.lines().take(limit).collect::<String>());
1408 };
1409
1410 let count = text.lines().count();
1411 if count < line as usize {
1412 anyhow::bail!("There are only {} lines", count);
1413 }
1414 Ok(text
1415 .lines()
1416 .skip(line as usize + 1)
1417 .take(limit)
1418 .collect::<String>())
1419 })?
1420 })
1421 }
1422
1423 pub fn write_text_file(
1424 &self,
1425 path: PathBuf,
1426 content: String,
1427 cx: &mut Context<Self>,
1428 ) -> Task<Result<()>> {
1429 let project = self.project.clone();
1430 let action_log = self.action_log.clone();
1431 cx.spawn(async move |this, cx| {
1432 let load = project.update(cx, |project, cx| {
1433 let path = project
1434 .project_path_for_absolute_path(&path, cx)
1435 .context("invalid path")?;
1436 anyhow::Ok(project.open_buffer(path, cx))
1437 });
1438 let buffer = load??.await?;
1439 let snapshot = this.update(cx, |this, cx| {
1440 this.shared_buffers
1441 .get(&buffer)
1442 .cloned()
1443 .unwrap_or_else(|| buffer.read(cx).snapshot())
1444 })?;
1445 let edits = cx
1446 .background_executor()
1447 .spawn(async move {
1448 let old_text = snapshot.text();
1449 text_diff(old_text.as_str(), &content)
1450 .into_iter()
1451 .map(|(range, replacement)| {
1452 (
1453 snapshot.anchor_after(range.start)
1454 ..snapshot.anchor_before(range.end),
1455 replacement,
1456 )
1457 })
1458 .collect::<Vec<_>>()
1459 })
1460 .await;
1461 cx.update(|cx| {
1462 project.update(cx, |project, cx| {
1463 project.set_agent_location(
1464 Some(AgentLocation {
1465 buffer: buffer.downgrade(),
1466 position: edits
1467 .last()
1468 .map(|(range, _)| range.end)
1469 .unwrap_or(Anchor::MIN),
1470 }),
1471 cx,
1472 );
1473 });
1474
1475 action_log.update(cx, |action_log, cx| {
1476 action_log.buffer_read(buffer.clone(), cx);
1477 });
1478 buffer.update(cx, |buffer, cx| {
1479 buffer.edit(edits, None, cx);
1480 });
1481 action_log.update(cx, |action_log, cx| {
1482 action_log.buffer_edited(buffer.clone(), cx);
1483 });
1484 })?;
1485 project
1486 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1487 .await
1488 })
1489 }
1490
1491 pub fn to_markdown(&self, cx: &App) -> String {
1492 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1493 }
1494
1495 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1496 cx.emit(AcpThreadEvent::ServerExited(status));
1497 }
1498}
1499
1500fn markdown_for_raw_output(
1501 raw_output: &serde_json::Value,
1502 language_registry: &Arc<LanguageRegistry>,
1503 cx: &mut App,
1504) -> Option<Entity<Markdown>> {
1505 match raw_output {
1506 serde_json::Value::Null => None,
1507 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1508 Markdown::new(
1509 value.to_string().into(),
1510 Some(language_registry.clone()),
1511 None,
1512 cx,
1513 )
1514 })),
1515 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1516 Markdown::new(
1517 value.to_string().into(),
1518 Some(language_registry.clone()),
1519 None,
1520 cx,
1521 )
1522 })),
1523 serde_json::Value::String(value) => Some(cx.new(|cx| {
1524 Markdown::new(
1525 value.clone().into(),
1526 Some(language_registry.clone()),
1527 None,
1528 cx,
1529 )
1530 })),
1531 value => Some(cx.new(|cx| {
1532 Markdown::new(
1533 format!("```json\n{}\n```", value).into(),
1534 Some(language_registry.clone()),
1535 None,
1536 cx,
1537 )
1538 })),
1539 }
1540}
1541
1542#[cfg(test)]
1543mod tests {
1544 use super::*;
1545 use anyhow::anyhow;
1546 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1547 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1548 use indoc::indoc;
1549 use project::{FakeFs, Fs};
1550 use rand::Rng as _;
1551 use serde_json::json;
1552 use settings::SettingsStore;
1553 use smol::stream::StreamExt as _;
1554 use std::{
1555 cell::RefCell,
1556 path::Path,
1557 rc::Rc,
1558 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1559 time::Duration,
1560 };
1561 use util::path;
1562
1563 fn init_test(cx: &mut TestAppContext) {
1564 env_logger::try_init().ok();
1565 cx.update(|cx| {
1566 let settings_store = SettingsStore::test(cx);
1567 cx.set_global(settings_store);
1568 Project::init_settings(cx);
1569 language::init(cx);
1570 });
1571 }
1572
1573 #[gpui::test]
1574 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1575 init_test(cx);
1576
1577 let fs = FakeFs::new(cx.executor());
1578 let project = Project::test(fs, [], cx).await;
1579 let connection = Rc::new(FakeAgentConnection::new());
1580 let thread = cx
1581 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1582 .await
1583 .unwrap();
1584
1585 // Test creating a new user message
1586 thread.update(cx, |thread, cx| {
1587 thread.push_user_content_block(
1588 None,
1589 acp::ContentBlock::Text(acp::TextContent {
1590 annotations: None,
1591 text: "Hello, ".to_string(),
1592 }),
1593 cx,
1594 );
1595 });
1596
1597 thread.update(cx, |thread, cx| {
1598 assert_eq!(thread.entries.len(), 1);
1599 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1600 assert_eq!(user_msg.id, None);
1601 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1602 } else {
1603 panic!("Expected UserMessage");
1604 }
1605 });
1606
1607 // Test appending to existing user message
1608 let message_1_id = UserMessageId::new();
1609 thread.update(cx, |thread, cx| {
1610 thread.push_user_content_block(
1611 Some(message_1_id.clone()),
1612 acp::ContentBlock::Text(acp::TextContent {
1613 annotations: None,
1614 text: "world!".to_string(),
1615 }),
1616 cx,
1617 );
1618 });
1619
1620 thread.update(cx, |thread, cx| {
1621 assert_eq!(thread.entries.len(), 1);
1622 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1623 assert_eq!(user_msg.id, Some(message_1_id));
1624 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1625 } else {
1626 panic!("Expected UserMessage");
1627 }
1628 });
1629
1630 // Test creating new user message after assistant message
1631 thread.update(cx, |thread, cx| {
1632 thread.push_assistant_content_block(
1633 acp::ContentBlock::Text(acp::TextContent {
1634 annotations: None,
1635 text: "Assistant response".to_string(),
1636 }),
1637 false,
1638 cx,
1639 );
1640 });
1641
1642 let message_2_id = UserMessageId::new();
1643 thread.update(cx, |thread, cx| {
1644 thread.push_user_content_block(
1645 Some(message_2_id.clone()),
1646 acp::ContentBlock::Text(acp::TextContent {
1647 annotations: None,
1648 text: "New user message".to_string(),
1649 }),
1650 cx,
1651 );
1652 });
1653
1654 thread.update(cx, |thread, cx| {
1655 assert_eq!(thread.entries.len(), 3);
1656 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1657 assert_eq!(user_msg.id, Some(message_2_id));
1658 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1659 } else {
1660 panic!("Expected UserMessage at index 2");
1661 }
1662 });
1663 }
1664
1665 #[gpui::test]
1666 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1667 init_test(cx);
1668
1669 let fs = FakeFs::new(cx.executor());
1670 let project = Project::test(fs, [], cx).await;
1671 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1672 |_, thread, mut cx| {
1673 async move {
1674 thread.update(&mut cx, |thread, cx| {
1675 thread
1676 .handle_session_update(
1677 acp::SessionUpdate::AgentThoughtChunk {
1678 content: "Thinking ".into(),
1679 },
1680 cx,
1681 )
1682 .unwrap();
1683 thread
1684 .handle_session_update(
1685 acp::SessionUpdate::AgentThoughtChunk {
1686 content: "hard!".into(),
1687 },
1688 cx,
1689 )
1690 .unwrap();
1691 })?;
1692 Ok(acp::PromptResponse {
1693 stop_reason: acp::StopReason::EndTurn,
1694 })
1695 }
1696 .boxed_local()
1697 },
1698 ));
1699
1700 let thread = cx
1701 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1702 .await
1703 .unwrap();
1704
1705 thread
1706 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1707 .await
1708 .unwrap();
1709
1710 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1711 assert_eq!(
1712 output,
1713 indoc! {r#"
1714 ## User
1715
1716 Hello from Zed!
1717
1718 ## Assistant
1719
1720 <thinking>
1721 Thinking hard!
1722 </thinking>
1723
1724 "#}
1725 );
1726 }
1727
1728 #[gpui::test]
1729 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1730 init_test(cx);
1731
1732 let fs = FakeFs::new(cx.executor());
1733 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1734 .await;
1735 let project = Project::test(fs.clone(), [], cx).await;
1736 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1737 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1738 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1739 move |_, thread, mut cx| {
1740 let read_file_tx = read_file_tx.clone();
1741 async move {
1742 let content = thread
1743 .update(&mut cx, |thread, cx| {
1744 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1745 })
1746 .unwrap()
1747 .await
1748 .unwrap();
1749 assert_eq!(content, "one\ntwo\nthree\n");
1750 read_file_tx.take().unwrap().send(()).unwrap();
1751 thread
1752 .update(&mut cx, |thread, cx| {
1753 thread.write_text_file(
1754 path!("/tmp/foo").into(),
1755 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1756 cx,
1757 )
1758 })
1759 .unwrap()
1760 .await
1761 .unwrap();
1762 Ok(acp::PromptResponse {
1763 stop_reason: acp::StopReason::EndTurn,
1764 })
1765 }
1766 .boxed_local()
1767 },
1768 ));
1769
1770 let (worktree, pathbuf) = project
1771 .update(cx, |project, cx| {
1772 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1773 })
1774 .await
1775 .unwrap();
1776 let buffer = project
1777 .update(cx, |project, cx| {
1778 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1779 })
1780 .await
1781 .unwrap();
1782
1783 let thread = cx
1784 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1785 .await
1786 .unwrap();
1787
1788 let request = thread.update(cx, |thread, cx| {
1789 thread.send_raw("Extend the count in /tmp/foo", cx)
1790 });
1791 read_file_rx.await.ok();
1792 buffer.update(cx, |buffer, cx| {
1793 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1794 });
1795 cx.run_until_parked();
1796 assert_eq!(
1797 buffer.read_with(cx, |buffer, _| buffer.text()),
1798 "zero\none\ntwo\nthree\nfour\nfive\n"
1799 );
1800 assert_eq!(
1801 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1802 "zero\none\ntwo\nthree\nfour\nfive\n"
1803 );
1804 request.await.unwrap();
1805 }
1806
1807 #[gpui::test]
1808 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1809 init_test(cx);
1810
1811 let fs = FakeFs::new(cx.executor());
1812 let project = Project::test(fs, [], cx).await;
1813 let id = acp::ToolCallId("test".into());
1814
1815 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1816 let id = id.clone();
1817 move |_, thread, mut cx| {
1818 let id = id.clone();
1819 async move {
1820 thread
1821 .update(&mut cx, |thread, cx| {
1822 thread.handle_session_update(
1823 acp::SessionUpdate::ToolCall(acp::ToolCall {
1824 id: id.clone(),
1825 title: "Label".into(),
1826 kind: acp::ToolKind::Fetch,
1827 status: acp::ToolCallStatus::InProgress,
1828 content: vec![],
1829 locations: vec![],
1830 raw_input: None,
1831 raw_output: None,
1832 }),
1833 cx,
1834 )
1835 })
1836 .unwrap()
1837 .unwrap();
1838 Ok(acp::PromptResponse {
1839 stop_reason: acp::StopReason::EndTurn,
1840 })
1841 }
1842 .boxed_local()
1843 }
1844 }));
1845
1846 let thread = cx
1847 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1848 .await
1849 .unwrap();
1850
1851 let request = thread.update(cx, |thread, cx| {
1852 thread.send_raw("Fetch https://example.com", cx)
1853 });
1854
1855 run_until_first_tool_call(&thread, cx).await;
1856
1857 thread.read_with(cx, |thread, _| {
1858 assert!(matches!(
1859 thread.entries[1],
1860 AgentThreadEntry::ToolCall(ToolCall {
1861 status: ToolCallStatus::Allowed {
1862 status: acp::ToolCallStatus::InProgress,
1863 ..
1864 },
1865 ..
1866 })
1867 ));
1868 });
1869
1870 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1871
1872 thread.read_with(cx, |thread, _| {
1873 assert!(matches!(
1874 &thread.entries[1],
1875 AgentThreadEntry::ToolCall(ToolCall {
1876 status: ToolCallStatus::Canceled,
1877 ..
1878 })
1879 ));
1880 });
1881
1882 thread
1883 .update(cx, |thread, cx| {
1884 thread.handle_session_update(
1885 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1886 id,
1887 fields: acp::ToolCallUpdateFields {
1888 status: Some(acp::ToolCallStatus::Completed),
1889 ..Default::default()
1890 },
1891 }),
1892 cx,
1893 )
1894 })
1895 .unwrap();
1896
1897 request.await.unwrap();
1898
1899 thread.read_with(cx, |thread, _| {
1900 assert!(matches!(
1901 thread.entries[1],
1902 AgentThreadEntry::ToolCall(ToolCall {
1903 status: ToolCallStatus::Allowed {
1904 status: acp::ToolCallStatus::Completed,
1905 ..
1906 },
1907 ..
1908 })
1909 ));
1910 });
1911 }
1912
1913 #[gpui::test]
1914 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1915 init_test(cx);
1916 let fs = FakeFs::new(cx.background_executor.clone());
1917 fs.insert_tree(path!("/test"), json!({})).await;
1918 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1919
1920 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1921 move |_, thread, mut cx| {
1922 async move {
1923 thread
1924 .update(&mut cx, |thread, cx| {
1925 thread.handle_session_update(
1926 acp::SessionUpdate::ToolCall(acp::ToolCall {
1927 id: acp::ToolCallId("test".into()),
1928 title: "Label".into(),
1929 kind: acp::ToolKind::Edit,
1930 status: acp::ToolCallStatus::Completed,
1931 content: vec![acp::ToolCallContent::Diff {
1932 diff: acp::Diff {
1933 path: "/test/test.txt".into(),
1934 old_text: None,
1935 new_text: "foo".into(),
1936 },
1937 }],
1938 locations: vec![],
1939 raw_input: None,
1940 raw_output: None,
1941 }),
1942 cx,
1943 )
1944 })
1945 .unwrap()
1946 .unwrap();
1947 Ok(acp::PromptResponse {
1948 stop_reason: acp::StopReason::EndTurn,
1949 })
1950 }
1951 .boxed_local()
1952 }
1953 }));
1954
1955 let thread = cx
1956 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1957 .await
1958 .unwrap();
1959
1960 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1961 .await
1962 .unwrap();
1963
1964 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1965 }
1966
1967 #[gpui::test(iterations = 10)]
1968 async fn test_checkpoints(cx: &mut TestAppContext) {
1969 init_test(cx);
1970 let fs = FakeFs::new(cx.background_executor.clone());
1971 fs.insert_tree(
1972 path!("/test"),
1973 json!({
1974 ".git": {}
1975 }),
1976 )
1977 .await;
1978 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1979
1980 let simulate_changes = Arc::new(AtomicBool::new(true));
1981 let next_filename = Arc::new(AtomicUsize::new(0));
1982 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1983 let simulate_changes = simulate_changes.clone();
1984 let next_filename = next_filename.clone();
1985 let fs = fs.clone();
1986 move |request, thread, mut cx| {
1987 let fs = fs.clone();
1988 let simulate_changes = simulate_changes.clone();
1989 let next_filename = next_filename.clone();
1990 async move {
1991 if simulate_changes.load(SeqCst) {
1992 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
1993 fs.write(Path::new(&filename), b"").await?;
1994 }
1995
1996 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
1997 panic!("expected text content block");
1998 };
1999 thread.update(&mut cx, |thread, cx| {
2000 thread
2001 .handle_session_update(
2002 acp::SessionUpdate::AgentMessageChunk {
2003 content: content.text.to_uppercase().into(),
2004 },
2005 cx,
2006 )
2007 .unwrap();
2008 })?;
2009 Ok(acp::PromptResponse {
2010 stop_reason: acp::StopReason::EndTurn,
2011 })
2012 }
2013 .boxed_local()
2014 }
2015 }));
2016 let thread = cx
2017 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2018 .await
2019 .unwrap();
2020
2021 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2022 .await
2023 .unwrap();
2024 thread.read_with(cx, |thread, cx| {
2025 assert_eq!(
2026 thread.to_markdown(cx),
2027 indoc! {"
2028 ## User (checkpoint)
2029
2030 Lorem
2031
2032 ## Assistant
2033
2034 LOREM
2035
2036 "}
2037 );
2038 });
2039 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2040
2041 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2042 .await
2043 .unwrap();
2044 thread.read_with(cx, |thread, cx| {
2045 assert_eq!(
2046 thread.to_markdown(cx),
2047 indoc! {"
2048 ## User (checkpoint)
2049
2050 Lorem
2051
2052 ## Assistant
2053
2054 LOREM
2055
2056 ## User (checkpoint)
2057
2058 ipsum
2059
2060 ## Assistant
2061
2062 IPSUM
2063
2064 "}
2065 );
2066 });
2067 assert_eq!(
2068 fs.files(),
2069 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2070 );
2071
2072 // Checkpoint isn't stored when there are no changes.
2073 simulate_changes.store(false, SeqCst);
2074 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2075 .await
2076 .unwrap();
2077 thread.read_with(cx, |thread, cx| {
2078 assert_eq!(
2079 thread.to_markdown(cx),
2080 indoc! {"
2081 ## User (checkpoint)
2082
2083 Lorem
2084
2085 ## Assistant
2086
2087 LOREM
2088
2089 ## User (checkpoint)
2090
2091 ipsum
2092
2093 ## Assistant
2094
2095 IPSUM
2096
2097 ## User
2098
2099 dolor
2100
2101 ## Assistant
2102
2103 DOLOR
2104
2105 "}
2106 );
2107 });
2108 assert_eq!(
2109 fs.files(),
2110 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2111 );
2112
2113 // Rewinding the conversation truncates the history and restores the checkpoint.
2114 thread
2115 .update(cx, |thread, cx| {
2116 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2117 panic!("unexpected entries {:?}", thread.entries)
2118 };
2119 thread.rewind(message.id.clone().unwrap(), cx)
2120 })
2121 .await
2122 .unwrap();
2123 thread.read_with(cx, |thread, cx| {
2124 assert_eq!(
2125 thread.to_markdown(cx),
2126 indoc! {"
2127 ## User (checkpoint)
2128
2129 Lorem
2130
2131 ## Assistant
2132
2133 LOREM
2134
2135 "}
2136 );
2137 });
2138 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2139 }
2140
2141 async fn run_until_first_tool_call(
2142 thread: &Entity<AcpThread>,
2143 cx: &mut TestAppContext,
2144 ) -> usize {
2145 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2146
2147 let subscription = cx.update(|cx| {
2148 cx.subscribe(thread, move |thread, _, cx| {
2149 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2150 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2151 return tx.try_send(ix).unwrap();
2152 }
2153 }
2154 })
2155 });
2156
2157 select! {
2158 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2159 panic!("Timeout waiting for tool call")
2160 }
2161 ix = rx.next().fuse() => {
2162 drop(subscription);
2163 ix.unwrap()
2164 }
2165 }
2166 }
2167
2168 #[derive(Clone, Default)]
2169 struct FakeAgentConnection {
2170 auth_methods: Vec<acp::AuthMethod>,
2171 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2172 on_user_message: Option<
2173 Rc<
2174 dyn Fn(
2175 acp::PromptRequest,
2176 WeakEntity<AcpThread>,
2177 AsyncApp,
2178 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2179 + 'static,
2180 >,
2181 >,
2182 }
2183
2184 impl FakeAgentConnection {
2185 fn new() -> Self {
2186 Self {
2187 auth_methods: Vec::new(),
2188 on_user_message: None,
2189 sessions: Arc::default(),
2190 }
2191 }
2192
2193 #[expect(unused)]
2194 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2195 self.auth_methods = auth_methods;
2196 self
2197 }
2198
2199 fn on_user_message(
2200 mut self,
2201 handler: impl Fn(
2202 acp::PromptRequest,
2203 WeakEntity<AcpThread>,
2204 AsyncApp,
2205 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2206 + 'static,
2207 ) -> Self {
2208 self.on_user_message.replace(Rc::new(handler));
2209 self
2210 }
2211 }
2212
2213 impl AgentConnection for FakeAgentConnection {
2214 fn auth_methods(&self) -> &[acp::AuthMethod] {
2215 &self.auth_methods
2216 }
2217
2218 fn new_thread(
2219 self: Rc<Self>,
2220 project: Entity<Project>,
2221 _cwd: &Path,
2222 cx: &mut gpui::App,
2223 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2224 let session_id = acp::SessionId(
2225 rand::thread_rng()
2226 .sample_iter(&rand::distributions::Alphanumeric)
2227 .take(7)
2228 .map(char::from)
2229 .collect::<String>()
2230 .into(),
2231 );
2232 let thread =
2233 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2234 self.sessions.lock().insert(session_id, thread.downgrade());
2235 Task::ready(Ok(thread))
2236 }
2237
2238 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2239 if self.auth_methods().iter().any(|m| m.id == method) {
2240 Task::ready(Ok(()))
2241 } else {
2242 Task::ready(Err(anyhow!("Invalid Auth Method")))
2243 }
2244 }
2245
2246 fn prompt(
2247 &self,
2248 _id: Option<UserMessageId>,
2249 params: acp::PromptRequest,
2250 cx: &mut App,
2251 ) -> Task<gpui::Result<acp::PromptResponse>> {
2252 let sessions = self.sessions.lock();
2253 let thread = sessions.get(¶ms.session_id).unwrap();
2254 if let Some(handler) = &self.on_user_message {
2255 let handler = handler.clone();
2256 let thread = thread.clone();
2257 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2258 } else {
2259 Task::ready(Ok(acp::PromptResponse {
2260 stop_reason: acp::StopReason::EndTurn,
2261 }))
2262 }
2263 }
2264
2265 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2266 let sessions = self.sessions.lock();
2267 let thread = sessions.get(&session_id).unwrap().clone();
2268
2269 cx.spawn(async move |cx| {
2270 thread
2271 .update(cx, |thread, cx| thread.cancel(cx))
2272 .unwrap()
2273 .await
2274 })
2275 .detach();
2276 }
2277
2278 fn session_editor(
2279 &self,
2280 session_id: &acp::SessionId,
2281 _cx: &mut App,
2282 ) -> Option<Rc<dyn AgentSessionEditor>> {
2283 Some(Rc::new(FakeAgentSessionEditor {
2284 _session_id: session_id.clone(),
2285 }))
2286 }
2287 }
2288
2289 struct FakeAgentSessionEditor {
2290 _session_id: acp::SessionId,
2291 }
2292
2293 impl AgentSessionEditor for FakeAgentSessionEditor {
2294 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2295 Task::ready(Ok(()))
2296 }
2297 }
2298}