1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9use serde::{Deserialize, Serialize};
10pub use terminal::*;
11
12use action_log::ActionLog;
13use agent_client_protocol as acp;
14use anyhow::{Context as _, Result, anyhow};
15use chrono::{DateTime, Utc};
16use editor::Bias;
17use futures::{FutureExt, channel::oneshot, future::BoxFuture};
18use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
19use itertools::Itertools;
20use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
21use markdown::Markdown;
22use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
23use std::collections::HashMap;
24use std::error::Error;
25use std::fmt::{Formatter, Write};
26use std::ops::Range;
27use std::process::ExitStatus;
28use std::rc::Rc;
29use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub id: Option<UserMessageId>,
36 pub content: ContentBlock,
37 pub chunks: Vec<acp::ContentBlock>,
38 pub checkpoint: Option<Checkpoint>,
39}
40
41#[derive(Debug)]
42pub struct Checkpoint {
43 git_checkpoint: GitStoreCheckpoint,
44 pub show: bool,
45}
46
47impl UserMessage {
48 fn to_markdown(&self, cx: &App) -> String {
49 let mut markdown = String::new();
50 if self
51 .checkpoint
52 .as_ref()
53 .map_or(false, |checkpoint| checkpoint.show)
54 {
55 writeln!(markdown, "## User (checkpoint)").unwrap();
56 } else {
57 writeln!(markdown, "## User").unwrap();
58 }
59 writeln!(markdown).unwrap();
60 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
61 writeln!(markdown).unwrap();
62 markdown
63 }
64}
65
66#[derive(Debug, PartialEq)]
67pub struct AssistantMessage {
68 pub chunks: Vec<AssistantMessageChunk>,
69}
70
71impl AssistantMessage {
72 pub fn to_markdown(&self, cx: &App) -> String {
73 format!(
74 "## Assistant\n\n{}\n\n",
75 self.chunks
76 .iter()
77 .map(|chunk| chunk.to_markdown(cx))
78 .join("\n\n")
79 )
80 }
81}
82
83#[derive(Debug, PartialEq)]
84pub enum AssistantMessageChunk {
85 Message { block: ContentBlock },
86 Thought { block: ContentBlock },
87}
88
89impl AssistantMessageChunk {
90 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
91 Self::Message {
92 block: ContentBlock::new(chunk.into(), language_registry, cx),
93 }
94 }
95
96 fn to_markdown(&self, cx: &App) -> String {
97 match self {
98 Self::Message { block } => block.to_markdown(cx).to_string(),
99 Self::Thought { block } => {
100 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
101 }
102 }
103 }
104}
105
106#[derive(Debug)]
107pub enum AgentThreadEntry {
108 UserMessage(UserMessage),
109 AssistantMessage(AssistantMessage),
110 ToolCall(ToolCall),
111}
112
113impl AgentThreadEntry {
114 pub fn to_markdown(&self, cx: &App) -> String {
115 match self {
116 Self::UserMessage(message) => message.to_markdown(cx),
117 Self::AssistantMessage(message) => message.to_markdown(cx),
118 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
119 }
120 }
121
122 pub fn user_message(&self) -> Option<&UserMessage> {
123 if let AgentThreadEntry::UserMessage(message) = self {
124 Some(message)
125 } else {
126 None
127 }
128 }
129
130 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
131 if let AgentThreadEntry::ToolCall(call) = self {
132 itertools::Either::Left(call.diffs())
133 } else {
134 itertools::Either::Right(std::iter::empty())
135 }
136 }
137
138 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.terminals())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
147 if let AgentThreadEntry::ToolCall(ToolCall {
148 locations,
149 resolved_locations,
150 ..
151 }) = self
152 {
153 Some((
154 locations.get(ix)?.clone(),
155 resolved_locations.get(ix)?.clone()?,
156 ))
157 } else {
158 None
159 }
160 }
161}
162
163#[derive(Debug)]
164pub struct ToolCall {
165 pub id: acp::ToolCallId,
166 pub label: Entity<Markdown>,
167 pub kind: acp::ToolKind,
168 pub content: Vec<ToolCallContent>,
169 pub status: ToolCallStatus,
170 pub locations: Vec<acp::ToolCallLocation>,
171 pub resolved_locations: Vec<Option<AgentLocation>>,
172 pub raw_input: Option<serde_json::Value>,
173 pub raw_output: Option<serde_json::Value>,
174}
175
176impl ToolCall {
177 fn from_acp(
178 tool_call: acp::ToolCall,
179 status: ToolCallStatus,
180 language_registry: Arc<LanguageRegistry>,
181 cx: &mut App,
182 ) -> Self {
183 Self {
184 id: tool_call.id,
185 label: cx.new(|cx| {
186 Markdown::new(
187 tool_call.title.into(),
188 Some(language_registry.clone()),
189 None,
190 cx,
191 )
192 }),
193 kind: tool_call.kind,
194 content: tool_call
195 .content
196 .into_iter()
197 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
198 .collect(),
199 locations: tool_call.locations,
200 resolved_locations: Vec::default(),
201 status,
202 raw_input: tool_call.raw_input,
203 raw_output: tool_call.raw_output,
204 }
205 }
206
207 fn update_fields(
208 &mut self,
209 fields: acp::ToolCallUpdateFields,
210 language_registry: Arc<LanguageRegistry>,
211 cx: &mut App,
212 ) {
213 let acp::ToolCallUpdateFields {
214 kind,
215 status,
216 title,
217 content,
218 locations,
219 raw_input,
220 raw_output,
221 } = fields;
222
223 if let Some(kind) = kind {
224 self.kind = kind;
225 }
226
227 if let Some(status) = status {
228 self.status = status.into();
229 }
230
231 if let Some(title) = title {
232 self.label.update(cx, |label, cx| {
233 label.replace(title, cx);
234 });
235 }
236
237 if let Some(content) = content {
238 self.content = content
239 .into_iter()
240 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
241 .collect();
242 }
243
244 if let Some(locations) = locations {
245 self.locations = locations;
246 }
247
248 if let Some(raw_input) = raw_input {
249 self.raw_input = Some(raw_input);
250 }
251
252 if let Some(raw_output) = raw_output {
253 if self.content.is_empty() {
254 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
255 {
256 self.content
257 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
258 markdown,
259 }));
260 }
261 }
262 self.raw_output = Some(raw_output);
263 }
264 }
265
266 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
267 self.content.iter().filter_map(|content| match content {
268 ToolCallContent::Diff(diff) => Some(diff),
269 ToolCallContent::ContentBlock(_) => None,
270 ToolCallContent::Terminal(_) => None,
271 })
272 }
273
274 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
275 self.content.iter().filter_map(|content| match content {
276 ToolCallContent::Terminal(terminal) => Some(terminal),
277 ToolCallContent::ContentBlock(_) => None,
278 ToolCallContent::Diff(_) => None,
279 })
280 }
281
282 fn to_markdown(&self, cx: &App) -> String {
283 let mut markdown = format!(
284 "**Tool Call: {}**\nStatus: {}\n\n",
285 self.label.read(cx).source(),
286 self.status
287 );
288 for content in &self.content {
289 markdown.push_str(content.to_markdown(cx).as_str());
290 markdown.push_str("\n\n");
291 }
292 markdown
293 }
294
295 async fn resolve_location(
296 location: acp::ToolCallLocation,
297 project: WeakEntity<Project>,
298 cx: &mut AsyncApp,
299 ) -> Option<AgentLocation> {
300 let buffer = project
301 .update(cx, |project, cx| {
302 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
303 Some(project.open_buffer(path, cx))
304 } else {
305 None
306 }
307 })
308 .ok()??;
309 let buffer = buffer.await.log_err()?;
310 let position = buffer
311 .update(cx, |buffer, _| {
312 if let Some(row) = location.line {
313 let snapshot = buffer.snapshot();
314 let column = snapshot.indent_size_for_line(row).len;
315 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
316 snapshot.anchor_before(point)
317 } else {
318 Anchor::MIN
319 }
320 })
321 .ok()?;
322
323 Some(AgentLocation {
324 buffer: buffer.downgrade(),
325 position,
326 })
327 }
328
329 fn resolve_locations(
330 &self,
331 project: Entity<Project>,
332 cx: &mut App,
333 ) -> Task<Vec<Option<AgentLocation>>> {
334 let locations = self.locations.clone();
335 project.update(cx, |_, cx| {
336 cx.spawn(async move |project, cx| {
337 let mut new_locations = Vec::new();
338 for location in locations {
339 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
340 }
341 new_locations
342 })
343 })
344 }
345}
346
347#[derive(Debug)]
348pub enum ToolCallStatus {
349 /// The tool call hasn't started running yet, but we start showing it to
350 /// the user.
351 Pending,
352 /// The tool call is waiting for confirmation from the user.
353 WaitingForConfirmation {
354 options: Vec<acp::PermissionOption>,
355 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
356 },
357 /// The tool call is currently running.
358 InProgress,
359 /// The tool call completed successfully.
360 Completed,
361 /// The tool call failed.
362 Failed,
363 /// The user rejected the tool call.
364 Rejected,
365 /// The user canceled generation so the tool call was canceled.
366 Canceled,
367}
368
369impl From<acp::ToolCallStatus> for ToolCallStatus {
370 fn from(status: acp::ToolCallStatus) -> Self {
371 match status {
372 acp::ToolCallStatus::Pending => Self::Pending,
373 acp::ToolCallStatus::InProgress => Self::InProgress,
374 acp::ToolCallStatus::Completed => Self::Completed,
375 acp::ToolCallStatus::Failed => Self::Failed,
376 }
377 }
378}
379
380impl Display for ToolCallStatus {
381 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
382 write!(
383 f,
384 "{}",
385 match self {
386 ToolCallStatus::Pending => "Pending",
387 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
388 ToolCallStatus::InProgress => "In Progress",
389 ToolCallStatus::Completed => "Completed",
390 ToolCallStatus::Failed => "Failed",
391 ToolCallStatus::Rejected => "Rejected",
392 ToolCallStatus::Canceled => "Canceled",
393 }
394 )
395 }
396}
397
398#[derive(Debug, PartialEq, Clone)]
399pub enum ContentBlock {
400 Empty,
401 Markdown { markdown: Entity<Markdown> },
402 ResourceLink { resource_link: acp::ResourceLink },
403}
404
405impl ContentBlock {
406 pub fn new(
407 block: acp::ContentBlock,
408 language_registry: &Arc<LanguageRegistry>,
409 cx: &mut App,
410 ) -> Self {
411 let mut this = Self::Empty;
412 this.append(block, language_registry, cx);
413 this
414 }
415
416 pub fn new_combined(
417 blocks: impl IntoIterator<Item = acp::ContentBlock>,
418 language_registry: Arc<LanguageRegistry>,
419 cx: &mut App,
420 ) -> Self {
421 let mut this = Self::Empty;
422 for block in blocks {
423 this.append(block, &language_registry, cx);
424 }
425 this
426 }
427
428 pub fn append(
429 &mut self,
430 block: acp::ContentBlock,
431 language_registry: &Arc<LanguageRegistry>,
432 cx: &mut App,
433 ) {
434 if matches!(self, ContentBlock::Empty) {
435 if let acp::ContentBlock::ResourceLink(resource_link) = block {
436 *self = ContentBlock::ResourceLink { resource_link };
437 return;
438 }
439 }
440
441 let new_content = self.block_string_contents(block);
442
443 match self {
444 ContentBlock::Empty => {
445 *self = Self::create_markdown_block(new_content, language_registry, cx);
446 }
447 ContentBlock::Markdown { markdown } => {
448 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
449 }
450 ContentBlock::ResourceLink { resource_link } => {
451 let existing_content = Self::resource_link_md(&resource_link.uri);
452 let combined = format!("{}\n{}", existing_content, new_content);
453
454 *self = Self::create_markdown_block(combined, language_registry, cx);
455 }
456 }
457 }
458
459 fn create_markdown_block(
460 content: String,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) -> ContentBlock {
464 ContentBlock::Markdown {
465 markdown: cx
466 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
467 }
468 }
469
470 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
471 match block {
472 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
473 acp::ContentBlock::ResourceLink(resource_link) => {
474 Self::resource_link_md(&resource_link.uri)
475 }
476 acp::ContentBlock::Resource(acp::EmbeddedResource {
477 resource:
478 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
479 uri,
480 ..
481 }),
482 ..
483 }) => Self::resource_link_md(&uri),
484 acp::ContentBlock::Image(image) => Self::image_md(&image),
485 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
486 }
487 }
488
489 fn resource_link_md(uri: &str) -> String {
490 if let Some(uri) = MentionUri::parse(&uri).log_err() {
491 uri.as_link().to_string()
492 } else {
493 uri.to_string()
494 }
495 }
496
497 fn image_md(_image: &acp::ImageContent) -> String {
498 "`Image`".into()
499 }
500
501 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
502 match self {
503 ContentBlock::Empty => "",
504 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
505 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
506 }
507 }
508
509 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
510 match self {
511 ContentBlock::Empty => None,
512 ContentBlock::Markdown { markdown } => Some(markdown),
513 ContentBlock::ResourceLink { .. } => None,
514 }
515 }
516
517 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
518 match self {
519 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
520 _ => None,
521 }
522 }
523}
524
525#[derive(Debug)]
526pub enum ToolCallContent {
527 ContentBlock(ContentBlock),
528 Diff(Entity<Diff>),
529 Terminal(Entity<Terminal>),
530}
531
532impl ToolCallContent {
533 pub fn from_acp(
534 content: acp::ToolCallContent,
535 language_registry: Arc<LanguageRegistry>,
536 cx: &mut App,
537 ) -> Self {
538 match content {
539 acp::ToolCallContent::Content { content } => {
540 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
541 }
542 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
543 Diff::finalized(
544 diff.path,
545 diff.old_text,
546 diff.new_text,
547 language_registry,
548 cx,
549 )
550 })),
551 }
552 }
553
554 pub fn to_markdown(&self, cx: &App) -> String {
555 match self {
556 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
557 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
558 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
559 }
560 }
561}
562
563#[derive(Debug, PartialEq)]
564pub enum ToolCallUpdate {
565 UpdateFields(acp::ToolCallUpdate),
566 UpdateDiff(ToolCallUpdateDiff),
567 UpdateTerminal(ToolCallUpdateTerminal),
568}
569
570impl ToolCallUpdate {
571 fn id(&self) -> &acp::ToolCallId {
572 match self {
573 Self::UpdateFields(update) => &update.id,
574 Self::UpdateDiff(diff) => &diff.id,
575 Self::UpdateTerminal(terminal) => &terminal.id,
576 }
577 }
578}
579
580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
581 fn from(update: acp::ToolCallUpdate) -> Self {
582 Self::UpdateFields(update)
583 }
584}
585
586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
587 fn from(diff: ToolCallUpdateDiff) -> Self {
588 Self::UpdateDiff(diff)
589 }
590}
591
592#[derive(Debug, PartialEq)]
593pub struct ToolCallUpdateDiff {
594 pub id: acp::ToolCallId,
595 pub diff: Entity<Diff>,
596}
597
598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
599 fn from(terminal: ToolCallUpdateTerminal) -> Self {
600 Self::UpdateTerminal(terminal)
601 }
602}
603
604#[derive(Debug, PartialEq)]
605pub struct ToolCallUpdateTerminal {
606 pub id: acp::ToolCallId,
607 pub terminal: Entity<Terminal>,
608}
609
610#[derive(Debug, Default)]
611pub struct Plan {
612 pub entries: Vec<PlanEntry>,
613}
614
615#[derive(Debug)]
616pub struct PlanStats<'a> {
617 pub in_progress_entry: Option<&'a PlanEntry>,
618 pub pending: u32,
619 pub completed: u32,
620}
621
622impl Plan {
623 pub fn is_empty(&self) -> bool {
624 self.entries.is_empty()
625 }
626
627 pub fn stats(&self) -> PlanStats<'_> {
628 let mut stats = PlanStats {
629 in_progress_entry: None,
630 pending: 0,
631 completed: 0,
632 };
633
634 for entry in &self.entries {
635 match &entry.status {
636 acp::PlanEntryStatus::Pending => {
637 stats.pending += 1;
638 }
639 acp::PlanEntryStatus::InProgress => {
640 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
641 }
642 acp::PlanEntryStatus::Completed => {
643 stats.completed += 1;
644 }
645 }
646 }
647
648 stats
649 }
650}
651
652#[derive(Debug)]
653pub struct PlanEntry {
654 pub content: Entity<Markdown>,
655 pub priority: acp::PlanEntryPriority,
656 pub status: acp::PlanEntryStatus,
657}
658
659impl PlanEntry {
660 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
661 Self {
662 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
663 priority: entry.priority,
664 status: entry.status,
665 }
666 }
667}
668
669#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
670pub struct AgentServerName(pub SharedString);
671
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct AcpThreadMetadata {
674 pub agent: AgentServerName,
675 pub id: acp::SessionId,
676 pub title: SharedString,
677 pub updated_at: DateTime<Utc>,
678}
679
680pub struct AcpThread {
681 title: SharedString,
682 entries: Vec<AgentThreadEntry>,
683 plan: Plan,
684 project: Entity<Project>,
685 action_log: Entity<ActionLog>,
686 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
687 send_task: Option<Task<()>>,
688 connection: Rc<dyn AgentConnection>,
689 session_id: acp::SessionId,
690}
691
692#[derive(Debug)]
693pub enum AcpThreadEvent {
694 NewEntry,
695 TitleUpdated,
696 EntryUpdated(usize),
697 EntriesRemoved(Range<usize>),
698 ToolAuthorizationRequired,
699 Stopped,
700 Error,
701 ServerExited(ExitStatus),
702}
703
704impl EventEmitter<AcpThreadEvent> for AcpThread {}
705
706#[derive(PartialEq, Eq)]
707pub enum ThreadStatus {
708 Idle,
709 WaitingForToolConfirmation,
710 Generating,
711}
712
713#[derive(Debug, Clone)]
714pub enum LoadError {
715 Unsupported {
716 error_message: SharedString,
717 upgrade_message: SharedString,
718 upgrade_command: String,
719 },
720 Exited(i32),
721 Other(SharedString),
722}
723
724impl Display for LoadError {
725 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
726 match self {
727 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
728 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
729 LoadError::Other(msg) => write!(f, "{}", msg),
730 }
731 }
732}
733
734impl Error for LoadError {}
735
736impl AcpThread {
737 pub fn new(
738 title: impl Into<SharedString>,
739 connection: Rc<dyn AgentConnection>,
740 project: Entity<Project>,
741 session_id: acp::SessionId,
742 cx: &mut Context<Self>,
743 ) -> Self {
744 let action_log = cx.new(|_| ActionLog::new(project.clone()));
745
746 Self {
747 action_log,
748 shared_buffers: Default::default(),
749 entries: Default::default(),
750 plan: Default::default(),
751 title: title.into(),
752 project,
753 send_task: None,
754 connection,
755 session_id,
756 }
757 }
758
759 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
760 &self.connection
761 }
762
763 pub fn action_log(&self) -> &Entity<ActionLog> {
764 &self.action_log
765 }
766
767 pub fn project(&self) -> &Entity<Project> {
768 &self.project
769 }
770
771 pub fn title(&self) -> SharedString {
772 self.title.clone()
773 }
774
775 pub fn entries(&self) -> &[AgentThreadEntry] {
776 &self.entries
777 }
778
779 pub fn session_id(&self) -> &acp::SessionId {
780 &self.session_id
781 }
782
783 pub fn status(&self) -> ThreadStatus {
784 if self.send_task.is_some() {
785 if self.waiting_for_tool_confirmation() {
786 ThreadStatus::WaitingForToolConfirmation
787 } else {
788 ThreadStatus::Generating
789 }
790 } else {
791 ThreadStatus::Idle
792 }
793 }
794
795 pub fn has_pending_edit_tool_calls(&self) -> bool {
796 for entry in self.entries.iter().rev() {
797 match entry {
798 AgentThreadEntry::UserMessage(_) => return false,
799 AgentThreadEntry::ToolCall(
800 call @ ToolCall {
801 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
802 ..
803 },
804 ) if call.diffs().next().is_some() => {
805 return true;
806 }
807 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
808 }
809 }
810
811 false
812 }
813
814 pub fn used_tools_since_last_user_message(&self) -> bool {
815 for entry in self.entries.iter().rev() {
816 match entry {
817 AgentThreadEntry::UserMessage(..) => return false,
818 AgentThreadEntry::AssistantMessage(..) => continue,
819 AgentThreadEntry::ToolCall(..) => return true,
820 }
821 }
822
823 false
824 }
825
826 pub fn handle_session_update(
827 &mut self,
828 update: acp::SessionUpdate,
829 cx: &mut Context<Self>,
830 ) -> Result<(), acp::Error> {
831 match update {
832 acp::SessionUpdate::UserMessageChunk { content } => {
833 self.push_user_content_block(None, content, cx);
834 }
835 acp::SessionUpdate::AgentMessageChunk { content } => {
836 self.push_assistant_content_block(content, false, cx);
837 }
838 acp::SessionUpdate::AgentThoughtChunk { content } => {
839 self.push_assistant_content_block(content, true, cx);
840 }
841 acp::SessionUpdate::ToolCall(tool_call) => {
842 self.upsert_tool_call(tool_call, cx)?;
843 }
844 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
845 self.update_tool_call(tool_call_update, cx)?;
846 }
847 acp::SessionUpdate::Plan(plan) => {
848 self.update_plan(plan, cx);
849 }
850 }
851 Ok(())
852 }
853
854 pub fn push_user_content_block(
855 &mut self,
856 message_id: Option<UserMessageId>,
857 chunk: acp::ContentBlock,
858 cx: &mut Context<Self>,
859 ) {
860 let language_registry = self.project.read(cx).languages().clone();
861 let entries_len = self.entries.len();
862
863 if let Some(last_entry) = self.entries.last_mut()
864 && let AgentThreadEntry::UserMessage(UserMessage {
865 id,
866 content,
867 chunks,
868 ..
869 }) = last_entry
870 {
871 *id = message_id.or(id.take());
872 content.append(chunk.clone(), &language_registry, cx);
873 chunks.push(chunk);
874 let idx = entries_len - 1;
875 cx.emit(AcpThreadEvent::EntryUpdated(idx));
876 } else {
877 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
878 self.push_entry(
879 AgentThreadEntry::UserMessage(UserMessage {
880 id: message_id,
881 content,
882 chunks: vec![chunk],
883 checkpoint: None,
884 }),
885 cx,
886 );
887 }
888 }
889
890 pub fn push_assistant_content_block(
891 &mut self,
892 chunk: acp::ContentBlock,
893 is_thought: bool,
894 cx: &mut Context<Self>,
895 ) {
896 let language_registry = self.project.read(cx).languages().clone();
897 let entries_len = self.entries.len();
898 if let Some(last_entry) = self.entries.last_mut()
899 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
900 {
901 let idx = entries_len - 1;
902 cx.emit(AcpThreadEvent::EntryUpdated(idx));
903 match (chunks.last_mut(), is_thought) {
904 (Some(AssistantMessageChunk::Message { block }), false)
905 | (Some(AssistantMessageChunk::Thought { block }), true) => {
906 block.append(chunk, &language_registry, cx)
907 }
908 _ => {
909 let block = ContentBlock::new(chunk, &language_registry, cx);
910 if is_thought {
911 chunks.push(AssistantMessageChunk::Thought { block })
912 } else {
913 chunks.push(AssistantMessageChunk::Message { block })
914 }
915 }
916 }
917 } else {
918 let block = ContentBlock::new(chunk, &language_registry, cx);
919 let chunk = if is_thought {
920 AssistantMessageChunk::Thought { block }
921 } else {
922 AssistantMessageChunk::Message { block }
923 };
924
925 self.push_entry(
926 AgentThreadEntry::AssistantMessage(AssistantMessage {
927 chunks: vec![chunk],
928 }),
929 cx,
930 );
931 }
932 }
933
934 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
935 self.entries.push(entry);
936 cx.emit(AcpThreadEvent::NewEntry);
937 }
938
939 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
940 self.title = title;
941 cx.emit(AcpThreadEvent::TitleUpdated);
942 Ok(())
943 }
944
945 pub fn update_tool_call(
946 &mut self,
947 update: impl Into<ToolCallUpdate>,
948 cx: &mut Context<Self>,
949 ) -> Result<()> {
950 let update = update.into();
951 let languages = self.project.read(cx).languages().clone();
952
953 let (ix, current_call) = self
954 .tool_call_mut(update.id())
955 .context("Tool call not found")?;
956 match update {
957 ToolCallUpdate::UpdateFields(update) => {
958 let location_updated = update.fields.locations.is_some();
959 current_call.update_fields(update.fields, languages, cx);
960 if location_updated {
961 self.resolve_locations(update.id.clone(), cx);
962 }
963 }
964 ToolCallUpdate::UpdateDiff(update) => {
965 current_call.content.clear();
966 current_call
967 .content
968 .push(ToolCallContent::Diff(update.diff));
969 }
970 ToolCallUpdate::UpdateTerminal(update) => {
971 current_call.content.clear();
972 current_call
973 .content
974 .push(ToolCallContent::Terminal(update.terminal));
975 }
976 }
977
978 cx.emit(AcpThreadEvent::EntryUpdated(ix));
979
980 Ok(())
981 }
982
983 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
984 pub fn upsert_tool_call(
985 &mut self,
986 tool_call: acp::ToolCall,
987 cx: &mut Context<Self>,
988 ) -> Result<(), acp::Error> {
989 let status = tool_call.status.into();
990 self.upsert_tool_call_inner(tool_call.into(), status, cx)
991 }
992
993 /// Fails if id does not match an existing entry.
994 pub fn upsert_tool_call_inner(
995 &mut self,
996 tool_call_update: acp::ToolCallUpdate,
997 status: ToolCallStatus,
998 cx: &mut Context<Self>,
999 ) -> Result<(), acp::Error> {
1000 let language_registry = self.project.read(cx).languages().clone();
1001 let id = tool_call_update.id.clone();
1002
1003 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1004 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1005 current_call.status = status;
1006
1007 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1008 } else {
1009 let call =
1010 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1011 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1012 };
1013
1014 self.resolve_locations(id, cx);
1015 Ok(())
1016 }
1017
1018 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1019 // The tool call we are looking for is typically the last one, or very close to the end.
1020 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1021 self.entries
1022 .iter_mut()
1023 .enumerate()
1024 .rev()
1025 .find_map(|(index, tool_call)| {
1026 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1027 && &tool_call.id == id
1028 {
1029 Some((index, tool_call))
1030 } else {
1031 None
1032 }
1033 })
1034 }
1035
1036 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1037 let project = self.project.clone();
1038 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1039 return;
1040 };
1041 let task = tool_call.resolve_locations(project, cx);
1042 cx.spawn(async move |this, cx| {
1043 let resolved_locations = task.await;
1044 this.update(cx, |this, cx| {
1045 let project = this.project.clone();
1046 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1047 return;
1048 };
1049 if let Some(Some(location)) = resolved_locations.last() {
1050 project.update(cx, |project, cx| {
1051 if let Some(agent_location) = project.agent_location() {
1052 let should_ignore = agent_location.buffer == location.buffer
1053 && location
1054 .buffer
1055 .update(cx, |buffer, _| {
1056 let snapshot = buffer.snapshot();
1057 let old_position =
1058 agent_location.position.to_point(&snapshot);
1059 let new_position = location.position.to_point(&snapshot);
1060 // ignore this so that when we get updates from the edit tool
1061 // the position doesn't reset to the startof line
1062 old_position.row == new_position.row
1063 && old_position.column > new_position.column
1064 })
1065 .ok()
1066 .unwrap_or_default();
1067 if !should_ignore {
1068 project.set_agent_location(Some(location.clone()), cx);
1069 }
1070 }
1071 });
1072 }
1073 if tool_call.resolved_locations != resolved_locations {
1074 tool_call.resolved_locations = resolved_locations;
1075 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1076 }
1077 })
1078 })
1079 .detach();
1080 }
1081
1082 pub fn request_tool_call_authorization(
1083 &mut self,
1084 tool_call: acp::ToolCallUpdate,
1085 options: Vec<acp::PermissionOption>,
1086 cx: &mut Context<Self>,
1087 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1088 let (tx, rx) = oneshot::channel();
1089
1090 let status = ToolCallStatus::WaitingForConfirmation {
1091 options,
1092 respond_tx: tx,
1093 };
1094
1095 self.upsert_tool_call_inner(tool_call, status, cx)?;
1096 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1097 Ok(rx)
1098 }
1099
1100 pub fn authorize_tool_call(
1101 &mut self,
1102 id: acp::ToolCallId,
1103 option_id: acp::PermissionOptionId,
1104 option_kind: acp::PermissionOptionKind,
1105 cx: &mut Context<Self>,
1106 ) {
1107 let Some((ix, call)) = self.tool_call_mut(&id) else {
1108 return;
1109 };
1110
1111 let new_status = match option_kind {
1112 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1113 ToolCallStatus::Rejected
1114 }
1115 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1116 ToolCallStatus::InProgress
1117 }
1118 };
1119
1120 let curr_status = mem::replace(&mut call.status, new_status);
1121
1122 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1123 respond_tx.send(option_id).log_err();
1124 } else if cfg!(debug_assertions) {
1125 panic!("tried to authorize an already authorized tool call");
1126 }
1127
1128 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1129 }
1130
1131 /// Returns true if the last turn is awaiting tool authorization
1132 pub fn waiting_for_tool_confirmation(&self) -> bool {
1133 for entry in self.entries.iter().rev() {
1134 match &entry {
1135 AgentThreadEntry::ToolCall(call) => match call.status {
1136 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1137 ToolCallStatus::Pending
1138 | ToolCallStatus::InProgress
1139 | ToolCallStatus::Completed
1140 | ToolCallStatus::Failed
1141 | ToolCallStatus::Rejected
1142 | ToolCallStatus::Canceled => continue,
1143 },
1144 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1145 // Reached the beginning of the turn
1146 return false;
1147 }
1148 }
1149 }
1150 false
1151 }
1152
1153 pub fn plan(&self) -> &Plan {
1154 &self.plan
1155 }
1156
1157 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1158 let new_entries_len = request.entries.len();
1159 let mut new_entries = request.entries.into_iter();
1160
1161 // Reuse existing markdown to prevent flickering
1162 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1163 let PlanEntry {
1164 content,
1165 priority,
1166 status,
1167 } = old;
1168 content.update(cx, |old, cx| {
1169 old.replace(new.content, cx);
1170 });
1171 *priority = new.priority;
1172 *status = new.status;
1173 }
1174 for new in new_entries {
1175 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1176 }
1177 self.plan.entries.truncate(new_entries_len);
1178
1179 cx.notify();
1180 }
1181
1182 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1183 self.plan
1184 .entries
1185 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1186 cx.notify();
1187 }
1188
1189 #[cfg(any(test, feature = "test-support"))]
1190 pub fn send_raw(
1191 &mut self,
1192 message: &str,
1193 cx: &mut Context<Self>,
1194 ) -> BoxFuture<'static, Result<()>> {
1195 self.send(
1196 vec![acp::ContentBlock::Text(acp::TextContent {
1197 text: message.to_string(),
1198 annotations: None,
1199 })],
1200 cx,
1201 )
1202 }
1203
1204 pub fn send(
1205 &mut self,
1206 message: Vec<acp::ContentBlock>,
1207 cx: &mut Context<Self>,
1208 ) -> BoxFuture<'static, Result<()>> {
1209 let block = ContentBlock::new_combined(
1210 message.clone(),
1211 self.project.read(cx).languages().clone(),
1212 cx,
1213 );
1214 let request = acp::PromptRequest {
1215 prompt: message.clone(),
1216 session_id: self.session_id.clone(),
1217 };
1218 let git_store = self.project.read(cx).git_store().clone();
1219
1220 let message_id = if self
1221 .connection
1222 .session_editor(&self.session_id, cx)
1223 .is_some()
1224 {
1225 Some(UserMessageId::new())
1226 } else {
1227 None
1228 };
1229
1230 self.run_turn(cx, async move |this, cx| {
1231 this.update(cx, |this, cx| {
1232 this.push_entry(
1233 AgentThreadEntry::UserMessage(UserMessage {
1234 id: message_id.clone(),
1235 content: block,
1236 chunks: message,
1237 checkpoint: None,
1238 }),
1239 cx,
1240 );
1241 })
1242 .ok();
1243
1244 let old_checkpoint = git_store
1245 .update(cx, |git, cx| git.checkpoint(cx))?
1246 .await
1247 .context("failed to get old checkpoint")
1248 .log_err();
1249 this.update(cx, |this, cx| {
1250 if let Some((_ix, message)) = this.last_user_message() {
1251 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1252 git_checkpoint,
1253 show: false,
1254 });
1255 }
1256 this.connection.prompt(message_id, request, cx)
1257 })?
1258 .await
1259 })
1260 }
1261
1262 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1263 self.run_turn(cx, async move |this, cx| {
1264 this.update(cx, |this, cx| {
1265 this.connection
1266 .resume(&this.session_id, cx)
1267 .map(|resume| resume.run(cx))
1268 })?
1269 .context("resuming a session is not supported")?
1270 .await
1271 })
1272 }
1273
1274 fn run_turn(
1275 &mut self,
1276 cx: &mut Context<Self>,
1277 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1278 ) -> BoxFuture<'static, Result<()>> {
1279 self.clear_completed_plan_entries(cx);
1280
1281 let (tx, rx) = oneshot::channel();
1282 let cancel_task = self.cancel(cx);
1283
1284 self.send_task = Some(cx.spawn(async move |this, cx| {
1285 cancel_task.await;
1286 tx.send(f(this, cx).await).ok();
1287 }));
1288
1289 cx.spawn(async move |this, cx| {
1290 let response = rx.await;
1291
1292 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1293 .await?;
1294
1295 this.update(cx, |this, cx| {
1296 match response {
1297 Ok(Err(e)) => {
1298 this.send_task.take();
1299 cx.emit(AcpThreadEvent::Error);
1300 Err(e)
1301 }
1302 result => {
1303 let canceled = matches!(
1304 result,
1305 Ok(Ok(acp::PromptResponse {
1306 stop_reason: acp::StopReason::Canceled
1307 }))
1308 );
1309
1310 // We only take the task if the current prompt wasn't canceled.
1311 //
1312 // This prompt may have been canceled because another one was sent
1313 // while it was still generating. In these cases, dropping `send_task`
1314 // would cause the next generation to be canceled.
1315 if !canceled {
1316 this.send_task.take();
1317 }
1318
1319 cx.emit(AcpThreadEvent::Stopped);
1320 Ok(())
1321 }
1322 }
1323 })?
1324 })
1325 .boxed()
1326 }
1327
1328 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1329 let Some(send_task) = self.send_task.take() else {
1330 return Task::ready(());
1331 };
1332
1333 for entry in self.entries.iter_mut() {
1334 if let AgentThreadEntry::ToolCall(call) = entry {
1335 let cancel = matches!(
1336 call.status,
1337 ToolCallStatus::Pending
1338 | ToolCallStatus::WaitingForConfirmation { .. }
1339 | ToolCallStatus::InProgress
1340 );
1341
1342 if cancel {
1343 call.status = ToolCallStatus::Canceled;
1344 }
1345 }
1346 }
1347
1348 self.connection.cancel(&self.session_id, cx);
1349
1350 // Wait for the send task to complete
1351 cx.foreground_executor().spawn(send_task)
1352 }
1353
1354 /// Rewinds this thread to before the entry at `index`, removing it and all
1355 /// subsequent entries while reverting any changes made from that point.
1356 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1357 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1358 return Task::ready(Err(anyhow!("not supported")));
1359 };
1360 let Some(message) = self.user_message(&id) else {
1361 return Task::ready(Err(anyhow!("message not found")));
1362 };
1363
1364 let checkpoint = message
1365 .checkpoint
1366 .as_ref()
1367 .map(|c| c.git_checkpoint.clone());
1368
1369 let git_store = self.project.read(cx).git_store().clone();
1370 cx.spawn(async move |this, cx| {
1371 if let Some(checkpoint) = checkpoint {
1372 git_store
1373 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1374 .await?;
1375 }
1376
1377 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1378 .await?;
1379 this.update(cx, |this, cx| {
1380 if let Some((ix, _)) = this.user_message_mut(&id) {
1381 let range = ix..this.entries.len();
1382 this.entries.truncate(ix);
1383 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1384 }
1385 })
1386 })
1387 }
1388
1389 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1390 let git_store = self.project.read(cx).git_store().clone();
1391
1392 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1393 if let Some(checkpoint) = message.checkpoint.as_ref() {
1394 checkpoint.git_checkpoint.clone()
1395 } else {
1396 return Task::ready(Ok(()));
1397 }
1398 } else {
1399 return Task::ready(Ok(()));
1400 };
1401
1402 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1403 cx.spawn(async move |this, cx| {
1404 let new_checkpoint = new_checkpoint
1405 .await
1406 .context("failed to get new checkpoint")
1407 .log_err();
1408 if let Some(new_checkpoint) = new_checkpoint {
1409 let equal = git_store
1410 .update(cx, |git, cx| {
1411 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1412 })?
1413 .await
1414 .unwrap_or(true);
1415 this.update(cx, |this, cx| {
1416 let (ix, message) = this.last_user_message().context("no user message")?;
1417 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1418 checkpoint.show = !equal;
1419 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1420 anyhow::Ok(())
1421 })??;
1422 }
1423
1424 Ok(())
1425 })
1426 }
1427
1428 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1429 self.entries
1430 .iter_mut()
1431 .enumerate()
1432 .rev()
1433 .find_map(|(ix, entry)| {
1434 if let AgentThreadEntry::UserMessage(message) = entry {
1435 Some((ix, message))
1436 } else {
1437 None
1438 }
1439 })
1440 }
1441
1442 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1443 self.entries.iter().find_map(|entry| {
1444 if let AgentThreadEntry::UserMessage(message) = entry {
1445 if message.id.as_ref() == Some(&id) {
1446 Some(message)
1447 } else {
1448 None
1449 }
1450 } else {
1451 None
1452 }
1453 })
1454 }
1455
1456 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1457 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1458 if let AgentThreadEntry::UserMessage(message) = entry {
1459 if message.id.as_ref() == Some(&id) {
1460 Some((ix, message))
1461 } else {
1462 None
1463 }
1464 } else {
1465 None
1466 }
1467 })
1468 }
1469
1470 pub fn read_text_file(
1471 &self,
1472 path: PathBuf,
1473 line: Option<u32>,
1474 limit: Option<u32>,
1475 reuse_shared_snapshot: bool,
1476 cx: &mut Context<Self>,
1477 ) -> Task<Result<String>> {
1478 let project = self.project.clone();
1479 let action_log = self.action_log.clone();
1480 cx.spawn(async move |this, cx| {
1481 let load = project.update(cx, |project, cx| {
1482 let path = project
1483 .project_path_for_absolute_path(&path, cx)
1484 .context("invalid path")?;
1485 anyhow::Ok(project.open_buffer(path, cx))
1486 });
1487 let buffer = load??.await?;
1488
1489 let snapshot = if reuse_shared_snapshot {
1490 this.read_with(cx, |this, _| {
1491 this.shared_buffers.get(&buffer.clone()).cloned()
1492 })
1493 .log_err()
1494 .flatten()
1495 } else {
1496 None
1497 };
1498
1499 let snapshot = if let Some(snapshot) = snapshot {
1500 snapshot
1501 } else {
1502 action_log.update(cx, |action_log, cx| {
1503 action_log.buffer_read(buffer.clone(), cx);
1504 })?;
1505 project.update(cx, |project, cx| {
1506 let position = buffer
1507 .read(cx)
1508 .snapshot()
1509 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1510 project.set_agent_location(
1511 Some(AgentLocation {
1512 buffer: buffer.downgrade(),
1513 position,
1514 }),
1515 cx,
1516 );
1517 })?;
1518
1519 buffer.update(cx, |buffer, _| buffer.snapshot())?
1520 };
1521
1522 this.update(cx, |this, _| {
1523 let text = snapshot.text();
1524 this.shared_buffers.insert(buffer.clone(), snapshot);
1525 if line.is_none() && limit.is_none() {
1526 return Ok(text);
1527 }
1528 let limit = limit.unwrap_or(u32::MAX) as usize;
1529 let Some(line) = line else {
1530 return Ok(text.lines().take(limit).collect::<String>());
1531 };
1532
1533 let count = text.lines().count();
1534 if count < line as usize {
1535 anyhow::bail!("There are only {} lines", count);
1536 }
1537 Ok(text
1538 .lines()
1539 .skip(line as usize + 1)
1540 .take(limit)
1541 .collect::<String>())
1542 })?
1543 })
1544 }
1545
1546 pub fn write_text_file(
1547 &self,
1548 path: PathBuf,
1549 content: String,
1550 cx: &mut Context<Self>,
1551 ) -> Task<Result<()>> {
1552 let project = self.project.clone();
1553 let action_log = self.action_log.clone();
1554 cx.spawn(async move |this, cx| {
1555 let load = project.update(cx, |project, cx| {
1556 let path = project
1557 .project_path_for_absolute_path(&path, cx)
1558 .context("invalid path")?;
1559 anyhow::Ok(project.open_buffer(path, cx))
1560 });
1561 let buffer = load??.await?;
1562 let snapshot = this.update(cx, |this, cx| {
1563 this.shared_buffers
1564 .get(&buffer)
1565 .cloned()
1566 .unwrap_or_else(|| buffer.read(cx).snapshot())
1567 })?;
1568 let edits = cx
1569 .background_executor()
1570 .spawn(async move {
1571 let old_text = snapshot.text();
1572 text_diff(old_text.as_str(), &content)
1573 .into_iter()
1574 .map(|(range, replacement)| {
1575 (
1576 snapshot.anchor_after(range.start)
1577 ..snapshot.anchor_before(range.end),
1578 replacement,
1579 )
1580 })
1581 .collect::<Vec<_>>()
1582 })
1583 .await;
1584 cx.update(|cx| {
1585 project.update(cx, |project, cx| {
1586 project.set_agent_location(
1587 Some(AgentLocation {
1588 buffer: buffer.downgrade(),
1589 position: edits
1590 .last()
1591 .map(|(range, _)| range.end)
1592 .unwrap_or(Anchor::MIN),
1593 }),
1594 cx,
1595 );
1596 });
1597
1598 action_log.update(cx, |action_log, cx| {
1599 action_log.buffer_read(buffer.clone(), cx);
1600 });
1601 buffer.update(cx, |buffer, cx| {
1602 buffer.edit(edits, None, cx);
1603 });
1604 action_log.update(cx, |action_log, cx| {
1605 action_log.buffer_edited(buffer.clone(), cx);
1606 });
1607 })?;
1608 project
1609 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1610 .await
1611 })
1612 }
1613
1614 pub fn to_markdown(&self, cx: &App) -> String {
1615 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1616 }
1617
1618 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1619 cx.emit(AcpThreadEvent::ServerExited(status));
1620 }
1621}
1622
1623fn markdown_for_raw_output(
1624 raw_output: &serde_json::Value,
1625 language_registry: &Arc<LanguageRegistry>,
1626 cx: &mut App,
1627) -> Option<Entity<Markdown>> {
1628 match raw_output {
1629 serde_json::Value::Null => None,
1630 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1631 Markdown::new(
1632 value.to_string().into(),
1633 Some(language_registry.clone()),
1634 None,
1635 cx,
1636 )
1637 })),
1638 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1639 Markdown::new(
1640 value.to_string().into(),
1641 Some(language_registry.clone()),
1642 None,
1643 cx,
1644 )
1645 })),
1646 serde_json::Value::String(value) => Some(cx.new(|cx| {
1647 Markdown::new(
1648 value.clone().into(),
1649 Some(language_registry.clone()),
1650 None,
1651 cx,
1652 )
1653 })),
1654 value => Some(cx.new(|cx| {
1655 Markdown::new(
1656 format!("```json\n{}\n```", value).into(),
1657 Some(language_registry.clone()),
1658 None,
1659 cx,
1660 )
1661 })),
1662 }
1663}
1664
1665#[cfg(test)]
1666mod tests {
1667 use super::*;
1668 use anyhow::anyhow;
1669 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1670 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1671 use indoc::indoc;
1672 use project::{FakeFs, Fs};
1673 use rand::Rng as _;
1674 use serde_json::json;
1675 use settings::SettingsStore;
1676 use smol::stream::StreamExt as _;
1677 use std::{
1678 any::Any,
1679 cell::RefCell,
1680 path::Path,
1681 rc::Rc,
1682 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1683 time::Duration,
1684 };
1685 use util::path;
1686
1687 fn init_test(cx: &mut TestAppContext) {
1688 env_logger::try_init().ok();
1689 cx.update(|cx| {
1690 let settings_store = SettingsStore::test(cx);
1691 cx.set_global(settings_store);
1692 Project::init_settings(cx);
1693 language::init(cx);
1694 });
1695 }
1696
1697 #[gpui::test]
1698 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1699 init_test(cx);
1700
1701 let fs = FakeFs::new(cx.executor());
1702 let project = Project::test(fs, [], cx).await;
1703 let connection = Rc::new(FakeAgentConnection::new());
1704 let thread = cx
1705 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1706 .await
1707 .unwrap();
1708
1709 // Test creating a new user message
1710 thread.update(cx, |thread, cx| {
1711 thread.push_user_content_block(
1712 None,
1713 acp::ContentBlock::Text(acp::TextContent {
1714 annotations: None,
1715 text: "Hello, ".to_string(),
1716 }),
1717 cx,
1718 );
1719 });
1720
1721 thread.update(cx, |thread, cx| {
1722 assert_eq!(thread.entries.len(), 1);
1723 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1724 assert_eq!(user_msg.id, None);
1725 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1726 } else {
1727 panic!("Expected UserMessage");
1728 }
1729 });
1730
1731 // Test appending to existing user message
1732 let message_1_id = UserMessageId::new();
1733 thread.update(cx, |thread, cx| {
1734 thread.push_user_content_block(
1735 Some(message_1_id.clone()),
1736 acp::ContentBlock::Text(acp::TextContent {
1737 annotations: None,
1738 text: "world!".to_string(),
1739 }),
1740 cx,
1741 );
1742 });
1743
1744 thread.update(cx, |thread, cx| {
1745 assert_eq!(thread.entries.len(), 1);
1746 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1747 assert_eq!(user_msg.id, Some(message_1_id));
1748 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1749 } else {
1750 panic!("Expected UserMessage");
1751 }
1752 });
1753
1754 // Test creating new user message after assistant message
1755 thread.update(cx, |thread, cx| {
1756 thread.push_assistant_content_block(
1757 acp::ContentBlock::Text(acp::TextContent {
1758 annotations: None,
1759 text: "Assistant response".to_string(),
1760 }),
1761 false,
1762 cx,
1763 );
1764 });
1765
1766 let message_2_id = UserMessageId::new();
1767 thread.update(cx, |thread, cx| {
1768 thread.push_user_content_block(
1769 Some(message_2_id.clone()),
1770 acp::ContentBlock::Text(acp::TextContent {
1771 annotations: None,
1772 text: "New user message".to_string(),
1773 }),
1774 cx,
1775 );
1776 });
1777
1778 thread.update(cx, |thread, cx| {
1779 assert_eq!(thread.entries.len(), 3);
1780 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1781 assert_eq!(user_msg.id, Some(message_2_id));
1782 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1783 } else {
1784 panic!("Expected UserMessage at index 2");
1785 }
1786 });
1787 }
1788
1789 #[gpui::test]
1790 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1791 init_test(cx);
1792
1793 let fs = FakeFs::new(cx.executor());
1794 let project = Project::test(fs, [], cx).await;
1795 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1796 |_, thread, mut cx| {
1797 async move {
1798 thread.update(&mut cx, |thread, cx| {
1799 thread
1800 .handle_session_update(
1801 acp::SessionUpdate::AgentThoughtChunk {
1802 content: "Thinking ".into(),
1803 },
1804 cx,
1805 )
1806 .unwrap();
1807 thread
1808 .handle_session_update(
1809 acp::SessionUpdate::AgentThoughtChunk {
1810 content: "hard!".into(),
1811 },
1812 cx,
1813 )
1814 .unwrap();
1815 })?;
1816 Ok(acp::PromptResponse {
1817 stop_reason: acp::StopReason::EndTurn,
1818 })
1819 }
1820 .boxed_local()
1821 },
1822 ));
1823
1824 let thread = cx
1825 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1826 .await
1827 .unwrap();
1828
1829 thread
1830 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1831 .await
1832 .unwrap();
1833
1834 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1835 assert_eq!(
1836 output,
1837 indoc! {r#"
1838 ## User
1839
1840 Hello from Zed!
1841
1842 ## Assistant
1843
1844 <thinking>
1845 Thinking hard!
1846 </thinking>
1847
1848 "#}
1849 );
1850 }
1851
1852 #[gpui::test]
1853 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1854 init_test(cx);
1855
1856 let fs = FakeFs::new(cx.executor());
1857 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1858 .await;
1859 let project = Project::test(fs.clone(), [], cx).await;
1860 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1861 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1862 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1863 move |_, thread, mut cx| {
1864 let read_file_tx = read_file_tx.clone();
1865 async move {
1866 let content = thread
1867 .update(&mut cx, |thread, cx| {
1868 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1869 })
1870 .unwrap()
1871 .await
1872 .unwrap();
1873 assert_eq!(content, "one\ntwo\nthree\n");
1874 read_file_tx.take().unwrap().send(()).unwrap();
1875 thread
1876 .update(&mut cx, |thread, cx| {
1877 thread.write_text_file(
1878 path!("/tmp/foo").into(),
1879 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1880 cx,
1881 )
1882 })
1883 .unwrap()
1884 .await
1885 .unwrap();
1886 Ok(acp::PromptResponse {
1887 stop_reason: acp::StopReason::EndTurn,
1888 })
1889 }
1890 .boxed_local()
1891 },
1892 ));
1893
1894 let (worktree, pathbuf) = project
1895 .update(cx, |project, cx| {
1896 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1897 })
1898 .await
1899 .unwrap();
1900 let buffer = project
1901 .update(cx, |project, cx| {
1902 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1903 })
1904 .await
1905 .unwrap();
1906
1907 let thread = cx
1908 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1909 .await
1910 .unwrap();
1911
1912 let request = thread.update(cx, |thread, cx| {
1913 thread.send_raw("Extend the count in /tmp/foo", cx)
1914 });
1915 read_file_rx.await.ok();
1916 buffer.update(cx, |buffer, cx| {
1917 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1918 });
1919 cx.run_until_parked();
1920 assert_eq!(
1921 buffer.read_with(cx, |buffer, _| buffer.text()),
1922 "zero\none\ntwo\nthree\nfour\nfive\n"
1923 );
1924 assert_eq!(
1925 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1926 "zero\none\ntwo\nthree\nfour\nfive\n"
1927 );
1928 request.await.unwrap();
1929 }
1930
1931 #[gpui::test]
1932 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1933 init_test(cx);
1934
1935 let fs = FakeFs::new(cx.executor());
1936 let project = Project::test(fs, [], cx).await;
1937 let id = acp::ToolCallId("test".into());
1938
1939 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1940 let id = id.clone();
1941 move |_, thread, mut cx| {
1942 let id = id.clone();
1943 async move {
1944 thread
1945 .update(&mut cx, |thread, cx| {
1946 thread.handle_session_update(
1947 acp::SessionUpdate::ToolCall(acp::ToolCall {
1948 id: id.clone(),
1949 title: "Label".into(),
1950 kind: acp::ToolKind::Fetch,
1951 status: acp::ToolCallStatus::InProgress,
1952 content: vec![],
1953 locations: vec![],
1954 raw_input: None,
1955 raw_output: None,
1956 }),
1957 cx,
1958 )
1959 })
1960 .unwrap()
1961 .unwrap();
1962 Ok(acp::PromptResponse {
1963 stop_reason: acp::StopReason::EndTurn,
1964 })
1965 }
1966 .boxed_local()
1967 }
1968 }));
1969
1970 let thread = cx
1971 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1972 .await
1973 .unwrap();
1974
1975 let request = thread.update(cx, |thread, cx| {
1976 thread.send_raw("Fetch https://example.com", cx)
1977 });
1978
1979 run_until_first_tool_call(&thread, cx).await;
1980
1981 thread.read_with(cx, |thread, _| {
1982 assert!(matches!(
1983 thread.entries[1],
1984 AgentThreadEntry::ToolCall(ToolCall {
1985 status: ToolCallStatus::InProgress,
1986 ..
1987 })
1988 ));
1989 });
1990
1991 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1992
1993 thread.read_with(cx, |thread, _| {
1994 assert!(matches!(
1995 &thread.entries[1],
1996 AgentThreadEntry::ToolCall(ToolCall {
1997 status: ToolCallStatus::Canceled,
1998 ..
1999 })
2000 ));
2001 });
2002
2003 thread
2004 .update(cx, |thread, cx| {
2005 thread.handle_session_update(
2006 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2007 id,
2008 fields: acp::ToolCallUpdateFields {
2009 status: Some(acp::ToolCallStatus::Completed),
2010 ..Default::default()
2011 },
2012 }),
2013 cx,
2014 )
2015 })
2016 .unwrap();
2017
2018 request.await.unwrap();
2019
2020 thread.read_with(cx, |thread, _| {
2021 assert!(matches!(
2022 thread.entries[1],
2023 AgentThreadEntry::ToolCall(ToolCall {
2024 status: ToolCallStatus::Completed,
2025 ..
2026 })
2027 ));
2028 });
2029 }
2030
2031 #[gpui::test]
2032 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2033 init_test(cx);
2034 let fs = FakeFs::new(cx.background_executor.clone());
2035 fs.insert_tree(path!("/test"), json!({})).await;
2036 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2037
2038 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2039 move |_, thread, mut cx| {
2040 async move {
2041 thread
2042 .update(&mut cx, |thread, cx| {
2043 thread.handle_session_update(
2044 acp::SessionUpdate::ToolCall(acp::ToolCall {
2045 id: acp::ToolCallId("test".into()),
2046 title: "Label".into(),
2047 kind: acp::ToolKind::Edit,
2048 status: acp::ToolCallStatus::Completed,
2049 content: vec![acp::ToolCallContent::Diff {
2050 diff: acp::Diff {
2051 path: "/test/test.txt".into(),
2052 old_text: None,
2053 new_text: "foo".into(),
2054 },
2055 }],
2056 locations: vec![],
2057 raw_input: None,
2058 raw_output: None,
2059 }),
2060 cx,
2061 )
2062 })
2063 .unwrap()
2064 .unwrap();
2065 Ok(acp::PromptResponse {
2066 stop_reason: acp::StopReason::EndTurn,
2067 })
2068 }
2069 .boxed_local()
2070 }
2071 }));
2072
2073 let thread = cx
2074 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2075 .await
2076 .unwrap();
2077
2078 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2079 .await
2080 .unwrap();
2081
2082 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2083 }
2084
2085 #[gpui::test(iterations = 10)]
2086 async fn test_checkpoints(cx: &mut TestAppContext) {
2087 init_test(cx);
2088 let fs = FakeFs::new(cx.background_executor.clone());
2089 fs.insert_tree(
2090 path!("/test"),
2091 json!({
2092 ".git": {}
2093 }),
2094 )
2095 .await;
2096 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2097
2098 let simulate_changes = Arc::new(AtomicBool::new(true));
2099 let next_filename = Arc::new(AtomicUsize::new(0));
2100 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2101 let simulate_changes = simulate_changes.clone();
2102 let next_filename = next_filename.clone();
2103 let fs = fs.clone();
2104 move |request, thread, mut cx| {
2105 let fs = fs.clone();
2106 let simulate_changes = simulate_changes.clone();
2107 let next_filename = next_filename.clone();
2108 async move {
2109 if simulate_changes.load(SeqCst) {
2110 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2111 fs.write(Path::new(&filename), b"").await?;
2112 }
2113
2114 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2115 panic!("expected text content block");
2116 };
2117 thread.update(&mut cx, |thread, cx| {
2118 thread
2119 .handle_session_update(
2120 acp::SessionUpdate::AgentMessageChunk {
2121 content: content.text.to_uppercase().into(),
2122 },
2123 cx,
2124 )
2125 .unwrap();
2126 })?;
2127 Ok(acp::PromptResponse {
2128 stop_reason: acp::StopReason::EndTurn,
2129 })
2130 }
2131 .boxed_local()
2132 }
2133 }));
2134 let thread = cx
2135 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2136 .await
2137 .unwrap();
2138
2139 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2140 .await
2141 .unwrap();
2142 thread.read_with(cx, |thread, cx| {
2143 assert_eq!(
2144 thread.to_markdown(cx),
2145 indoc! {"
2146 ## User (checkpoint)
2147
2148 Lorem
2149
2150 ## Assistant
2151
2152 LOREM
2153
2154 "}
2155 );
2156 });
2157 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2158
2159 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2160 .await
2161 .unwrap();
2162 thread.read_with(cx, |thread, cx| {
2163 assert_eq!(
2164 thread.to_markdown(cx),
2165 indoc! {"
2166 ## User (checkpoint)
2167
2168 Lorem
2169
2170 ## Assistant
2171
2172 LOREM
2173
2174 ## User (checkpoint)
2175
2176 ipsum
2177
2178 ## Assistant
2179
2180 IPSUM
2181
2182 "}
2183 );
2184 });
2185 assert_eq!(
2186 fs.files(),
2187 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2188 );
2189
2190 // Checkpoint isn't stored when there are no changes.
2191 simulate_changes.store(false, SeqCst);
2192 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2193 .await
2194 .unwrap();
2195 thread.read_with(cx, |thread, cx| {
2196 assert_eq!(
2197 thread.to_markdown(cx),
2198 indoc! {"
2199 ## User (checkpoint)
2200
2201 Lorem
2202
2203 ## Assistant
2204
2205 LOREM
2206
2207 ## User (checkpoint)
2208
2209 ipsum
2210
2211 ## Assistant
2212
2213 IPSUM
2214
2215 ## User
2216
2217 dolor
2218
2219 ## Assistant
2220
2221 DOLOR
2222
2223 "}
2224 );
2225 });
2226 assert_eq!(
2227 fs.files(),
2228 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2229 );
2230
2231 // Rewinding the conversation truncates the history and restores the checkpoint.
2232 thread
2233 .update(cx, |thread, cx| {
2234 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2235 panic!("unexpected entries {:?}", thread.entries)
2236 };
2237 thread.rewind(message.id.clone().unwrap(), cx)
2238 })
2239 .await
2240 .unwrap();
2241 thread.read_with(cx, |thread, cx| {
2242 assert_eq!(
2243 thread.to_markdown(cx),
2244 indoc! {"
2245 ## User (checkpoint)
2246
2247 Lorem
2248
2249 ## Assistant
2250
2251 LOREM
2252
2253 "}
2254 );
2255 });
2256 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2257 }
2258
2259 async fn run_until_first_tool_call(
2260 thread: &Entity<AcpThread>,
2261 cx: &mut TestAppContext,
2262 ) -> usize {
2263 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2264
2265 let subscription = cx.update(|cx| {
2266 cx.subscribe(thread, move |thread, _, cx| {
2267 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2268 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2269 return tx.try_send(ix).unwrap();
2270 }
2271 }
2272 })
2273 });
2274
2275 select! {
2276 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2277 panic!("Timeout waiting for tool call")
2278 }
2279 ix = rx.next().fuse() => {
2280 drop(subscription);
2281 ix.unwrap()
2282 }
2283 }
2284 }
2285
2286 #[derive(Clone, Default)]
2287 struct FakeAgentConnection {
2288 auth_methods: Vec<acp::AuthMethod>,
2289 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2290 on_user_message: Option<
2291 Rc<
2292 dyn Fn(
2293 acp::PromptRequest,
2294 WeakEntity<AcpThread>,
2295 AsyncApp,
2296 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2297 + 'static,
2298 >,
2299 >,
2300 }
2301
2302 impl FakeAgentConnection {
2303 fn new() -> Self {
2304 Self {
2305 auth_methods: Vec::new(),
2306 on_user_message: None,
2307 sessions: Arc::default(),
2308 }
2309 }
2310
2311 #[expect(unused)]
2312 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2313 self.auth_methods = auth_methods;
2314 self
2315 }
2316
2317 fn on_user_message(
2318 mut self,
2319 handler: impl Fn(
2320 acp::PromptRequest,
2321 WeakEntity<AcpThread>,
2322 AsyncApp,
2323 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2324 + 'static,
2325 ) -> Self {
2326 self.on_user_message.replace(Rc::new(handler));
2327 self
2328 }
2329 }
2330
2331 impl AgentConnection for FakeAgentConnection {
2332 fn auth_methods(&self) -> &[acp::AuthMethod] {
2333 &self.auth_methods
2334 }
2335
2336 fn new_thread(
2337 self: Rc<Self>,
2338 project: Entity<Project>,
2339 _cwd: &Path,
2340 cx: &mut App,
2341 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2342 let session_id = acp::SessionId(
2343 rand::thread_rng()
2344 .sample_iter(&rand::distributions::Alphanumeric)
2345 .take(7)
2346 .map(char::from)
2347 .collect::<String>()
2348 .into(),
2349 );
2350 let thread =
2351 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2352 self.sessions.lock().insert(session_id, thread.downgrade());
2353 Task::ready(Ok(thread))
2354 }
2355
2356 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2357 if self.auth_methods().iter().any(|m| m.id == method) {
2358 Task::ready(Ok(()))
2359 } else {
2360 Task::ready(Err(anyhow!("Invalid Auth Method")))
2361 }
2362 }
2363
2364 fn prompt(
2365 &self,
2366 _id: Option<UserMessageId>,
2367 params: acp::PromptRequest,
2368 cx: &mut App,
2369 ) -> Task<gpui::Result<acp::PromptResponse>> {
2370 let sessions = self.sessions.lock();
2371 let thread = sessions.get(¶ms.session_id).unwrap();
2372 if let Some(handler) = &self.on_user_message {
2373 let handler = handler.clone();
2374 let thread = thread.clone();
2375 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2376 } else {
2377 Task::ready(Ok(acp::PromptResponse {
2378 stop_reason: acp::StopReason::EndTurn,
2379 }))
2380 }
2381 }
2382
2383 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2384 let sessions = self.sessions.lock();
2385 let thread = sessions.get(&session_id).unwrap().clone();
2386
2387 cx.spawn(async move |cx| {
2388 thread
2389 .update(cx, |thread, cx| thread.cancel(cx))
2390 .unwrap()
2391 .await
2392 })
2393 .detach();
2394 }
2395
2396 fn session_editor(
2397 &self,
2398 session_id: &acp::SessionId,
2399 _cx: &mut App,
2400 ) -> Option<Rc<dyn AgentSessionEditor>> {
2401 Some(Rc::new(FakeAgentSessionEditor {
2402 _session_id: session_id.clone(),
2403 }))
2404 }
2405
2406 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2407 self
2408 }
2409 }
2410
2411 struct FakeAgentSessionEditor {
2412 _session_id: acp::SessionId,
2413 }
2414
2415 impl AgentSessionEditor for FakeAgentSessionEditor {
2416 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2417 Task::ready(Ok(()))
2418 }
2419 }
2420}