1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
28use ui::App;
29use util::ResultExt;
30
31#[derive(Debug)]
32pub struct UserMessage {
33 pub id: Option<UserMessageId>,
34 pub content: ContentBlock,
35 pub checkpoint: Option<GitStoreCheckpoint>,
36}
37
38impl UserMessage {
39 fn to_markdown(&self, cx: &App) -> String {
40 let mut markdown = String::new();
41 if let Some(_) = self.checkpoint {
42 writeln!(markdown, "## User (checkpoint)").unwrap();
43 } else {
44 writeln!(markdown, "## User").unwrap();
45 }
46 writeln!(markdown).unwrap();
47 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
48 writeln!(markdown).unwrap();
49 markdown
50 }
51}
52
53#[derive(Debug, PartialEq)]
54pub struct AssistantMessage {
55 pub chunks: Vec<AssistantMessageChunk>,
56}
57
58impl AssistantMessage {
59 pub fn to_markdown(&self, cx: &App) -> String {
60 format!(
61 "## Assistant\n\n{}\n\n",
62 self.chunks
63 .iter()
64 .map(|chunk| chunk.to_markdown(cx))
65 .join("\n\n")
66 )
67 }
68}
69
70#[derive(Debug, PartialEq)]
71pub enum AssistantMessageChunk {
72 Message { block: ContentBlock },
73 Thought { block: ContentBlock },
74}
75
76impl AssistantMessageChunk {
77 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
78 Self::Message {
79 block: ContentBlock::new(chunk.into(), language_registry, cx),
80 }
81 }
82
83 fn to_markdown(&self, cx: &App) -> String {
84 match self {
85 Self::Message { block } => block.to_markdown(cx).to_string(),
86 Self::Thought { block } => {
87 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum AgentThreadEntry {
95 UserMessage(UserMessage),
96 AssistantMessage(AssistantMessage),
97 ToolCall(ToolCall),
98}
99
100impl AgentThreadEntry {
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::UserMessage(message) => message.to_markdown(cx),
104 Self::AssistantMessage(message) => message.to_markdown(cx),
105 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
106 }
107 }
108
109 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
110 if let AgentThreadEntry::ToolCall(call) = self {
111 itertools::Either::Left(call.diffs())
112 } else {
113 itertools::Either::Right(std::iter::empty())
114 }
115 }
116
117 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
118 if let AgentThreadEntry::ToolCall(call) = self {
119 itertools::Either::Left(call.terminals())
120 } else {
121 itertools::Either::Right(std::iter::empty())
122 }
123 }
124
125 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
126 if let AgentThreadEntry::ToolCall(ToolCall {
127 locations,
128 resolved_locations,
129 ..
130 }) = self
131 {
132 Some((
133 locations.get(ix)?.clone(),
134 resolved_locations.get(ix)?.clone()?,
135 ))
136 } else {
137 None
138 }
139 }
140}
141
142#[derive(Debug)]
143pub struct ToolCall {
144 pub id: acp::ToolCallId,
145 pub label: Entity<Markdown>,
146 pub kind: acp::ToolKind,
147 pub content: Vec<ToolCallContent>,
148 pub status: ToolCallStatus,
149 pub locations: Vec<acp::ToolCallLocation>,
150 pub resolved_locations: Vec<Option<AgentLocation>>,
151 pub raw_input: Option<serde_json::Value>,
152 pub raw_output: Option<serde_json::Value>,
153}
154
155impl ToolCall {
156 fn from_acp(
157 tool_call: acp::ToolCall,
158 status: ToolCallStatus,
159 language_registry: Arc<LanguageRegistry>,
160 cx: &mut App,
161 ) -> Self {
162 Self {
163 id: tool_call.id,
164 label: cx.new(|cx| {
165 Markdown::new(
166 tool_call.title.into(),
167 Some(language_registry.clone()),
168 None,
169 cx,
170 )
171 }),
172 kind: tool_call.kind,
173 content: tool_call
174 .content
175 .into_iter()
176 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
177 .collect(),
178 locations: tool_call.locations,
179 resolved_locations: Vec::default(),
180 status,
181 raw_input: tool_call.raw_input,
182 raw_output: tool_call.raw_output,
183 }
184 }
185
186 fn update_fields(
187 &mut self,
188 fields: acp::ToolCallUpdateFields,
189 language_registry: Arc<LanguageRegistry>,
190 cx: &mut App,
191 ) {
192 let acp::ToolCallUpdateFields {
193 kind,
194 status,
195 title,
196 content,
197 locations,
198 raw_input,
199 raw_output,
200 } = fields;
201
202 if let Some(kind) = kind {
203 self.kind = kind;
204 }
205
206 if let Some(status) = status {
207 self.status = ToolCallStatus::Allowed { status };
208 }
209
210 if let Some(title) = title {
211 self.label.update(cx, |label, cx| {
212 label.replace(title, cx);
213 });
214 }
215
216 if let Some(content) = content {
217 self.content = content
218 .into_iter()
219 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
220 .collect();
221 }
222
223 if let Some(locations) = locations {
224 self.locations = locations;
225 }
226
227 if let Some(raw_input) = raw_input {
228 self.raw_input = Some(raw_input);
229 }
230
231 if let Some(raw_output) = raw_output {
232 if self.content.is_empty() {
233 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
234 {
235 self.content
236 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
237 markdown,
238 }));
239 }
240 }
241 self.raw_output = Some(raw_output);
242 }
243 }
244
245 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
246 self.content.iter().filter_map(|content| match content {
247 ToolCallContent::Diff(diff) => Some(diff),
248 ToolCallContent::ContentBlock(_) => None,
249 ToolCallContent::Terminal(_) => None,
250 })
251 }
252
253 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
254 self.content.iter().filter_map(|content| match content {
255 ToolCallContent::Terminal(terminal) => Some(terminal),
256 ToolCallContent::ContentBlock(_) => None,
257 ToolCallContent::Diff(_) => None,
258 })
259 }
260
261 fn to_markdown(&self, cx: &App) -> String {
262 let mut markdown = format!(
263 "**Tool Call: {}**\nStatus: {}\n\n",
264 self.label.read(cx).source(),
265 self.status
266 );
267 for content in &self.content {
268 markdown.push_str(content.to_markdown(cx).as_str());
269 markdown.push_str("\n\n");
270 }
271 markdown
272 }
273
274 async fn resolve_location(
275 location: acp::ToolCallLocation,
276 project: WeakEntity<Project>,
277 cx: &mut AsyncApp,
278 ) -> Option<AgentLocation> {
279 let buffer = project
280 .update(cx, |project, cx| {
281 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
282 Some(project.open_buffer(path, cx))
283 } else {
284 None
285 }
286 })
287 .ok()??;
288 let buffer = buffer.await.log_err()?;
289 let position = buffer
290 .update(cx, |buffer, _| {
291 if let Some(row) = location.line {
292 let snapshot = buffer.snapshot();
293 let column = snapshot.indent_size_for_line(row).len;
294 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
295 snapshot.anchor_before(point)
296 } else {
297 Anchor::MIN
298 }
299 })
300 .ok()?;
301
302 Some(AgentLocation {
303 buffer: buffer.downgrade(),
304 position,
305 })
306 }
307
308 fn resolve_locations(
309 &self,
310 project: Entity<Project>,
311 cx: &mut App,
312 ) -> Task<Vec<Option<AgentLocation>>> {
313 let locations = self.locations.clone();
314 project.update(cx, |_, cx| {
315 cx.spawn(async move |project, cx| {
316 let mut new_locations = Vec::new();
317 for location in locations {
318 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
319 }
320 new_locations
321 })
322 })
323 }
324}
325
326#[derive(Debug)]
327pub enum ToolCallStatus {
328 WaitingForConfirmation {
329 options: Vec<acp::PermissionOption>,
330 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
331 },
332 Allowed {
333 status: acp::ToolCallStatus,
334 },
335 Rejected,
336 Canceled,
337}
338
339impl Display for ToolCallStatus {
340 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
341 write!(
342 f,
343 "{}",
344 match self {
345 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
346 ToolCallStatus::Allowed { status } => match status {
347 acp::ToolCallStatus::Pending => "Pending",
348 acp::ToolCallStatus::InProgress => "In Progress",
349 acp::ToolCallStatus::Completed => "Completed",
350 acp::ToolCallStatus::Failed => "Failed",
351 },
352 ToolCallStatus::Rejected => "Rejected",
353 ToolCallStatus::Canceled => "Canceled",
354 }
355 )
356 }
357}
358
359#[derive(Debug, PartialEq, Clone)]
360pub enum ContentBlock {
361 Empty,
362 Markdown { markdown: Entity<Markdown> },
363 ResourceLink { resource_link: acp::ResourceLink },
364}
365
366impl ContentBlock {
367 pub fn new(
368 block: acp::ContentBlock,
369 language_registry: &Arc<LanguageRegistry>,
370 cx: &mut App,
371 ) -> Self {
372 let mut this = Self::Empty;
373 this.append(block, language_registry, cx);
374 this
375 }
376
377 pub fn new_combined(
378 blocks: impl IntoIterator<Item = acp::ContentBlock>,
379 language_registry: Arc<LanguageRegistry>,
380 cx: &mut App,
381 ) -> Self {
382 let mut this = Self::Empty;
383 for block in blocks {
384 this.append(block, &language_registry, cx);
385 }
386 this
387 }
388
389 pub fn append(
390 &mut self,
391 block: acp::ContentBlock,
392 language_registry: &Arc<LanguageRegistry>,
393 cx: &mut App,
394 ) {
395 if matches!(self, ContentBlock::Empty) {
396 if let acp::ContentBlock::ResourceLink(resource_link) = block {
397 *self = ContentBlock::ResourceLink { resource_link };
398 return;
399 }
400 }
401
402 let new_content = self.extract_content_from_block(block);
403
404 match self {
405 ContentBlock::Empty => {
406 *self = Self::create_markdown_block(new_content, language_registry, cx);
407 }
408 ContentBlock::Markdown { markdown } => {
409 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
410 }
411 ContentBlock::ResourceLink { resource_link } => {
412 let existing_content = Self::resource_link_to_content(&resource_link.uri);
413 let combined = format!("{}\n{}", existing_content, new_content);
414
415 *self = Self::create_markdown_block(combined, language_registry, cx);
416 }
417 }
418 }
419
420 fn resource_link_to_content(uri: &str) -> String {
421 if let Some(uri) = MentionUri::parse(&uri).log_err() {
422 uri.to_link()
423 } else {
424 uri.to_string().clone()
425 }
426 }
427
428 fn create_markdown_block(
429 content: String,
430 language_registry: &Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) -> ContentBlock {
433 ContentBlock::Markdown {
434 markdown: cx
435 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
436 }
437 }
438
439 fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
440 match block {
441 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
442 acp::ContentBlock::ResourceLink(resource_link) => {
443 Self::resource_link_to_content(&resource_link.uri)
444 }
445 acp::ContentBlock::Resource(acp::EmbeddedResource {
446 resource:
447 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
448 uri,
449 ..
450 }),
451 ..
452 }) => Self::resource_link_to_content(&uri),
453 acp::ContentBlock::Image(_)
454 | acp::ContentBlock::Audio(_)
455 | acp::ContentBlock::Resource(_) => String::new(),
456 }
457 }
458
459 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
460 match self {
461 ContentBlock::Empty => "",
462 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
463 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
464 }
465 }
466
467 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
468 match self {
469 ContentBlock::Empty => None,
470 ContentBlock::Markdown { markdown } => Some(markdown),
471 ContentBlock::ResourceLink { .. } => None,
472 }
473 }
474
475 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
476 match self {
477 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
478 _ => None,
479 }
480 }
481}
482
483#[derive(Debug)]
484pub enum ToolCallContent {
485 ContentBlock(ContentBlock),
486 Diff(Entity<Diff>),
487 Terminal(Entity<Terminal>),
488}
489
490impl ToolCallContent {
491 pub fn from_acp(
492 content: acp::ToolCallContent,
493 language_registry: Arc<LanguageRegistry>,
494 cx: &mut App,
495 ) -> Self {
496 match content {
497 acp::ToolCallContent::Content { content } => {
498 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
499 }
500 acp::ToolCallContent::Diff { diff } => {
501 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
502 }
503 }
504 }
505
506 pub fn to_markdown(&self, cx: &App) -> String {
507 match self {
508 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
509 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
510 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
511 }
512 }
513}
514
515#[derive(Debug, PartialEq)]
516pub enum ToolCallUpdate {
517 UpdateFields(acp::ToolCallUpdate),
518 UpdateDiff(ToolCallUpdateDiff),
519 UpdateTerminal(ToolCallUpdateTerminal),
520}
521
522impl ToolCallUpdate {
523 fn id(&self) -> &acp::ToolCallId {
524 match self {
525 Self::UpdateFields(update) => &update.id,
526 Self::UpdateDiff(diff) => &diff.id,
527 Self::UpdateTerminal(terminal) => &terminal.id,
528 }
529 }
530}
531
532impl From<acp::ToolCallUpdate> for ToolCallUpdate {
533 fn from(update: acp::ToolCallUpdate) -> Self {
534 Self::UpdateFields(update)
535 }
536}
537
538impl From<ToolCallUpdateDiff> for ToolCallUpdate {
539 fn from(diff: ToolCallUpdateDiff) -> Self {
540 Self::UpdateDiff(diff)
541 }
542}
543
544#[derive(Debug, PartialEq)]
545pub struct ToolCallUpdateDiff {
546 pub id: acp::ToolCallId,
547 pub diff: Entity<Diff>,
548}
549
550impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
551 fn from(terminal: ToolCallUpdateTerminal) -> Self {
552 Self::UpdateTerminal(terminal)
553 }
554}
555
556#[derive(Debug, PartialEq)]
557pub struct ToolCallUpdateTerminal {
558 pub id: acp::ToolCallId,
559 pub terminal: Entity<Terminal>,
560}
561
562#[derive(Debug, Default)]
563pub struct Plan {
564 pub entries: Vec<PlanEntry>,
565}
566
567#[derive(Debug)]
568pub struct PlanStats<'a> {
569 pub in_progress_entry: Option<&'a PlanEntry>,
570 pub pending: u32,
571 pub completed: u32,
572}
573
574impl Plan {
575 pub fn is_empty(&self) -> bool {
576 self.entries.is_empty()
577 }
578
579 pub fn stats(&self) -> PlanStats<'_> {
580 let mut stats = PlanStats {
581 in_progress_entry: None,
582 pending: 0,
583 completed: 0,
584 };
585
586 for entry in &self.entries {
587 match &entry.status {
588 acp::PlanEntryStatus::Pending => {
589 stats.pending += 1;
590 }
591 acp::PlanEntryStatus::InProgress => {
592 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
593 }
594 acp::PlanEntryStatus::Completed => {
595 stats.completed += 1;
596 }
597 }
598 }
599
600 stats
601 }
602}
603
604#[derive(Debug)]
605pub struct PlanEntry {
606 pub content: Entity<Markdown>,
607 pub priority: acp::PlanEntryPriority,
608 pub status: acp::PlanEntryStatus,
609}
610
611impl PlanEntry {
612 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
613 Self {
614 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
615 priority: entry.priority,
616 status: entry.status,
617 }
618 }
619}
620
621pub struct AcpThread {
622 title: SharedString,
623 entries: Vec<AgentThreadEntry>,
624 plan: Plan,
625 project: Entity<Project>,
626 action_log: Entity<ActionLog>,
627 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
628 send_task: Option<Task<()>>,
629 connection: Rc<dyn AgentConnection>,
630 session_id: acp::SessionId,
631}
632
633pub enum AcpThreadEvent {
634 NewEntry,
635 EntryUpdated(usize),
636 EntriesRemoved(Range<usize>),
637 ToolAuthorizationRequired,
638 Stopped,
639 Error,
640 ServerExited(ExitStatus),
641}
642
643impl EventEmitter<AcpThreadEvent> for AcpThread {}
644
645#[derive(PartialEq, Eq)]
646pub enum ThreadStatus {
647 Idle,
648 WaitingForToolConfirmation,
649 Generating,
650}
651
652#[derive(Debug, Clone)]
653pub enum LoadError {
654 Unsupported {
655 error_message: SharedString,
656 upgrade_message: SharedString,
657 upgrade_command: String,
658 },
659 Exited(i32),
660 Other(SharedString),
661}
662
663impl Display for LoadError {
664 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
665 match self {
666 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
667 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
668 LoadError::Other(msg) => write!(f, "{}", msg),
669 }
670 }
671}
672
673impl Error for LoadError {}
674
675impl AcpThread {
676 pub fn new(
677 title: impl Into<SharedString>,
678 connection: Rc<dyn AgentConnection>,
679 project: Entity<Project>,
680 session_id: acp::SessionId,
681 cx: &mut Context<Self>,
682 ) -> Self {
683 let action_log = cx.new(|_| ActionLog::new(project.clone()));
684
685 Self {
686 action_log,
687 shared_buffers: Default::default(),
688 entries: Default::default(),
689 plan: Default::default(),
690 title: title.into(),
691 project,
692 send_task: None,
693 connection,
694 session_id,
695 }
696 }
697
698 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
699 &self.connection
700 }
701
702 pub fn action_log(&self) -> &Entity<ActionLog> {
703 &self.action_log
704 }
705
706 pub fn project(&self) -> &Entity<Project> {
707 &self.project
708 }
709
710 pub fn title(&self) -> SharedString {
711 self.title.clone()
712 }
713
714 pub fn entries(&self) -> &[AgentThreadEntry] {
715 &self.entries
716 }
717
718 pub fn session_id(&self) -> &acp::SessionId {
719 &self.session_id
720 }
721
722 pub fn status(&self) -> ThreadStatus {
723 if self.send_task.is_some() {
724 if self.waiting_for_tool_confirmation() {
725 ThreadStatus::WaitingForToolConfirmation
726 } else {
727 ThreadStatus::Generating
728 }
729 } else {
730 ThreadStatus::Idle
731 }
732 }
733
734 pub fn has_pending_edit_tool_calls(&self) -> bool {
735 for entry in self.entries.iter().rev() {
736 match entry {
737 AgentThreadEntry::UserMessage(_) => return false,
738 AgentThreadEntry::ToolCall(
739 call @ ToolCall {
740 status:
741 ToolCallStatus::Allowed {
742 status:
743 acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
744 },
745 ..
746 },
747 ) if call.diffs().next().is_some() => {
748 return true;
749 }
750 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
751 }
752 }
753
754 false
755 }
756
757 pub fn used_tools_since_last_user_message(&self) -> bool {
758 for entry in self.entries.iter().rev() {
759 match entry {
760 AgentThreadEntry::UserMessage(..) => return false,
761 AgentThreadEntry::AssistantMessage(..) => continue,
762 AgentThreadEntry::ToolCall(..) => return true,
763 }
764 }
765
766 false
767 }
768
769 pub fn handle_session_update(
770 &mut self,
771 update: acp::SessionUpdate,
772 cx: &mut Context<Self>,
773 ) -> Result<()> {
774 match update {
775 acp::SessionUpdate::UserMessageChunk { content } => {
776 self.push_user_content_block(None, content, cx);
777 }
778 acp::SessionUpdate::AgentMessageChunk { content } => {
779 self.push_assistant_content_block(content, false, cx);
780 }
781 acp::SessionUpdate::AgentThoughtChunk { content } => {
782 self.push_assistant_content_block(content, true, cx);
783 }
784 acp::SessionUpdate::ToolCall(tool_call) => {
785 self.upsert_tool_call(tool_call, cx);
786 }
787 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
788 self.update_tool_call(tool_call_update, cx)?;
789 }
790 acp::SessionUpdate::Plan(plan) => {
791 self.update_plan(plan, cx);
792 }
793 }
794 Ok(())
795 }
796
797 pub fn push_user_content_block(
798 &mut self,
799 message_id: Option<UserMessageId>,
800 chunk: acp::ContentBlock,
801 cx: &mut Context<Self>,
802 ) {
803 let language_registry = self.project.read(cx).languages().clone();
804 let entries_len = self.entries.len();
805
806 if let Some(last_entry) = self.entries.last_mut()
807 && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
808 {
809 *id = message_id.or(id.take());
810 content.append(chunk, &language_registry, cx);
811 let idx = entries_len - 1;
812 cx.emit(AcpThreadEvent::EntryUpdated(idx));
813 } else {
814 let content = ContentBlock::new(chunk, &language_registry, cx);
815 self.push_entry(
816 AgentThreadEntry::UserMessage(UserMessage {
817 id: message_id,
818 content,
819 checkpoint: None,
820 }),
821 cx,
822 );
823 }
824 }
825
826 pub fn push_assistant_content_block(
827 &mut self,
828 chunk: acp::ContentBlock,
829 is_thought: bool,
830 cx: &mut Context<Self>,
831 ) {
832 let language_registry = self.project.read(cx).languages().clone();
833 let entries_len = self.entries.len();
834 if let Some(last_entry) = self.entries.last_mut()
835 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
836 {
837 let idx = entries_len - 1;
838 cx.emit(AcpThreadEvent::EntryUpdated(idx));
839 match (chunks.last_mut(), is_thought) {
840 (Some(AssistantMessageChunk::Message { block }), false)
841 | (Some(AssistantMessageChunk::Thought { block }), true) => {
842 block.append(chunk, &language_registry, cx)
843 }
844 _ => {
845 let block = ContentBlock::new(chunk, &language_registry, cx);
846 if is_thought {
847 chunks.push(AssistantMessageChunk::Thought { block })
848 } else {
849 chunks.push(AssistantMessageChunk::Message { block })
850 }
851 }
852 }
853 } else {
854 let block = ContentBlock::new(chunk, &language_registry, cx);
855 let chunk = if is_thought {
856 AssistantMessageChunk::Thought { block }
857 } else {
858 AssistantMessageChunk::Message { block }
859 };
860
861 self.push_entry(
862 AgentThreadEntry::AssistantMessage(AssistantMessage {
863 chunks: vec![chunk],
864 }),
865 cx,
866 );
867 }
868 }
869
870 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
871 self.entries.push(entry);
872 cx.emit(AcpThreadEvent::NewEntry);
873 }
874
875 pub fn update_tool_call(
876 &mut self,
877 update: impl Into<ToolCallUpdate>,
878 cx: &mut Context<Self>,
879 ) -> Result<()> {
880 let update = update.into();
881 let languages = self.project.read(cx).languages().clone();
882
883 let (ix, current_call) = self
884 .tool_call_mut(update.id())
885 .context("Tool call not found")?;
886 match update {
887 ToolCallUpdate::UpdateFields(update) => {
888 let location_updated = update.fields.locations.is_some();
889 current_call.update_fields(update.fields, languages, cx);
890 if location_updated {
891 self.resolve_locations(update.id.clone(), cx);
892 }
893 }
894 ToolCallUpdate::UpdateDiff(update) => {
895 current_call.content.clear();
896 current_call
897 .content
898 .push(ToolCallContent::Diff(update.diff));
899 }
900 ToolCallUpdate::UpdateTerminal(update) => {
901 current_call.content.clear();
902 current_call
903 .content
904 .push(ToolCallContent::Terminal(update.terminal));
905 }
906 }
907
908 cx.emit(AcpThreadEvent::EntryUpdated(ix));
909
910 Ok(())
911 }
912
913 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
914 pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
915 let status = ToolCallStatus::Allowed {
916 status: tool_call.status,
917 };
918 self.upsert_tool_call_inner(tool_call, status, cx)
919 }
920
921 pub fn upsert_tool_call_inner(
922 &mut self,
923 tool_call: acp::ToolCall,
924 status: ToolCallStatus,
925 cx: &mut Context<Self>,
926 ) {
927 let language_registry = self.project.read(cx).languages().clone();
928 let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
929 let id = call.id.clone();
930
931 if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
932 *current_call = call;
933
934 cx.emit(AcpThreadEvent::EntryUpdated(ix));
935 } else {
936 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
937 };
938
939 self.resolve_locations(id, cx);
940 }
941
942 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
943 // The tool call we are looking for is typically the last one, or very close to the end.
944 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
945 self.entries
946 .iter_mut()
947 .enumerate()
948 .rev()
949 .find_map(|(index, tool_call)| {
950 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
951 && &tool_call.id == id
952 {
953 Some((index, tool_call))
954 } else {
955 None
956 }
957 })
958 }
959
960 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
961 let project = self.project.clone();
962 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
963 return;
964 };
965 let task = tool_call.resolve_locations(project, cx);
966 cx.spawn(async move |this, cx| {
967 let resolved_locations = task.await;
968 this.update(cx, |this, cx| {
969 let project = this.project.clone();
970 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
971 return;
972 };
973 if let Some(Some(location)) = resolved_locations.last() {
974 project.update(cx, |project, cx| {
975 if let Some(agent_location) = project.agent_location() {
976 let should_ignore = agent_location.buffer == location.buffer
977 && location
978 .buffer
979 .update(cx, |buffer, _| {
980 let snapshot = buffer.snapshot();
981 let old_position =
982 agent_location.position.to_point(&snapshot);
983 let new_position = location.position.to_point(&snapshot);
984 // ignore this so that when we get updates from the edit tool
985 // the position doesn't reset to the startof line
986 old_position.row == new_position.row
987 && old_position.column > new_position.column
988 })
989 .ok()
990 .unwrap_or_default();
991 if !should_ignore {
992 project.set_agent_location(Some(location.clone()), cx);
993 }
994 }
995 });
996 }
997 if tool_call.resolved_locations != resolved_locations {
998 tool_call.resolved_locations = resolved_locations;
999 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1000 }
1001 })
1002 })
1003 .detach();
1004 }
1005
1006 pub fn request_tool_call_authorization(
1007 &mut self,
1008 tool_call: acp::ToolCall,
1009 options: Vec<acp::PermissionOption>,
1010 cx: &mut Context<Self>,
1011 ) -> oneshot::Receiver<acp::PermissionOptionId> {
1012 let (tx, rx) = oneshot::channel();
1013
1014 let status = ToolCallStatus::WaitingForConfirmation {
1015 options,
1016 respond_tx: tx,
1017 };
1018
1019 self.upsert_tool_call_inner(tool_call, status, cx);
1020 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1021 rx
1022 }
1023
1024 pub fn authorize_tool_call(
1025 &mut self,
1026 id: acp::ToolCallId,
1027 option_id: acp::PermissionOptionId,
1028 option_kind: acp::PermissionOptionKind,
1029 cx: &mut Context<Self>,
1030 ) {
1031 let Some((ix, call)) = self.tool_call_mut(&id) else {
1032 return;
1033 };
1034
1035 let new_status = match option_kind {
1036 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1037 ToolCallStatus::Rejected
1038 }
1039 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1040 ToolCallStatus::Allowed {
1041 status: acp::ToolCallStatus::InProgress,
1042 }
1043 }
1044 };
1045
1046 let curr_status = mem::replace(&mut call.status, new_status);
1047
1048 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1049 respond_tx.send(option_id).log_err();
1050 } else if cfg!(debug_assertions) {
1051 panic!("tried to authorize an already authorized tool call");
1052 }
1053
1054 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1055 }
1056
1057 /// Returns true if the last turn is awaiting tool authorization
1058 pub fn waiting_for_tool_confirmation(&self) -> bool {
1059 for entry in self.entries.iter().rev() {
1060 match &entry {
1061 AgentThreadEntry::ToolCall(call) => match call.status {
1062 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1063 ToolCallStatus::Allowed { .. }
1064 | ToolCallStatus::Rejected
1065 | ToolCallStatus::Canceled => continue,
1066 },
1067 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1068 // Reached the beginning of the turn
1069 return false;
1070 }
1071 }
1072 }
1073 false
1074 }
1075
1076 pub fn plan(&self) -> &Plan {
1077 &self.plan
1078 }
1079
1080 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1081 let new_entries_len = request.entries.len();
1082 let mut new_entries = request.entries.into_iter();
1083
1084 // Reuse existing markdown to prevent flickering
1085 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1086 let PlanEntry {
1087 content,
1088 priority,
1089 status,
1090 } = old;
1091 content.update(cx, |old, cx| {
1092 old.replace(new.content, cx);
1093 });
1094 *priority = new.priority;
1095 *status = new.status;
1096 }
1097 for new in new_entries {
1098 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1099 }
1100 self.plan.entries.truncate(new_entries_len);
1101
1102 cx.notify();
1103 }
1104
1105 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1106 self.plan
1107 .entries
1108 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1109 cx.notify();
1110 }
1111
1112 #[cfg(any(test, feature = "test-support"))]
1113 pub fn send_raw(
1114 &mut self,
1115 message: &str,
1116 cx: &mut Context<Self>,
1117 ) -> BoxFuture<'static, Result<()>> {
1118 self.send(
1119 vec![acp::ContentBlock::Text(acp::TextContent {
1120 text: message.to_string(),
1121 annotations: None,
1122 })],
1123 cx,
1124 )
1125 }
1126
1127 pub fn send(
1128 &mut self,
1129 message: Vec<acp::ContentBlock>,
1130 cx: &mut Context<Self>,
1131 ) -> BoxFuture<'static, Result<()>> {
1132 let block = ContentBlock::new_combined(
1133 message.clone(),
1134 self.project.read(cx).languages().clone(),
1135 cx,
1136 );
1137 let git_store = self.project.read(cx).git_store().clone();
1138
1139 let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1140 let message_id = if self
1141 .connection
1142 .session_editor(&self.session_id, cx)
1143 .is_some()
1144 {
1145 Some(UserMessageId::new())
1146 } else {
1147 None
1148 };
1149 self.push_entry(
1150 AgentThreadEntry::UserMessage(UserMessage {
1151 id: message_id.clone(),
1152 content: block,
1153 checkpoint: None,
1154 }),
1155 cx,
1156 );
1157 self.clear_completed_plan_entries(cx);
1158
1159 let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
1160 let (tx, rx) = oneshot::channel();
1161 let cancel_task = self.cancel(cx);
1162 let request = acp::PromptRequest {
1163 prompt: message,
1164 session_id: self.session_id.clone(),
1165 };
1166
1167 self.send_task = Some(cx.spawn({
1168 let message_id = message_id.clone();
1169 async move |this, cx| {
1170 cancel_task.await;
1171
1172 old_checkpoint_tx.send(old_checkpoint.await).ok();
1173 if let Ok(result) = this.update(cx, |this, cx| {
1174 this.connection.prompt(message_id, request, cx)
1175 }) {
1176 tx.send(result.await).log_err();
1177 }
1178 }
1179 }));
1180
1181 cx.spawn(async move |this, cx| {
1182 let old_checkpoint = old_checkpoint_rx
1183 .await
1184 .map_err(|_| anyhow!("send canceled"))
1185 .flatten()
1186 .context("failed to get old checkpoint")
1187 .log_err();
1188
1189 let response = rx.await;
1190
1191 if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
1192 let new_checkpoint = git_store
1193 .update(cx, |git, cx| git.checkpoint(cx))?
1194 .await
1195 .context("failed to get new checkpoint")
1196 .log_err();
1197 if let Some(new_checkpoint) = new_checkpoint {
1198 let equal = git_store
1199 .update(cx, |git, cx| {
1200 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1201 })?
1202 .await
1203 .unwrap_or(true);
1204 if !equal {
1205 this.update(cx, |this, cx| {
1206 if let Some((ix, message)) = this.user_message_mut(&message_id) {
1207 message.checkpoint = Some(old_checkpoint);
1208 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1209 }
1210 })?;
1211 }
1212 }
1213 }
1214
1215 this.update(cx, |this, cx| {
1216 match response {
1217 Ok(Err(e)) => {
1218 this.send_task.take();
1219 cx.emit(AcpThreadEvent::Error);
1220 Err(e)
1221 }
1222 result => {
1223 let cancelled = matches!(
1224 result,
1225 Ok(Ok(acp::PromptResponse {
1226 stop_reason: acp::StopReason::Cancelled
1227 }))
1228 );
1229
1230 // We only take the task if the current prompt wasn't cancelled.
1231 //
1232 // This prompt may have been cancelled because another one was sent
1233 // while it was still generating. In these cases, dropping `send_task`
1234 // would cause the next generation to be cancelled.
1235 if !cancelled {
1236 this.send_task.take();
1237 }
1238
1239 cx.emit(AcpThreadEvent::Stopped);
1240 Ok(())
1241 }
1242 }
1243 })?
1244 })
1245 .boxed()
1246 }
1247
1248 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1249 let Some(send_task) = self.send_task.take() else {
1250 return Task::ready(());
1251 };
1252
1253 for entry in self.entries.iter_mut() {
1254 if let AgentThreadEntry::ToolCall(call) = entry {
1255 let cancel = matches!(
1256 call.status,
1257 ToolCallStatus::WaitingForConfirmation { .. }
1258 | ToolCallStatus::Allowed {
1259 status: acp::ToolCallStatus::InProgress
1260 }
1261 );
1262
1263 if cancel {
1264 call.status = ToolCallStatus::Canceled;
1265 }
1266 }
1267 }
1268
1269 self.connection.cancel(&self.session_id, cx);
1270
1271 // Wait for the send task to complete
1272 cx.foreground_executor().spawn(send_task)
1273 }
1274
1275 /// Rewinds this thread to before the entry at `index`, removing it and all
1276 /// subsequent entries while reverting any changes made from that point.
1277 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1278 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1279 return Task::ready(Err(anyhow!("not supported")));
1280 };
1281 let Some(message) = self.user_message(&id) else {
1282 return Task::ready(Err(anyhow!("message not found")));
1283 };
1284
1285 let checkpoint = message.checkpoint.clone();
1286
1287 let git_store = self.project.read(cx).git_store().clone();
1288 cx.spawn(async move |this, cx| {
1289 if let Some(checkpoint) = checkpoint {
1290 git_store
1291 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1292 .await?;
1293 }
1294
1295 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1296 .await?;
1297 this.update(cx, |this, cx| {
1298 if let Some((ix, _)) = this.user_message_mut(&id) {
1299 let range = ix..this.entries.len();
1300 this.entries.truncate(ix);
1301 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1302 }
1303 })
1304 })
1305 }
1306
1307 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1308 self.entries.iter().find_map(|entry| {
1309 if let AgentThreadEntry::UserMessage(message) = entry {
1310 if message.id.as_ref() == Some(&id) {
1311 Some(message)
1312 } else {
1313 None
1314 }
1315 } else {
1316 None
1317 }
1318 })
1319 }
1320
1321 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1322 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1323 if let AgentThreadEntry::UserMessage(message) = entry {
1324 if message.id.as_ref() == Some(&id) {
1325 Some((ix, message))
1326 } else {
1327 None
1328 }
1329 } else {
1330 None
1331 }
1332 })
1333 }
1334
1335 pub fn read_text_file(
1336 &self,
1337 path: PathBuf,
1338 line: Option<u32>,
1339 limit: Option<u32>,
1340 reuse_shared_snapshot: bool,
1341 cx: &mut Context<Self>,
1342 ) -> Task<Result<String>> {
1343 let project = self.project.clone();
1344 let action_log = self.action_log.clone();
1345 cx.spawn(async move |this, cx| {
1346 let load = project.update(cx, |project, cx| {
1347 let path = project
1348 .project_path_for_absolute_path(&path, cx)
1349 .context("invalid path")?;
1350 anyhow::Ok(project.open_buffer(path, cx))
1351 });
1352 let buffer = load??.await?;
1353
1354 let snapshot = if reuse_shared_snapshot {
1355 this.read_with(cx, |this, _| {
1356 this.shared_buffers.get(&buffer.clone()).cloned()
1357 })
1358 .log_err()
1359 .flatten()
1360 } else {
1361 None
1362 };
1363
1364 let snapshot = if let Some(snapshot) = snapshot {
1365 snapshot
1366 } else {
1367 action_log.update(cx, |action_log, cx| {
1368 action_log.buffer_read(buffer.clone(), cx);
1369 })?;
1370 project.update(cx, |project, cx| {
1371 let position = buffer
1372 .read(cx)
1373 .snapshot()
1374 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1375 project.set_agent_location(
1376 Some(AgentLocation {
1377 buffer: buffer.downgrade(),
1378 position,
1379 }),
1380 cx,
1381 );
1382 })?;
1383
1384 buffer.update(cx, |buffer, _| buffer.snapshot())?
1385 };
1386
1387 this.update(cx, |this, _| {
1388 let text = snapshot.text();
1389 this.shared_buffers.insert(buffer.clone(), snapshot);
1390 if line.is_none() && limit.is_none() {
1391 return Ok(text);
1392 }
1393 let limit = limit.unwrap_or(u32::MAX) as usize;
1394 let Some(line) = line else {
1395 return Ok(text.lines().take(limit).collect::<String>());
1396 };
1397
1398 let count = text.lines().count();
1399 if count < line as usize {
1400 anyhow::bail!("There are only {} lines", count);
1401 }
1402 Ok(text
1403 .lines()
1404 .skip(line as usize + 1)
1405 .take(limit)
1406 .collect::<String>())
1407 })?
1408 })
1409 }
1410
1411 pub fn write_text_file(
1412 &self,
1413 path: PathBuf,
1414 content: String,
1415 cx: &mut Context<Self>,
1416 ) -> Task<Result<()>> {
1417 let project = self.project.clone();
1418 let action_log = self.action_log.clone();
1419 cx.spawn(async move |this, cx| {
1420 let load = project.update(cx, |project, cx| {
1421 let path = project
1422 .project_path_for_absolute_path(&path, cx)
1423 .context("invalid path")?;
1424 anyhow::Ok(project.open_buffer(path, cx))
1425 });
1426 let buffer = load??.await?;
1427 let snapshot = this.update(cx, |this, cx| {
1428 this.shared_buffers
1429 .get(&buffer)
1430 .cloned()
1431 .unwrap_or_else(|| buffer.read(cx).snapshot())
1432 })?;
1433 let edits = cx
1434 .background_executor()
1435 .spawn(async move {
1436 let old_text = snapshot.text();
1437 text_diff(old_text.as_str(), &content)
1438 .into_iter()
1439 .map(|(range, replacement)| {
1440 (
1441 snapshot.anchor_after(range.start)
1442 ..snapshot.anchor_before(range.end),
1443 replacement,
1444 )
1445 })
1446 .collect::<Vec<_>>()
1447 })
1448 .await;
1449 cx.update(|cx| {
1450 project.update(cx, |project, cx| {
1451 project.set_agent_location(
1452 Some(AgentLocation {
1453 buffer: buffer.downgrade(),
1454 position: edits
1455 .last()
1456 .map(|(range, _)| range.end)
1457 .unwrap_or(Anchor::MIN),
1458 }),
1459 cx,
1460 );
1461 });
1462
1463 action_log.update(cx, |action_log, cx| {
1464 action_log.buffer_read(buffer.clone(), cx);
1465 });
1466 buffer.update(cx, |buffer, cx| {
1467 buffer.edit(edits, None, cx);
1468 });
1469 action_log.update(cx, |action_log, cx| {
1470 action_log.buffer_edited(buffer.clone(), cx);
1471 });
1472 })?;
1473 project
1474 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1475 .await
1476 })
1477 }
1478
1479 pub fn to_markdown(&self, cx: &App) -> String {
1480 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1481 }
1482
1483 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1484 cx.emit(AcpThreadEvent::ServerExited(status));
1485 }
1486}
1487
1488fn markdown_for_raw_output(
1489 raw_output: &serde_json::Value,
1490 language_registry: &Arc<LanguageRegistry>,
1491 cx: &mut App,
1492) -> Option<Entity<Markdown>> {
1493 match raw_output {
1494 serde_json::Value::Null => None,
1495 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1496 Markdown::new(
1497 value.to_string().into(),
1498 Some(language_registry.clone()),
1499 None,
1500 cx,
1501 )
1502 })),
1503 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1504 Markdown::new(
1505 value.to_string().into(),
1506 Some(language_registry.clone()),
1507 None,
1508 cx,
1509 )
1510 })),
1511 serde_json::Value::String(value) => Some(cx.new(|cx| {
1512 Markdown::new(
1513 value.clone().into(),
1514 Some(language_registry.clone()),
1515 None,
1516 cx,
1517 )
1518 })),
1519 value => Some(cx.new(|cx| {
1520 Markdown::new(
1521 format!("```json\n{}\n```", value).into(),
1522 Some(language_registry.clone()),
1523 None,
1524 cx,
1525 )
1526 })),
1527 }
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532 use super::*;
1533 use anyhow::anyhow;
1534 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1535 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1536 use indoc::indoc;
1537 use project::{FakeFs, Fs};
1538 use rand::Rng as _;
1539 use serde_json::json;
1540 use settings::SettingsStore;
1541 use smol::stream::StreamExt as _;
1542 use std::{
1543 cell::RefCell,
1544 path::Path,
1545 rc::Rc,
1546 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1547 time::Duration,
1548 };
1549 use util::path;
1550
1551 fn init_test(cx: &mut TestAppContext) {
1552 env_logger::try_init().ok();
1553 cx.update(|cx| {
1554 let settings_store = SettingsStore::test(cx);
1555 cx.set_global(settings_store);
1556 Project::init_settings(cx);
1557 language::init(cx);
1558 });
1559 }
1560
1561 #[gpui::test]
1562 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1563 init_test(cx);
1564
1565 let fs = FakeFs::new(cx.executor());
1566 let project = Project::test(fs, [], cx).await;
1567 let connection = Rc::new(FakeAgentConnection::new());
1568 let thread = cx
1569 .spawn(async move |mut cx| {
1570 connection
1571 .new_thread(project, Path::new(path!("/test")), &mut cx)
1572 .await
1573 })
1574 .await
1575 .unwrap();
1576
1577 // Test creating a new user message
1578 thread.update(cx, |thread, cx| {
1579 thread.push_user_content_block(
1580 None,
1581 acp::ContentBlock::Text(acp::TextContent {
1582 annotations: None,
1583 text: "Hello, ".to_string(),
1584 }),
1585 cx,
1586 );
1587 });
1588
1589 thread.update(cx, |thread, cx| {
1590 assert_eq!(thread.entries.len(), 1);
1591 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1592 assert_eq!(user_msg.id, None);
1593 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1594 } else {
1595 panic!("Expected UserMessage");
1596 }
1597 });
1598
1599 // Test appending to existing user message
1600 let message_1_id = UserMessageId::new();
1601 thread.update(cx, |thread, cx| {
1602 thread.push_user_content_block(
1603 Some(message_1_id.clone()),
1604 acp::ContentBlock::Text(acp::TextContent {
1605 annotations: None,
1606 text: "world!".to_string(),
1607 }),
1608 cx,
1609 );
1610 });
1611
1612 thread.update(cx, |thread, cx| {
1613 assert_eq!(thread.entries.len(), 1);
1614 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1615 assert_eq!(user_msg.id, Some(message_1_id));
1616 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1617 } else {
1618 panic!("Expected UserMessage");
1619 }
1620 });
1621
1622 // Test creating new user message after assistant message
1623 thread.update(cx, |thread, cx| {
1624 thread.push_assistant_content_block(
1625 acp::ContentBlock::Text(acp::TextContent {
1626 annotations: None,
1627 text: "Assistant response".to_string(),
1628 }),
1629 false,
1630 cx,
1631 );
1632 });
1633
1634 let message_2_id = UserMessageId::new();
1635 thread.update(cx, |thread, cx| {
1636 thread.push_user_content_block(
1637 Some(message_2_id.clone()),
1638 acp::ContentBlock::Text(acp::TextContent {
1639 annotations: None,
1640 text: "New user message".to_string(),
1641 }),
1642 cx,
1643 );
1644 });
1645
1646 thread.update(cx, |thread, cx| {
1647 assert_eq!(thread.entries.len(), 3);
1648 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1649 assert_eq!(user_msg.id, Some(message_2_id));
1650 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1651 } else {
1652 panic!("Expected UserMessage at index 2");
1653 }
1654 });
1655 }
1656
1657 #[gpui::test]
1658 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1659 init_test(cx);
1660
1661 let fs = FakeFs::new(cx.executor());
1662 let project = Project::test(fs, [], cx).await;
1663 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1664 |_, thread, mut cx| {
1665 async move {
1666 thread.update(&mut cx, |thread, cx| {
1667 thread
1668 .handle_session_update(
1669 acp::SessionUpdate::AgentThoughtChunk {
1670 content: "Thinking ".into(),
1671 },
1672 cx,
1673 )
1674 .unwrap();
1675 thread
1676 .handle_session_update(
1677 acp::SessionUpdate::AgentThoughtChunk {
1678 content: "hard!".into(),
1679 },
1680 cx,
1681 )
1682 .unwrap();
1683 })?;
1684 Ok(acp::PromptResponse {
1685 stop_reason: acp::StopReason::EndTurn,
1686 })
1687 }
1688 .boxed_local()
1689 },
1690 ));
1691
1692 let thread = cx
1693 .spawn(async move |mut cx| {
1694 connection
1695 .new_thread(project, Path::new(path!("/test")), &mut cx)
1696 .await
1697 })
1698 .await
1699 .unwrap();
1700
1701 thread
1702 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1703 .await
1704 .unwrap();
1705
1706 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1707 assert_eq!(
1708 output,
1709 indoc! {r#"
1710 ## User
1711
1712 Hello from Zed!
1713
1714 ## Assistant
1715
1716 <thinking>
1717 Thinking hard!
1718 </thinking>
1719
1720 "#}
1721 );
1722 }
1723
1724 #[gpui::test]
1725 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1726 init_test(cx);
1727
1728 let fs = FakeFs::new(cx.executor());
1729 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1730 .await;
1731 let project = Project::test(fs.clone(), [], cx).await;
1732 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1733 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1734 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1735 move |_, thread, mut cx| {
1736 let read_file_tx = read_file_tx.clone();
1737 async move {
1738 let content = thread
1739 .update(&mut cx, |thread, cx| {
1740 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1741 })
1742 .unwrap()
1743 .await
1744 .unwrap();
1745 assert_eq!(content, "one\ntwo\nthree\n");
1746 read_file_tx.take().unwrap().send(()).unwrap();
1747 thread
1748 .update(&mut cx, |thread, cx| {
1749 thread.write_text_file(
1750 path!("/tmp/foo").into(),
1751 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1752 cx,
1753 )
1754 })
1755 .unwrap()
1756 .await
1757 .unwrap();
1758 Ok(acp::PromptResponse {
1759 stop_reason: acp::StopReason::EndTurn,
1760 })
1761 }
1762 .boxed_local()
1763 },
1764 ));
1765
1766 let (worktree, pathbuf) = project
1767 .update(cx, |project, cx| {
1768 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1769 })
1770 .await
1771 .unwrap();
1772 let buffer = project
1773 .update(cx, |project, cx| {
1774 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1775 })
1776 .await
1777 .unwrap();
1778
1779 let thread = cx
1780 .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1781 .await
1782 .unwrap();
1783
1784 let request = thread.update(cx, |thread, cx| {
1785 thread.send_raw("Extend the count in /tmp/foo", cx)
1786 });
1787 read_file_rx.await.ok();
1788 buffer.update(cx, |buffer, cx| {
1789 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1790 });
1791 cx.run_until_parked();
1792 assert_eq!(
1793 buffer.read_with(cx, |buffer, _| buffer.text()),
1794 "zero\none\ntwo\nthree\nfour\nfive\n"
1795 );
1796 assert_eq!(
1797 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1798 "zero\none\ntwo\nthree\nfour\nfive\n"
1799 );
1800 request.await.unwrap();
1801 }
1802
1803 #[gpui::test]
1804 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1805 init_test(cx);
1806
1807 let fs = FakeFs::new(cx.executor());
1808 let project = Project::test(fs, [], cx).await;
1809 let id = acp::ToolCallId("test".into());
1810
1811 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1812 let id = id.clone();
1813 move |_, thread, mut cx| {
1814 let id = id.clone();
1815 async move {
1816 thread
1817 .update(&mut cx, |thread, cx| {
1818 thread.handle_session_update(
1819 acp::SessionUpdate::ToolCall(acp::ToolCall {
1820 id: id.clone(),
1821 title: "Label".into(),
1822 kind: acp::ToolKind::Fetch,
1823 status: acp::ToolCallStatus::InProgress,
1824 content: vec![],
1825 locations: vec![],
1826 raw_input: None,
1827 raw_output: None,
1828 }),
1829 cx,
1830 )
1831 })
1832 .unwrap()
1833 .unwrap();
1834 Ok(acp::PromptResponse {
1835 stop_reason: acp::StopReason::EndTurn,
1836 })
1837 }
1838 .boxed_local()
1839 }
1840 }));
1841
1842 let thread = cx
1843 .spawn(async move |mut cx| {
1844 connection
1845 .new_thread(project, Path::new(path!("/test")), &mut cx)
1846 .await
1847 })
1848 .await
1849 .unwrap();
1850
1851 let request = thread.update(cx, |thread, cx| {
1852 thread.send_raw("Fetch https://example.com", cx)
1853 });
1854
1855 run_until_first_tool_call(&thread, cx).await;
1856
1857 thread.read_with(cx, |thread, _| {
1858 assert!(matches!(
1859 thread.entries[1],
1860 AgentThreadEntry::ToolCall(ToolCall {
1861 status: ToolCallStatus::Allowed {
1862 status: acp::ToolCallStatus::InProgress,
1863 ..
1864 },
1865 ..
1866 })
1867 ));
1868 });
1869
1870 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1871
1872 thread.read_with(cx, |thread, _| {
1873 assert!(matches!(
1874 &thread.entries[1],
1875 AgentThreadEntry::ToolCall(ToolCall {
1876 status: ToolCallStatus::Canceled,
1877 ..
1878 })
1879 ));
1880 });
1881
1882 thread
1883 .update(cx, |thread, cx| {
1884 thread.handle_session_update(
1885 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1886 id,
1887 fields: acp::ToolCallUpdateFields {
1888 status: Some(acp::ToolCallStatus::Completed),
1889 ..Default::default()
1890 },
1891 }),
1892 cx,
1893 )
1894 })
1895 .unwrap();
1896
1897 request.await.unwrap();
1898
1899 thread.read_with(cx, |thread, _| {
1900 assert!(matches!(
1901 thread.entries[1],
1902 AgentThreadEntry::ToolCall(ToolCall {
1903 status: ToolCallStatus::Allowed {
1904 status: acp::ToolCallStatus::Completed,
1905 ..
1906 },
1907 ..
1908 })
1909 ));
1910 });
1911 }
1912
1913 #[gpui::test]
1914 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1915 init_test(cx);
1916 let fs = FakeFs::new(cx.background_executor.clone());
1917 fs.insert_tree(path!("/test"), json!({})).await;
1918 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1919
1920 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1921 move |_, thread, mut cx| {
1922 async move {
1923 thread
1924 .update(&mut cx, |thread, cx| {
1925 thread.handle_session_update(
1926 acp::SessionUpdate::ToolCall(acp::ToolCall {
1927 id: acp::ToolCallId("test".into()),
1928 title: "Label".into(),
1929 kind: acp::ToolKind::Edit,
1930 status: acp::ToolCallStatus::Completed,
1931 content: vec![acp::ToolCallContent::Diff {
1932 diff: acp::Diff {
1933 path: "/test/test.txt".into(),
1934 old_text: None,
1935 new_text: "foo".into(),
1936 },
1937 }],
1938 locations: vec![],
1939 raw_input: None,
1940 raw_output: None,
1941 }),
1942 cx,
1943 )
1944 })
1945 .unwrap()
1946 .unwrap();
1947 Ok(acp::PromptResponse {
1948 stop_reason: acp::StopReason::EndTurn,
1949 })
1950 }
1951 .boxed_local()
1952 }
1953 }));
1954
1955 let thread = connection
1956 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1957 .await
1958 .unwrap();
1959 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1960 .await
1961 .unwrap();
1962
1963 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1964 }
1965
1966 #[gpui::test(iterations = 10)]
1967 async fn test_checkpoints(cx: &mut TestAppContext) {
1968 init_test(cx);
1969 let fs = FakeFs::new(cx.background_executor.clone());
1970 fs.insert_tree(
1971 path!("/test"),
1972 json!({
1973 ".git": {}
1974 }),
1975 )
1976 .await;
1977 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1978
1979 let simulate_changes = Arc::new(AtomicBool::new(true));
1980 let next_filename = Arc::new(AtomicUsize::new(0));
1981 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1982 let simulate_changes = simulate_changes.clone();
1983 let next_filename = next_filename.clone();
1984 let fs = fs.clone();
1985 move |request, thread, mut cx| {
1986 let fs = fs.clone();
1987 let simulate_changes = simulate_changes.clone();
1988 let next_filename = next_filename.clone();
1989 async move {
1990 if simulate_changes.load(SeqCst) {
1991 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
1992 fs.write(Path::new(&filename), b"").await?;
1993 }
1994
1995 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
1996 panic!("expected text content block");
1997 };
1998 thread.update(&mut cx, |thread, cx| {
1999 thread
2000 .handle_session_update(
2001 acp::SessionUpdate::AgentMessageChunk {
2002 content: content.text.to_uppercase().into(),
2003 },
2004 cx,
2005 )
2006 .unwrap();
2007 })?;
2008 Ok(acp::PromptResponse {
2009 stop_reason: acp::StopReason::EndTurn,
2010 })
2011 }
2012 .boxed_local()
2013 }
2014 }));
2015 let thread = connection
2016 .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
2017 .await
2018 .unwrap();
2019
2020 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2021 .await
2022 .unwrap();
2023 thread.read_with(cx, |thread, cx| {
2024 assert_eq!(
2025 thread.to_markdown(cx),
2026 indoc! {"
2027 ## User (checkpoint)
2028
2029 Lorem
2030
2031 ## Assistant
2032
2033 LOREM
2034
2035 "}
2036 );
2037 });
2038 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2039
2040 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2041 .await
2042 .unwrap();
2043 thread.read_with(cx, |thread, cx| {
2044 assert_eq!(
2045 thread.to_markdown(cx),
2046 indoc! {"
2047 ## User (checkpoint)
2048
2049 Lorem
2050
2051 ## Assistant
2052
2053 LOREM
2054
2055 ## User (checkpoint)
2056
2057 ipsum
2058
2059 ## Assistant
2060
2061 IPSUM
2062
2063 "}
2064 );
2065 });
2066 assert_eq!(
2067 fs.files(),
2068 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2069 );
2070
2071 // Checkpoint isn't stored when there are no changes.
2072 simulate_changes.store(false, SeqCst);
2073 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2074 .await
2075 .unwrap();
2076 thread.read_with(cx, |thread, cx| {
2077 assert_eq!(
2078 thread.to_markdown(cx),
2079 indoc! {"
2080 ## User (checkpoint)
2081
2082 Lorem
2083
2084 ## Assistant
2085
2086 LOREM
2087
2088 ## User (checkpoint)
2089
2090 ipsum
2091
2092 ## Assistant
2093
2094 IPSUM
2095
2096 ## User
2097
2098 dolor
2099
2100 ## Assistant
2101
2102 DOLOR
2103
2104 "}
2105 );
2106 });
2107 assert_eq!(
2108 fs.files(),
2109 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2110 );
2111
2112 // Rewinding the conversation truncates the history and restores the checkpoint.
2113 thread
2114 .update(cx, |thread, cx| {
2115 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2116 panic!("unexpected entries {:?}", thread.entries)
2117 };
2118 thread.rewind(message.id.clone().unwrap(), cx)
2119 })
2120 .await
2121 .unwrap();
2122 thread.read_with(cx, |thread, cx| {
2123 assert_eq!(
2124 thread.to_markdown(cx),
2125 indoc! {"
2126 ## User (checkpoint)
2127
2128 Lorem
2129
2130 ## Assistant
2131
2132 LOREM
2133
2134 "}
2135 );
2136 });
2137 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2138 }
2139
2140 async fn run_until_first_tool_call(
2141 thread: &Entity<AcpThread>,
2142 cx: &mut TestAppContext,
2143 ) -> usize {
2144 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2145
2146 let subscription = cx.update(|cx| {
2147 cx.subscribe(thread, move |thread, _, cx| {
2148 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2149 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2150 return tx.try_send(ix).unwrap();
2151 }
2152 }
2153 })
2154 });
2155
2156 select! {
2157 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2158 panic!("Timeout waiting for tool call")
2159 }
2160 ix = rx.next().fuse() => {
2161 drop(subscription);
2162 ix.unwrap()
2163 }
2164 }
2165 }
2166
2167 #[derive(Clone, Default)]
2168 struct FakeAgentConnection {
2169 auth_methods: Vec<acp::AuthMethod>,
2170 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2171 on_user_message: Option<
2172 Rc<
2173 dyn Fn(
2174 acp::PromptRequest,
2175 WeakEntity<AcpThread>,
2176 AsyncApp,
2177 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2178 + 'static,
2179 >,
2180 >,
2181 }
2182
2183 impl FakeAgentConnection {
2184 fn new() -> Self {
2185 Self {
2186 auth_methods: Vec::new(),
2187 on_user_message: None,
2188 sessions: Arc::default(),
2189 }
2190 }
2191
2192 #[expect(unused)]
2193 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2194 self.auth_methods = auth_methods;
2195 self
2196 }
2197
2198 fn on_user_message(
2199 mut self,
2200 handler: impl Fn(
2201 acp::PromptRequest,
2202 WeakEntity<AcpThread>,
2203 AsyncApp,
2204 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2205 + 'static,
2206 ) -> Self {
2207 self.on_user_message.replace(Rc::new(handler));
2208 self
2209 }
2210 }
2211
2212 impl AgentConnection for FakeAgentConnection {
2213 fn auth_methods(&self) -> &[acp::AuthMethod] {
2214 &self.auth_methods
2215 }
2216
2217 fn new_thread(
2218 self: Rc<Self>,
2219 project: Entity<Project>,
2220 _cwd: &Path,
2221 cx: &mut gpui::AsyncApp,
2222 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2223 let session_id = acp::SessionId(
2224 rand::thread_rng()
2225 .sample_iter(&rand::distributions::Alphanumeric)
2226 .take(7)
2227 .map(char::from)
2228 .collect::<String>()
2229 .into(),
2230 );
2231 let thread = cx
2232 .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
2233 .unwrap();
2234 self.sessions.lock().insert(session_id, thread.downgrade());
2235 Task::ready(Ok(thread))
2236 }
2237
2238 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2239 if self.auth_methods().iter().any(|m| m.id == method) {
2240 Task::ready(Ok(()))
2241 } else {
2242 Task::ready(Err(anyhow!("Invalid Auth Method")))
2243 }
2244 }
2245
2246 fn prompt(
2247 &self,
2248 _id: Option<UserMessageId>,
2249 params: acp::PromptRequest,
2250 cx: &mut App,
2251 ) -> Task<gpui::Result<acp::PromptResponse>> {
2252 let sessions = self.sessions.lock();
2253 let thread = sessions.get(¶ms.session_id).unwrap();
2254 if let Some(handler) = &self.on_user_message {
2255 let handler = handler.clone();
2256 let thread = thread.clone();
2257 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2258 } else {
2259 Task::ready(Ok(acp::PromptResponse {
2260 stop_reason: acp::StopReason::EndTurn,
2261 }))
2262 }
2263 }
2264
2265 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2266 let sessions = self.sessions.lock();
2267 let thread = sessions.get(&session_id).unwrap().clone();
2268
2269 cx.spawn(async move |cx| {
2270 thread
2271 .update(cx, |thread, cx| thread.cancel(cx))
2272 .unwrap()
2273 .await
2274 })
2275 .detach();
2276 }
2277
2278 fn session_editor(
2279 &self,
2280 session_id: &acp::SessionId,
2281 _cx: &mut App,
2282 ) -> Option<Rc<dyn AgentSessionEditor>> {
2283 Some(Rc::new(FakeAgentSessionEditor {
2284 _session_id: session_id.clone(),
2285 }))
2286 }
2287 }
2288
2289 struct FakeAgentSessionEditor {
2290 _session_id: acp::SessionId,
2291 }
2292
2293 impl AgentSessionEditor for FakeAgentSessionEditor {
2294 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2295 Task::ready(Ok(()))
2296 }
2297 }
2298}