1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub id: Option<UserMessageId>,
35 pub content: ContentBlock,
36 pub chunks: Vec<acp::ContentBlock>,
37 pub checkpoint: Option<Checkpoint>,
38}
39
40#[derive(Debug)]
41pub struct Checkpoint {
42 git_checkpoint: GitStoreCheckpoint,
43 pub show: bool,
44}
45
46impl UserMessage {
47 fn to_markdown(&self, cx: &App) -> String {
48 let mut markdown = String::new();
49 if self
50 .checkpoint
51 .as_ref()
52 .map_or(false, |checkpoint| checkpoint.show)
53 {
54 writeln!(markdown, "## User (checkpoint)").unwrap();
55 } else {
56 writeln!(markdown, "## User").unwrap();
57 }
58 writeln!(markdown).unwrap();
59 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
60 writeln!(markdown).unwrap();
61 markdown
62 }
63}
64
65#[derive(Debug, PartialEq)]
66pub struct AssistantMessage {
67 pub chunks: Vec<AssistantMessageChunk>,
68}
69
70impl AssistantMessage {
71 pub fn to_markdown(&self, cx: &App) -> String {
72 format!(
73 "## Assistant\n\n{}\n\n",
74 self.chunks
75 .iter()
76 .map(|chunk| chunk.to_markdown(cx))
77 .join("\n\n")
78 )
79 }
80}
81
82#[derive(Debug, PartialEq)]
83pub enum AssistantMessageChunk {
84 Message { block: ContentBlock },
85 Thought { block: ContentBlock },
86}
87
88impl AssistantMessageChunk {
89 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
90 Self::Message {
91 block: ContentBlock::new(chunk.into(), language_registry, cx),
92 }
93 }
94
95 fn to_markdown(&self, cx: &App) -> String {
96 match self {
97 Self::Message { block } => block.to_markdown(cx).to_string(),
98 Self::Thought { block } => {
99 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum AgentThreadEntry {
107 UserMessage(UserMessage),
108 AssistantMessage(AssistantMessage),
109 ToolCall(ToolCall),
110}
111
112impl AgentThreadEntry {
113 pub fn to_markdown(&self, cx: &App) -> String {
114 match self {
115 Self::UserMessage(message) => message.to_markdown(cx),
116 Self::AssistantMessage(message) => message.to_markdown(cx),
117 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
118 }
119 }
120
121 pub fn user_message(&self) -> Option<&UserMessage> {
122 if let AgentThreadEntry::UserMessage(message) = self {
123 Some(message)
124 } else {
125 None
126 }
127 }
128
129 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
130 if let AgentThreadEntry::ToolCall(call) = self {
131 itertools::Either::Left(call.diffs())
132 } else {
133 itertools::Either::Right(std::iter::empty())
134 }
135 }
136
137 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.terminals())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
146 if let AgentThreadEntry::ToolCall(ToolCall {
147 locations,
148 resolved_locations,
149 ..
150 }) = self
151 {
152 Some((
153 locations.get(ix)?.clone(),
154 resolved_locations.get(ix)?.clone()?,
155 ))
156 } else {
157 None
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct ToolCall {
164 pub id: acp::ToolCallId,
165 pub label: Entity<Markdown>,
166 pub kind: acp::ToolKind,
167 pub content: Vec<ToolCallContent>,
168 pub status: ToolCallStatus,
169 pub locations: Vec<acp::ToolCallLocation>,
170 pub resolved_locations: Vec<Option<AgentLocation>>,
171 pub raw_input: Option<serde_json::Value>,
172 pub raw_output: Option<serde_json::Value>,
173}
174
175impl ToolCall {
176 fn from_acp(
177 tool_call: acp::ToolCall,
178 status: ToolCallStatus,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) -> Self {
182 Self {
183 id: tool_call.id,
184 label: cx.new(|cx| {
185 Markdown::new(
186 tool_call.title.into(),
187 Some(language_registry.clone()),
188 None,
189 cx,
190 )
191 }),
192 kind: tool_call.kind,
193 content: tool_call
194 .content
195 .into_iter()
196 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
197 .collect(),
198 locations: tool_call.locations,
199 resolved_locations: Vec::default(),
200 status,
201 raw_input: tool_call.raw_input,
202 raw_output: tool_call.raw_output,
203 }
204 }
205
206 fn update_fields(
207 &mut self,
208 fields: acp::ToolCallUpdateFields,
209 language_registry: Arc<LanguageRegistry>,
210 cx: &mut App,
211 ) {
212 let acp::ToolCallUpdateFields {
213 kind,
214 status,
215 title,
216 content,
217 locations,
218 raw_input,
219 raw_output,
220 } = fields;
221
222 if let Some(kind) = kind {
223 self.kind = kind;
224 }
225
226 if let Some(status) = status {
227 self.status = status.into();
228 }
229
230 if let Some(title) = title {
231 self.label.update(cx, |label, cx| {
232 label.replace(title, cx);
233 });
234 }
235
236 if let Some(content) = content {
237 self.content = content
238 .into_iter()
239 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
240 .collect();
241 }
242
243 if let Some(locations) = locations {
244 self.locations = locations;
245 }
246
247 if let Some(raw_input) = raw_input {
248 self.raw_input = Some(raw_input);
249 }
250
251 if let Some(raw_output) = raw_output {
252 if self.content.is_empty() {
253 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
254 {
255 self.content
256 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
257 markdown,
258 }));
259 }
260 }
261 self.raw_output = Some(raw_output);
262 }
263 }
264
265 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
266 self.content.iter().filter_map(|content| match content {
267 ToolCallContent::Diff(diff) => Some(diff),
268 ToolCallContent::ContentBlock(_) => None,
269 ToolCallContent::Terminal(_) => None,
270 })
271 }
272
273 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
274 self.content.iter().filter_map(|content| match content {
275 ToolCallContent::Terminal(terminal) => Some(terminal),
276 ToolCallContent::ContentBlock(_) => None,
277 ToolCallContent::Diff(_) => None,
278 })
279 }
280
281 fn to_markdown(&self, cx: &App) -> String {
282 let mut markdown = format!(
283 "**Tool Call: {}**\nStatus: {}\n\n",
284 self.label.read(cx).source(),
285 self.status
286 );
287 for content in &self.content {
288 markdown.push_str(content.to_markdown(cx).as_str());
289 markdown.push_str("\n\n");
290 }
291 markdown
292 }
293
294 async fn resolve_location(
295 location: acp::ToolCallLocation,
296 project: WeakEntity<Project>,
297 cx: &mut AsyncApp,
298 ) -> Option<AgentLocation> {
299 let buffer = project
300 .update(cx, |project, cx| {
301 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
302 Some(project.open_buffer(path, cx))
303 } else {
304 None
305 }
306 })
307 .ok()??;
308 let buffer = buffer.await.log_err()?;
309 let position = buffer
310 .update(cx, |buffer, _| {
311 if let Some(row) = location.line {
312 let snapshot = buffer.snapshot();
313 let column = snapshot.indent_size_for_line(row).len;
314 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
315 snapshot.anchor_before(point)
316 } else {
317 Anchor::MIN
318 }
319 })
320 .ok()?;
321
322 Some(AgentLocation {
323 buffer: buffer.downgrade(),
324 position,
325 })
326 }
327
328 fn resolve_locations(
329 &self,
330 project: Entity<Project>,
331 cx: &mut App,
332 ) -> Task<Vec<Option<AgentLocation>>> {
333 let locations = self.locations.clone();
334 project.update(cx, |_, cx| {
335 cx.spawn(async move |project, cx| {
336 let mut new_locations = Vec::new();
337 for location in locations {
338 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
339 }
340 new_locations
341 })
342 })
343 }
344}
345
346#[derive(Debug)]
347pub enum ToolCallStatus {
348 /// The tool call hasn't started running yet, but we start showing it to
349 /// the user.
350 Pending,
351 /// The tool call is waiting for confirmation from the user.
352 WaitingForConfirmation {
353 options: Vec<acp::PermissionOption>,
354 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
355 },
356 /// The tool call is currently running.
357 InProgress,
358 /// The tool call completed successfully.
359 Completed,
360 /// The tool call failed.
361 Failed,
362 /// The user rejected the tool call.
363 Rejected,
364 /// The user canceled generation so the tool call was canceled.
365 Canceled,
366}
367
368impl From<acp::ToolCallStatus> for ToolCallStatus {
369 fn from(status: acp::ToolCallStatus) -> Self {
370 match status {
371 acp::ToolCallStatus::Pending => Self::Pending,
372 acp::ToolCallStatus::InProgress => Self::InProgress,
373 acp::ToolCallStatus::Completed => Self::Completed,
374 acp::ToolCallStatus::Failed => Self::Failed,
375 }
376 }
377}
378
379impl Display for ToolCallStatus {
380 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
381 write!(
382 f,
383 "{}",
384 match self {
385 ToolCallStatus::Pending => "Pending",
386 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
387 ToolCallStatus::InProgress => "In Progress",
388 ToolCallStatus::Completed => "Completed",
389 ToolCallStatus::Failed => "Failed",
390 ToolCallStatus::Rejected => "Rejected",
391 ToolCallStatus::Canceled => "Canceled",
392 }
393 )
394 }
395}
396
397#[derive(Debug, PartialEq, Clone)]
398pub enum ContentBlock {
399 Empty,
400 Markdown { markdown: Entity<Markdown> },
401 ResourceLink { resource_link: acp::ResourceLink },
402}
403
404impl ContentBlock {
405 pub fn new(
406 block: acp::ContentBlock,
407 language_registry: &Arc<LanguageRegistry>,
408 cx: &mut App,
409 ) -> Self {
410 let mut this = Self::Empty;
411 this.append(block, language_registry, cx);
412 this
413 }
414
415 pub fn new_combined(
416 blocks: impl IntoIterator<Item = acp::ContentBlock>,
417 language_registry: Arc<LanguageRegistry>,
418 cx: &mut App,
419 ) -> Self {
420 let mut this = Self::Empty;
421 for block in blocks {
422 this.append(block, &language_registry, cx);
423 }
424 this
425 }
426
427 pub fn append(
428 &mut self,
429 block: acp::ContentBlock,
430 language_registry: &Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) {
433 if matches!(self, ContentBlock::Empty) {
434 if let acp::ContentBlock::ResourceLink(resource_link) = block {
435 *self = ContentBlock::ResourceLink { resource_link };
436 return;
437 }
438 }
439
440 let new_content = self.block_string_contents(block);
441
442 match self {
443 ContentBlock::Empty => {
444 *self = Self::create_markdown_block(new_content, language_registry, cx);
445 }
446 ContentBlock::Markdown { markdown } => {
447 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
448 }
449 ContentBlock::ResourceLink { resource_link } => {
450 let existing_content = Self::resource_link_md(&resource_link.uri);
451 let combined = format!("{}\n{}", existing_content, new_content);
452
453 *self = Self::create_markdown_block(combined, language_registry, cx);
454 }
455 }
456 }
457
458 fn create_markdown_block(
459 content: String,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) -> ContentBlock {
463 ContentBlock::Markdown {
464 markdown: cx
465 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
466 }
467 }
468
469 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
470 match block {
471 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
472 acp::ContentBlock::ResourceLink(resource_link) => {
473 Self::resource_link_md(&resource_link.uri)
474 }
475 acp::ContentBlock::Resource(acp::EmbeddedResource {
476 resource:
477 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
478 uri,
479 ..
480 }),
481 ..
482 }) => Self::resource_link_md(&uri),
483 acp::ContentBlock::Image(image) => Self::image_md(&image),
484 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
485 }
486 }
487
488 fn resource_link_md(uri: &str) -> String {
489 if let Some(uri) = MentionUri::parse(uri).log_err() {
490 uri.as_link().to_string()
491 } else {
492 uri.to_string()
493 }
494 }
495
496 fn image_md(_image: &acp::ImageContent) -> String {
497 "`Image`".into()
498 }
499
500 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
501 match self {
502 ContentBlock::Empty => "",
503 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
504 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
505 }
506 }
507
508 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
509 match self {
510 ContentBlock::Empty => None,
511 ContentBlock::Markdown { markdown } => Some(markdown),
512 ContentBlock::ResourceLink { .. } => None,
513 }
514 }
515
516 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
517 match self {
518 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
519 _ => None,
520 }
521 }
522}
523
524#[derive(Debug)]
525pub enum ToolCallContent {
526 ContentBlock(ContentBlock),
527 Diff(Entity<Diff>),
528 Terminal(Entity<Terminal>),
529}
530
531impl ToolCallContent {
532 pub fn from_acp(
533 content: acp::ToolCallContent,
534 language_registry: Arc<LanguageRegistry>,
535 cx: &mut App,
536 ) -> Self {
537 match content {
538 acp::ToolCallContent::Content { content } => {
539 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
540 }
541 acp::ToolCallContent::Diff { diff } => {
542 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
543 }
544 }
545 }
546
547 pub fn to_markdown(&self, cx: &App) -> String {
548 match self {
549 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
550 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
551 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
552 }
553 }
554}
555
556#[derive(Debug, PartialEq)]
557pub enum ToolCallUpdate {
558 UpdateFields(acp::ToolCallUpdate),
559 UpdateDiff(ToolCallUpdateDiff),
560 UpdateTerminal(ToolCallUpdateTerminal),
561}
562
563impl ToolCallUpdate {
564 fn id(&self) -> &acp::ToolCallId {
565 match self {
566 Self::UpdateFields(update) => &update.id,
567 Self::UpdateDiff(diff) => &diff.id,
568 Self::UpdateTerminal(terminal) => &terminal.id,
569 }
570 }
571}
572
573impl From<acp::ToolCallUpdate> for ToolCallUpdate {
574 fn from(update: acp::ToolCallUpdate) -> Self {
575 Self::UpdateFields(update)
576 }
577}
578
579impl From<ToolCallUpdateDiff> for ToolCallUpdate {
580 fn from(diff: ToolCallUpdateDiff) -> Self {
581 Self::UpdateDiff(diff)
582 }
583}
584
585#[derive(Debug, PartialEq)]
586pub struct ToolCallUpdateDiff {
587 pub id: acp::ToolCallId,
588 pub diff: Entity<Diff>,
589}
590
591impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
592 fn from(terminal: ToolCallUpdateTerminal) -> Self {
593 Self::UpdateTerminal(terminal)
594 }
595}
596
597#[derive(Debug, PartialEq)]
598pub struct ToolCallUpdateTerminal {
599 pub id: acp::ToolCallId,
600 pub terminal: Entity<Terminal>,
601}
602
603#[derive(Debug, Default)]
604pub struct Plan {
605 pub entries: Vec<PlanEntry>,
606}
607
608#[derive(Debug)]
609pub struct PlanStats<'a> {
610 pub in_progress_entry: Option<&'a PlanEntry>,
611 pub pending: u32,
612 pub completed: u32,
613}
614
615impl Plan {
616 pub fn is_empty(&self) -> bool {
617 self.entries.is_empty()
618 }
619
620 pub fn stats(&self) -> PlanStats<'_> {
621 let mut stats = PlanStats {
622 in_progress_entry: None,
623 pending: 0,
624 completed: 0,
625 };
626
627 for entry in &self.entries {
628 match &entry.status {
629 acp::PlanEntryStatus::Pending => {
630 stats.pending += 1;
631 }
632 acp::PlanEntryStatus::InProgress => {
633 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
634 }
635 acp::PlanEntryStatus::Completed => {
636 stats.completed += 1;
637 }
638 }
639 }
640
641 stats
642 }
643}
644
645#[derive(Debug)]
646pub struct PlanEntry {
647 pub content: Entity<Markdown>,
648 pub priority: acp::PlanEntryPriority,
649 pub status: acp::PlanEntryStatus,
650}
651
652impl PlanEntry {
653 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
654 Self {
655 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
656 priority: entry.priority,
657 status: entry.status,
658 }
659 }
660}
661
662#[derive(Debug, Clone)]
663pub struct RetryStatus {
664 pub last_error: SharedString,
665 pub attempt: usize,
666 pub max_attempts: usize,
667 pub started_at: Instant,
668 pub duration: Duration,
669}
670
671pub struct AcpThread {
672 title: SharedString,
673 entries: Vec<AgentThreadEntry>,
674 plan: Plan,
675 project: Entity<Project>,
676 action_log: Entity<ActionLog>,
677 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
678 send_task: Option<Task<()>>,
679 connection: Rc<dyn AgentConnection>,
680 session_id: acp::SessionId,
681}
682
683#[derive(Debug)]
684pub enum AcpThreadEvent {
685 NewEntry,
686 EntryUpdated(usize),
687 EntriesRemoved(Range<usize>),
688 ToolAuthorizationRequired,
689 Retry(RetryStatus),
690 Stopped,
691 Error,
692 ServerExited(ExitStatus),
693}
694
695impl EventEmitter<AcpThreadEvent> for AcpThread {}
696
697#[derive(PartialEq, Eq)]
698pub enum ThreadStatus {
699 Idle,
700 WaitingForToolConfirmation,
701 Generating,
702}
703
704#[derive(Debug, Clone)]
705pub enum LoadError {
706 Unsupported {
707 error_message: SharedString,
708 upgrade_message: SharedString,
709 upgrade_command: String,
710 },
711 Exited(i32),
712 Other(SharedString),
713}
714
715impl Display for LoadError {
716 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
717 match self {
718 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
719 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
720 LoadError::Other(msg) => write!(f, "{}", msg),
721 }
722 }
723}
724
725impl Error for LoadError {}
726
727impl AcpThread {
728 pub fn new(
729 title: impl Into<SharedString>,
730 connection: Rc<dyn AgentConnection>,
731 project: Entity<Project>,
732 session_id: acp::SessionId,
733 cx: &mut Context<Self>,
734 ) -> Self {
735 let action_log = cx.new(|_| ActionLog::new(project.clone()));
736
737 Self {
738 action_log,
739 shared_buffers: Default::default(),
740 entries: Default::default(),
741 plan: Default::default(),
742 title: title.into(),
743 project,
744 send_task: None,
745 connection,
746 session_id,
747 }
748 }
749
750 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
751 &self.connection
752 }
753
754 pub fn action_log(&self) -> &Entity<ActionLog> {
755 &self.action_log
756 }
757
758 pub fn project(&self) -> &Entity<Project> {
759 &self.project
760 }
761
762 pub fn title(&self) -> SharedString {
763 self.title.clone()
764 }
765
766 pub fn entries(&self) -> &[AgentThreadEntry] {
767 &self.entries
768 }
769
770 pub fn session_id(&self) -> &acp::SessionId {
771 &self.session_id
772 }
773
774 pub fn status(&self) -> ThreadStatus {
775 if self.send_task.is_some() {
776 if self.waiting_for_tool_confirmation() {
777 ThreadStatus::WaitingForToolConfirmation
778 } else {
779 ThreadStatus::Generating
780 }
781 } else {
782 ThreadStatus::Idle
783 }
784 }
785
786 pub fn has_pending_edit_tool_calls(&self) -> bool {
787 for entry in self.entries.iter().rev() {
788 match entry {
789 AgentThreadEntry::UserMessage(_) => return false,
790 AgentThreadEntry::ToolCall(
791 call @ ToolCall {
792 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
793 ..
794 },
795 ) if call.diffs().next().is_some() => {
796 return true;
797 }
798 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
799 }
800 }
801
802 false
803 }
804
805 pub fn used_tools_since_last_user_message(&self) -> bool {
806 for entry in self.entries.iter().rev() {
807 match entry {
808 AgentThreadEntry::UserMessage(..) => return false,
809 AgentThreadEntry::AssistantMessage(..) => continue,
810 AgentThreadEntry::ToolCall(..) => return true,
811 }
812 }
813
814 false
815 }
816
817 pub fn handle_session_update(
818 &mut self,
819 update: acp::SessionUpdate,
820 cx: &mut Context<Self>,
821 ) -> Result<(), acp::Error> {
822 match update {
823 acp::SessionUpdate::UserMessageChunk { content } => {
824 self.push_user_content_block(None, content, cx);
825 }
826 acp::SessionUpdate::AgentMessageChunk { content } => {
827 self.push_assistant_content_block(content, false, cx);
828 }
829 acp::SessionUpdate::AgentThoughtChunk { content } => {
830 self.push_assistant_content_block(content, true, cx);
831 }
832 acp::SessionUpdate::ToolCall(tool_call) => {
833 self.upsert_tool_call(tool_call, cx)?;
834 }
835 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
836 self.update_tool_call(tool_call_update, cx)?;
837 }
838 acp::SessionUpdate::Plan(plan) => {
839 self.update_plan(plan, cx);
840 }
841 }
842 Ok(())
843 }
844
845 pub fn push_user_content_block(
846 &mut self,
847 message_id: Option<UserMessageId>,
848 chunk: acp::ContentBlock,
849 cx: &mut Context<Self>,
850 ) {
851 let language_registry = self.project.read(cx).languages().clone();
852 let entries_len = self.entries.len();
853
854 if let Some(last_entry) = self.entries.last_mut()
855 && let AgentThreadEntry::UserMessage(UserMessage {
856 id,
857 content,
858 chunks,
859 ..
860 }) = last_entry
861 {
862 *id = message_id.or(id.take());
863 content.append(chunk.clone(), &language_registry, cx);
864 chunks.push(chunk);
865 let idx = entries_len - 1;
866 cx.emit(AcpThreadEvent::EntryUpdated(idx));
867 } else {
868 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
869 self.push_entry(
870 AgentThreadEntry::UserMessage(UserMessage {
871 id: message_id,
872 content,
873 chunks: vec![chunk],
874 checkpoint: None,
875 }),
876 cx,
877 );
878 }
879 }
880
881 pub fn push_assistant_content_block(
882 &mut self,
883 chunk: acp::ContentBlock,
884 is_thought: bool,
885 cx: &mut Context<Self>,
886 ) {
887 let language_registry = self.project.read(cx).languages().clone();
888 let entries_len = self.entries.len();
889 if let Some(last_entry) = self.entries.last_mut()
890 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
891 {
892 let idx = entries_len - 1;
893 cx.emit(AcpThreadEvent::EntryUpdated(idx));
894 match (chunks.last_mut(), is_thought) {
895 (Some(AssistantMessageChunk::Message { block }), false)
896 | (Some(AssistantMessageChunk::Thought { block }), true) => {
897 block.append(chunk, &language_registry, cx)
898 }
899 _ => {
900 let block = ContentBlock::new(chunk, &language_registry, cx);
901 if is_thought {
902 chunks.push(AssistantMessageChunk::Thought { block })
903 } else {
904 chunks.push(AssistantMessageChunk::Message { block })
905 }
906 }
907 }
908 } else {
909 let block = ContentBlock::new(chunk, &language_registry, cx);
910 let chunk = if is_thought {
911 AssistantMessageChunk::Thought { block }
912 } else {
913 AssistantMessageChunk::Message { block }
914 };
915
916 self.push_entry(
917 AgentThreadEntry::AssistantMessage(AssistantMessage {
918 chunks: vec![chunk],
919 }),
920 cx,
921 );
922 }
923 }
924
925 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
926 self.entries.push(entry);
927 cx.emit(AcpThreadEvent::NewEntry);
928 }
929
930 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
931 cx.emit(AcpThreadEvent::Retry(status));
932 }
933
934 pub fn update_tool_call(
935 &mut self,
936 update: impl Into<ToolCallUpdate>,
937 cx: &mut Context<Self>,
938 ) -> Result<()> {
939 let update = update.into();
940 let languages = self.project.read(cx).languages().clone();
941
942 let (ix, current_call) = self
943 .tool_call_mut(update.id())
944 .context("Tool call not found")?;
945 match update {
946 ToolCallUpdate::UpdateFields(update) => {
947 let location_updated = update.fields.locations.is_some();
948 current_call.update_fields(update.fields, languages, cx);
949 if location_updated {
950 self.resolve_locations(update.id.clone(), cx);
951 }
952 }
953 ToolCallUpdate::UpdateDiff(update) => {
954 current_call.content.clear();
955 current_call
956 .content
957 .push(ToolCallContent::Diff(update.diff));
958 }
959 ToolCallUpdate::UpdateTerminal(update) => {
960 current_call.content.clear();
961 current_call
962 .content
963 .push(ToolCallContent::Terminal(update.terminal));
964 }
965 }
966
967 cx.emit(AcpThreadEvent::EntryUpdated(ix));
968
969 Ok(())
970 }
971
972 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
973 pub fn upsert_tool_call(
974 &mut self,
975 tool_call: acp::ToolCall,
976 cx: &mut Context<Self>,
977 ) -> Result<(), acp::Error> {
978 let status = tool_call.status.into();
979 self.upsert_tool_call_inner(tool_call.into(), status, cx)
980 }
981
982 /// Fails if id does not match an existing entry.
983 pub fn upsert_tool_call_inner(
984 &mut self,
985 tool_call_update: acp::ToolCallUpdate,
986 status: ToolCallStatus,
987 cx: &mut Context<Self>,
988 ) -> Result<(), acp::Error> {
989 let language_registry = self.project.read(cx).languages().clone();
990 let id = tool_call_update.id.clone();
991
992 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
993 current_call.update_fields(tool_call_update.fields, language_registry, cx);
994 current_call.status = status;
995
996 cx.emit(AcpThreadEvent::EntryUpdated(ix));
997 } else {
998 let call =
999 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1000 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1001 };
1002
1003 self.resolve_locations(id, cx);
1004 Ok(())
1005 }
1006
1007 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1008 // The tool call we are looking for is typically the last one, or very close to the end.
1009 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1010 self.entries
1011 .iter_mut()
1012 .enumerate()
1013 .rev()
1014 .find_map(|(index, tool_call)| {
1015 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1016 && &tool_call.id == id
1017 {
1018 Some((index, tool_call))
1019 } else {
1020 None
1021 }
1022 })
1023 }
1024
1025 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1026 let project = self.project.clone();
1027 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1028 return;
1029 };
1030 let task = tool_call.resolve_locations(project, cx);
1031 cx.spawn(async move |this, cx| {
1032 let resolved_locations = task.await;
1033 this.update(cx, |this, cx| {
1034 let project = this.project.clone();
1035 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1036 return;
1037 };
1038 if let Some(Some(location)) = resolved_locations.last() {
1039 project.update(cx, |project, cx| {
1040 if let Some(agent_location) = project.agent_location() {
1041 let should_ignore = agent_location.buffer == location.buffer
1042 && location
1043 .buffer
1044 .update(cx, |buffer, _| {
1045 let snapshot = buffer.snapshot();
1046 let old_position =
1047 agent_location.position.to_point(&snapshot);
1048 let new_position = location.position.to_point(&snapshot);
1049 // ignore this so that when we get updates from the edit tool
1050 // the position doesn't reset to the startof line
1051 old_position.row == new_position.row
1052 && old_position.column > new_position.column
1053 })
1054 .ok()
1055 .unwrap_or_default();
1056 if !should_ignore {
1057 project.set_agent_location(Some(location.clone()), cx);
1058 }
1059 }
1060 });
1061 }
1062 if tool_call.resolved_locations != resolved_locations {
1063 tool_call.resolved_locations = resolved_locations;
1064 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1065 }
1066 })
1067 })
1068 .detach();
1069 }
1070
1071 pub fn request_tool_call_authorization(
1072 &mut self,
1073 tool_call: acp::ToolCallUpdate,
1074 options: Vec<acp::PermissionOption>,
1075 cx: &mut Context<Self>,
1076 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1077 let (tx, rx) = oneshot::channel();
1078
1079 let status = ToolCallStatus::WaitingForConfirmation {
1080 options,
1081 respond_tx: tx,
1082 };
1083
1084 self.upsert_tool_call_inner(tool_call, status, cx)?;
1085 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1086 Ok(rx)
1087 }
1088
1089 pub fn authorize_tool_call(
1090 &mut self,
1091 id: acp::ToolCallId,
1092 option_id: acp::PermissionOptionId,
1093 option_kind: acp::PermissionOptionKind,
1094 cx: &mut Context<Self>,
1095 ) {
1096 let Some((ix, call)) = self.tool_call_mut(&id) else {
1097 return;
1098 };
1099
1100 let new_status = match option_kind {
1101 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1102 ToolCallStatus::Rejected
1103 }
1104 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1105 ToolCallStatus::InProgress
1106 }
1107 };
1108
1109 let curr_status = mem::replace(&mut call.status, new_status);
1110
1111 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1112 respond_tx.send(option_id).log_err();
1113 } else if cfg!(debug_assertions) {
1114 panic!("tried to authorize an already authorized tool call");
1115 }
1116
1117 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1118 }
1119
1120 /// Returns true if the last turn is awaiting tool authorization
1121 pub fn waiting_for_tool_confirmation(&self) -> bool {
1122 for entry in self.entries.iter().rev() {
1123 match &entry {
1124 AgentThreadEntry::ToolCall(call) => match call.status {
1125 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1126 ToolCallStatus::Pending
1127 | ToolCallStatus::InProgress
1128 | ToolCallStatus::Completed
1129 | ToolCallStatus::Failed
1130 | ToolCallStatus::Rejected
1131 | ToolCallStatus::Canceled => continue,
1132 },
1133 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1134 // Reached the beginning of the turn
1135 return false;
1136 }
1137 }
1138 }
1139 false
1140 }
1141
1142 pub fn plan(&self) -> &Plan {
1143 &self.plan
1144 }
1145
1146 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1147 let new_entries_len = request.entries.len();
1148 let mut new_entries = request.entries.into_iter();
1149
1150 // Reuse existing markdown to prevent flickering
1151 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1152 let PlanEntry {
1153 content,
1154 priority,
1155 status,
1156 } = old;
1157 content.update(cx, |old, cx| {
1158 old.replace(new.content, cx);
1159 });
1160 *priority = new.priority;
1161 *status = new.status;
1162 }
1163 for new in new_entries {
1164 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1165 }
1166 self.plan.entries.truncate(new_entries_len);
1167
1168 cx.notify();
1169 }
1170
1171 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1172 self.plan
1173 .entries
1174 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1175 cx.notify();
1176 }
1177
1178 #[cfg(any(test, feature = "test-support"))]
1179 pub fn send_raw(
1180 &mut self,
1181 message: &str,
1182 cx: &mut Context<Self>,
1183 ) -> BoxFuture<'static, Result<()>> {
1184 self.send(
1185 vec![acp::ContentBlock::Text(acp::TextContent {
1186 text: message.to_string(),
1187 annotations: None,
1188 })],
1189 cx,
1190 )
1191 }
1192
1193 pub fn send(
1194 &mut self,
1195 message: Vec<acp::ContentBlock>,
1196 cx: &mut Context<Self>,
1197 ) -> BoxFuture<'static, Result<()>> {
1198 let block = ContentBlock::new_combined(
1199 message.clone(),
1200 self.project.read(cx).languages().clone(),
1201 cx,
1202 );
1203 let request = acp::PromptRequest {
1204 prompt: message.clone(),
1205 session_id: self.session_id.clone(),
1206 };
1207 let git_store = self.project.read(cx).git_store().clone();
1208
1209 let message_id = if self
1210 .connection
1211 .session_editor(&self.session_id, cx)
1212 .is_some()
1213 {
1214 Some(UserMessageId::new())
1215 } else {
1216 None
1217 };
1218
1219 self.run_turn(cx, async move |this, cx| {
1220 this.update(cx, |this, cx| {
1221 this.push_entry(
1222 AgentThreadEntry::UserMessage(UserMessage {
1223 id: message_id.clone(),
1224 content: block,
1225 chunks: message,
1226 checkpoint: None,
1227 }),
1228 cx,
1229 );
1230 })
1231 .ok();
1232
1233 let old_checkpoint = git_store
1234 .update(cx, |git, cx| git.checkpoint(cx))?
1235 .await
1236 .context("failed to get old checkpoint")
1237 .log_err();
1238 this.update(cx, |this, cx| {
1239 if let Some((_ix, message)) = this.last_user_message() {
1240 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1241 git_checkpoint,
1242 show: false,
1243 });
1244 }
1245 this.connection.prompt(message_id, request, cx)
1246 })?
1247 .await
1248 })
1249 }
1250
1251 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1252 self.run_turn(cx, async move |this, cx| {
1253 this.update(cx, |this, cx| {
1254 this.connection
1255 .resume(&this.session_id, cx)
1256 .map(|resume| resume.run(cx))
1257 })?
1258 .context("resuming a session is not supported")?
1259 .await
1260 })
1261 }
1262
1263 fn run_turn(
1264 &mut self,
1265 cx: &mut Context<Self>,
1266 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1267 ) -> BoxFuture<'static, Result<()>> {
1268 self.clear_completed_plan_entries(cx);
1269
1270 let (tx, rx) = oneshot::channel();
1271 let cancel_task = self.cancel(cx);
1272
1273 self.send_task = Some(cx.spawn(async move |this, cx| {
1274 cancel_task.await;
1275 tx.send(f(this, cx).await).ok();
1276 }));
1277
1278 cx.spawn(async move |this, cx| {
1279 let response = rx.await;
1280
1281 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1282 .await?;
1283
1284 this.update(cx, |this, cx| {
1285 this.project
1286 .update(cx, |project, cx| project.set_agent_location(None, cx));
1287 match response {
1288 Ok(Err(e)) => {
1289 this.send_task.take();
1290 cx.emit(AcpThreadEvent::Error);
1291 Err(e)
1292 }
1293 result => {
1294 let canceled = matches!(
1295 result,
1296 Ok(Ok(acp::PromptResponse {
1297 stop_reason: acp::StopReason::Canceled
1298 }))
1299 );
1300
1301 // We only take the task if the current prompt wasn't canceled.
1302 //
1303 // This prompt may have been canceled because another one was sent
1304 // while it was still generating. In these cases, dropping `send_task`
1305 // would cause the next generation to be canceled.
1306 if !canceled {
1307 this.send_task.take();
1308 }
1309
1310 cx.emit(AcpThreadEvent::Stopped);
1311 Ok(())
1312 }
1313 }
1314 })?
1315 })
1316 .boxed()
1317 }
1318
1319 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1320 let Some(send_task) = self.send_task.take() else {
1321 return Task::ready(());
1322 };
1323
1324 for entry in self.entries.iter_mut() {
1325 if let AgentThreadEntry::ToolCall(call) = entry {
1326 let cancel = matches!(
1327 call.status,
1328 ToolCallStatus::Pending
1329 | ToolCallStatus::WaitingForConfirmation { .. }
1330 | ToolCallStatus::InProgress
1331 );
1332
1333 if cancel {
1334 call.status = ToolCallStatus::Canceled;
1335 }
1336 }
1337 }
1338
1339 self.connection.cancel(&self.session_id, cx);
1340
1341 // Wait for the send task to complete
1342 cx.foreground_executor().spawn(send_task)
1343 }
1344
1345 /// Rewinds this thread to before the entry at `index`, removing it and all
1346 /// subsequent entries while reverting any changes made from that point.
1347 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1348 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1349 return Task::ready(Err(anyhow!("not supported")));
1350 };
1351 let Some(message) = self.user_message(&id) else {
1352 return Task::ready(Err(anyhow!("message not found")));
1353 };
1354
1355 let checkpoint = message
1356 .checkpoint
1357 .as_ref()
1358 .map(|c| c.git_checkpoint.clone());
1359
1360 let git_store = self.project.read(cx).git_store().clone();
1361 cx.spawn(async move |this, cx| {
1362 if let Some(checkpoint) = checkpoint {
1363 git_store
1364 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1365 .await?;
1366 }
1367
1368 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1369 .await?;
1370 this.update(cx, |this, cx| {
1371 if let Some((ix, _)) = this.user_message_mut(&id) {
1372 let range = ix..this.entries.len();
1373 this.entries.truncate(ix);
1374 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1375 }
1376 })
1377 })
1378 }
1379
1380 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1381 let git_store = self.project.read(cx).git_store().clone();
1382
1383 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1384 if let Some(checkpoint) = message.checkpoint.as_ref() {
1385 checkpoint.git_checkpoint.clone()
1386 } else {
1387 return Task::ready(Ok(()));
1388 }
1389 } else {
1390 return Task::ready(Ok(()));
1391 };
1392
1393 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1394 cx.spawn(async move |this, cx| {
1395 let new_checkpoint = new_checkpoint
1396 .await
1397 .context("failed to get new checkpoint")
1398 .log_err();
1399 if let Some(new_checkpoint) = new_checkpoint {
1400 let equal = git_store
1401 .update(cx, |git, cx| {
1402 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1403 })?
1404 .await
1405 .unwrap_or(true);
1406 this.update(cx, |this, cx| {
1407 let (ix, message) = this.last_user_message().context("no user message")?;
1408 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1409 checkpoint.show = !equal;
1410 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1411 anyhow::Ok(())
1412 })??;
1413 }
1414
1415 Ok(())
1416 })
1417 }
1418
1419 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1420 self.entries
1421 .iter_mut()
1422 .enumerate()
1423 .rev()
1424 .find_map(|(ix, entry)| {
1425 if let AgentThreadEntry::UserMessage(message) = entry {
1426 Some((ix, message))
1427 } else {
1428 None
1429 }
1430 })
1431 }
1432
1433 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1434 self.entries.iter().find_map(|entry| {
1435 if let AgentThreadEntry::UserMessage(message) = entry {
1436 if message.id.as_ref() == Some(id) {
1437 Some(message)
1438 } else {
1439 None
1440 }
1441 } else {
1442 None
1443 }
1444 })
1445 }
1446
1447 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1448 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1449 if let AgentThreadEntry::UserMessage(message) = entry {
1450 if message.id.as_ref() == Some(id) {
1451 Some((ix, message))
1452 } else {
1453 None
1454 }
1455 } else {
1456 None
1457 }
1458 })
1459 }
1460
1461 pub fn read_text_file(
1462 &self,
1463 path: PathBuf,
1464 line: Option<u32>,
1465 limit: Option<u32>,
1466 reuse_shared_snapshot: bool,
1467 cx: &mut Context<Self>,
1468 ) -> Task<Result<String>> {
1469 let project = self.project.clone();
1470 let action_log = self.action_log.clone();
1471 cx.spawn(async move |this, cx| {
1472 let load = project.update(cx, |project, cx| {
1473 let path = project
1474 .project_path_for_absolute_path(&path, cx)
1475 .context("invalid path")?;
1476 anyhow::Ok(project.open_buffer(path, cx))
1477 });
1478 let buffer = load??.await?;
1479
1480 let snapshot = if reuse_shared_snapshot {
1481 this.read_with(cx, |this, _| {
1482 this.shared_buffers.get(&buffer.clone()).cloned()
1483 })
1484 .log_err()
1485 .flatten()
1486 } else {
1487 None
1488 };
1489
1490 let snapshot = if let Some(snapshot) = snapshot {
1491 snapshot
1492 } else {
1493 action_log.update(cx, |action_log, cx| {
1494 action_log.buffer_read(buffer.clone(), cx);
1495 })?;
1496 project.update(cx, |project, cx| {
1497 let position = buffer
1498 .read(cx)
1499 .snapshot()
1500 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1501 project.set_agent_location(
1502 Some(AgentLocation {
1503 buffer: buffer.downgrade(),
1504 position,
1505 }),
1506 cx,
1507 );
1508 })?;
1509
1510 buffer.update(cx, |buffer, _| buffer.snapshot())?
1511 };
1512
1513 this.update(cx, |this, _| {
1514 let text = snapshot.text();
1515 this.shared_buffers.insert(buffer.clone(), snapshot);
1516 if line.is_none() && limit.is_none() {
1517 return Ok(text);
1518 }
1519 let limit = limit.unwrap_or(u32::MAX) as usize;
1520 let Some(line) = line else {
1521 return Ok(text.lines().take(limit).collect::<String>());
1522 };
1523
1524 let count = text.lines().count();
1525 if count < line as usize {
1526 anyhow::bail!("There are only {} lines", count);
1527 }
1528 Ok(text
1529 .lines()
1530 .skip(line as usize + 1)
1531 .take(limit)
1532 .collect::<String>())
1533 })?
1534 })
1535 }
1536
1537 pub fn write_text_file(
1538 &self,
1539 path: PathBuf,
1540 content: String,
1541 cx: &mut Context<Self>,
1542 ) -> Task<Result<()>> {
1543 let project = self.project.clone();
1544 let action_log = self.action_log.clone();
1545 cx.spawn(async move |this, cx| {
1546 let load = project.update(cx, |project, cx| {
1547 let path = project
1548 .project_path_for_absolute_path(&path, cx)
1549 .context("invalid path")?;
1550 anyhow::Ok(project.open_buffer(path, cx))
1551 });
1552 let buffer = load??.await?;
1553 let snapshot = this.update(cx, |this, cx| {
1554 this.shared_buffers
1555 .get(&buffer)
1556 .cloned()
1557 .unwrap_or_else(|| buffer.read(cx).snapshot())
1558 })?;
1559 let edits = cx
1560 .background_executor()
1561 .spawn(async move {
1562 let old_text = snapshot.text();
1563 text_diff(old_text.as_str(), &content)
1564 .into_iter()
1565 .map(|(range, replacement)| {
1566 (
1567 snapshot.anchor_after(range.start)
1568 ..snapshot.anchor_before(range.end),
1569 replacement,
1570 )
1571 })
1572 .collect::<Vec<_>>()
1573 })
1574 .await;
1575 cx.update(|cx| {
1576 project.update(cx, |project, cx| {
1577 project.set_agent_location(
1578 Some(AgentLocation {
1579 buffer: buffer.downgrade(),
1580 position: edits
1581 .last()
1582 .map(|(range, _)| range.end)
1583 .unwrap_or(Anchor::MIN),
1584 }),
1585 cx,
1586 );
1587 });
1588
1589 action_log.update(cx, |action_log, cx| {
1590 action_log.buffer_read(buffer.clone(), cx);
1591 });
1592 buffer.update(cx, |buffer, cx| {
1593 buffer.edit(edits, None, cx);
1594 });
1595 action_log.update(cx, |action_log, cx| {
1596 action_log.buffer_edited(buffer.clone(), cx);
1597 });
1598 })?;
1599 project
1600 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1601 .await
1602 })
1603 }
1604
1605 pub fn to_markdown(&self, cx: &App) -> String {
1606 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1607 }
1608
1609 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1610 cx.emit(AcpThreadEvent::ServerExited(status));
1611 }
1612}
1613
1614fn markdown_for_raw_output(
1615 raw_output: &serde_json::Value,
1616 language_registry: &Arc<LanguageRegistry>,
1617 cx: &mut App,
1618) -> Option<Entity<Markdown>> {
1619 match raw_output {
1620 serde_json::Value::Null => None,
1621 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1622 Markdown::new(
1623 value.to_string().into(),
1624 Some(language_registry.clone()),
1625 None,
1626 cx,
1627 )
1628 })),
1629 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1630 Markdown::new(
1631 value.to_string().into(),
1632 Some(language_registry.clone()),
1633 None,
1634 cx,
1635 )
1636 })),
1637 serde_json::Value::String(value) => Some(cx.new(|cx| {
1638 Markdown::new(
1639 value.clone().into(),
1640 Some(language_registry.clone()),
1641 None,
1642 cx,
1643 )
1644 })),
1645 value => Some(cx.new(|cx| {
1646 Markdown::new(
1647 format!("```json\n{}\n```", value).into(),
1648 Some(language_registry.clone()),
1649 None,
1650 cx,
1651 )
1652 })),
1653 }
1654}
1655
1656#[cfg(test)]
1657mod tests {
1658 use super::*;
1659 use anyhow::anyhow;
1660 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1661 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1662 use indoc::indoc;
1663 use project::{FakeFs, Fs};
1664 use rand::Rng as _;
1665 use serde_json::json;
1666 use settings::SettingsStore;
1667 use smol::stream::StreamExt as _;
1668 use std::{
1669 any::Any,
1670 cell::RefCell,
1671 path::Path,
1672 rc::Rc,
1673 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1674 time::Duration,
1675 };
1676 use util::path;
1677
1678 fn init_test(cx: &mut TestAppContext) {
1679 env_logger::try_init().ok();
1680 cx.update(|cx| {
1681 let settings_store = SettingsStore::test(cx);
1682 cx.set_global(settings_store);
1683 Project::init_settings(cx);
1684 language::init(cx);
1685 });
1686 }
1687
1688 #[gpui::test]
1689 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1690 init_test(cx);
1691
1692 let fs = FakeFs::new(cx.executor());
1693 let project = Project::test(fs, [], cx).await;
1694 let connection = Rc::new(FakeAgentConnection::new());
1695 let thread = cx
1696 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1697 .await
1698 .unwrap();
1699
1700 // Test creating a new user message
1701 thread.update(cx, |thread, cx| {
1702 thread.push_user_content_block(
1703 None,
1704 acp::ContentBlock::Text(acp::TextContent {
1705 annotations: None,
1706 text: "Hello, ".to_string(),
1707 }),
1708 cx,
1709 );
1710 });
1711
1712 thread.update(cx, |thread, cx| {
1713 assert_eq!(thread.entries.len(), 1);
1714 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1715 assert_eq!(user_msg.id, None);
1716 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1717 } else {
1718 panic!("Expected UserMessage");
1719 }
1720 });
1721
1722 // Test appending to existing user message
1723 let message_1_id = UserMessageId::new();
1724 thread.update(cx, |thread, cx| {
1725 thread.push_user_content_block(
1726 Some(message_1_id.clone()),
1727 acp::ContentBlock::Text(acp::TextContent {
1728 annotations: None,
1729 text: "world!".to_string(),
1730 }),
1731 cx,
1732 );
1733 });
1734
1735 thread.update(cx, |thread, cx| {
1736 assert_eq!(thread.entries.len(), 1);
1737 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1738 assert_eq!(user_msg.id, Some(message_1_id));
1739 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1740 } else {
1741 panic!("Expected UserMessage");
1742 }
1743 });
1744
1745 // Test creating new user message after assistant message
1746 thread.update(cx, |thread, cx| {
1747 thread.push_assistant_content_block(
1748 acp::ContentBlock::Text(acp::TextContent {
1749 annotations: None,
1750 text: "Assistant response".to_string(),
1751 }),
1752 false,
1753 cx,
1754 );
1755 });
1756
1757 let message_2_id = UserMessageId::new();
1758 thread.update(cx, |thread, cx| {
1759 thread.push_user_content_block(
1760 Some(message_2_id.clone()),
1761 acp::ContentBlock::Text(acp::TextContent {
1762 annotations: None,
1763 text: "New user message".to_string(),
1764 }),
1765 cx,
1766 );
1767 });
1768
1769 thread.update(cx, |thread, cx| {
1770 assert_eq!(thread.entries.len(), 3);
1771 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1772 assert_eq!(user_msg.id, Some(message_2_id));
1773 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1774 } else {
1775 panic!("Expected UserMessage at index 2");
1776 }
1777 });
1778 }
1779
1780 #[gpui::test]
1781 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1782 init_test(cx);
1783
1784 let fs = FakeFs::new(cx.executor());
1785 let project = Project::test(fs, [], cx).await;
1786 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1787 |_, thread, mut cx| {
1788 async move {
1789 thread.update(&mut cx, |thread, cx| {
1790 thread
1791 .handle_session_update(
1792 acp::SessionUpdate::AgentThoughtChunk {
1793 content: "Thinking ".into(),
1794 },
1795 cx,
1796 )
1797 .unwrap();
1798 thread
1799 .handle_session_update(
1800 acp::SessionUpdate::AgentThoughtChunk {
1801 content: "hard!".into(),
1802 },
1803 cx,
1804 )
1805 .unwrap();
1806 })?;
1807 Ok(acp::PromptResponse {
1808 stop_reason: acp::StopReason::EndTurn,
1809 })
1810 }
1811 .boxed_local()
1812 },
1813 ));
1814
1815 let thread = cx
1816 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1817 .await
1818 .unwrap();
1819
1820 thread
1821 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1822 .await
1823 .unwrap();
1824
1825 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1826 assert_eq!(
1827 output,
1828 indoc! {r#"
1829 ## User
1830
1831 Hello from Zed!
1832
1833 ## Assistant
1834
1835 <thinking>
1836 Thinking hard!
1837 </thinking>
1838
1839 "#}
1840 );
1841 }
1842
1843 #[gpui::test]
1844 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1845 init_test(cx);
1846
1847 let fs = FakeFs::new(cx.executor());
1848 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1849 .await;
1850 let project = Project::test(fs.clone(), [], cx).await;
1851 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1852 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1853 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1854 move |_, thread, mut cx| {
1855 let read_file_tx = read_file_tx.clone();
1856 async move {
1857 let content = thread
1858 .update(&mut cx, |thread, cx| {
1859 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1860 })
1861 .unwrap()
1862 .await
1863 .unwrap();
1864 assert_eq!(content, "one\ntwo\nthree\n");
1865 read_file_tx.take().unwrap().send(()).unwrap();
1866 thread
1867 .update(&mut cx, |thread, cx| {
1868 thread.write_text_file(
1869 path!("/tmp/foo").into(),
1870 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1871 cx,
1872 )
1873 })
1874 .unwrap()
1875 .await
1876 .unwrap();
1877 Ok(acp::PromptResponse {
1878 stop_reason: acp::StopReason::EndTurn,
1879 })
1880 }
1881 .boxed_local()
1882 },
1883 ));
1884
1885 let (worktree, pathbuf) = project
1886 .update(cx, |project, cx| {
1887 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1888 })
1889 .await
1890 .unwrap();
1891 let buffer = project
1892 .update(cx, |project, cx| {
1893 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1894 })
1895 .await
1896 .unwrap();
1897
1898 let thread = cx
1899 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1900 .await
1901 .unwrap();
1902
1903 let request = thread.update(cx, |thread, cx| {
1904 thread.send_raw("Extend the count in /tmp/foo", cx)
1905 });
1906 read_file_rx.await.ok();
1907 buffer.update(cx, |buffer, cx| {
1908 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1909 });
1910 cx.run_until_parked();
1911 assert_eq!(
1912 buffer.read_with(cx, |buffer, _| buffer.text()),
1913 "zero\none\ntwo\nthree\nfour\nfive\n"
1914 );
1915 assert_eq!(
1916 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1917 "zero\none\ntwo\nthree\nfour\nfive\n"
1918 );
1919 request.await.unwrap();
1920 }
1921
1922 #[gpui::test]
1923 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1924 init_test(cx);
1925
1926 let fs = FakeFs::new(cx.executor());
1927 let project = Project::test(fs, [], cx).await;
1928 let id = acp::ToolCallId("test".into());
1929
1930 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1931 let id = id.clone();
1932 move |_, thread, mut cx| {
1933 let id = id.clone();
1934 async move {
1935 thread
1936 .update(&mut cx, |thread, cx| {
1937 thread.handle_session_update(
1938 acp::SessionUpdate::ToolCall(acp::ToolCall {
1939 id: id.clone(),
1940 title: "Label".into(),
1941 kind: acp::ToolKind::Fetch,
1942 status: acp::ToolCallStatus::InProgress,
1943 content: vec![],
1944 locations: vec![],
1945 raw_input: None,
1946 raw_output: None,
1947 }),
1948 cx,
1949 )
1950 })
1951 .unwrap()
1952 .unwrap();
1953 Ok(acp::PromptResponse {
1954 stop_reason: acp::StopReason::EndTurn,
1955 })
1956 }
1957 .boxed_local()
1958 }
1959 }));
1960
1961 let thread = cx
1962 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1963 .await
1964 .unwrap();
1965
1966 let request = thread.update(cx, |thread, cx| {
1967 thread.send_raw("Fetch https://example.com", cx)
1968 });
1969
1970 run_until_first_tool_call(&thread, cx).await;
1971
1972 thread.read_with(cx, |thread, _| {
1973 assert!(matches!(
1974 thread.entries[1],
1975 AgentThreadEntry::ToolCall(ToolCall {
1976 status: ToolCallStatus::InProgress,
1977 ..
1978 })
1979 ));
1980 });
1981
1982 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1983
1984 thread.read_with(cx, |thread, _| {
1985 assert!(matches!(
1986 &thread.entries[1],
1987 AgentThreadEntry::ToolCall(ToolCall {
1988 status: ToolCallStatus::Canceled,
1989 ..
1990 })
1991 ));
1992 });
1993
1994 thread
1995 .update(cx, |thread, cx| {
1996 thread.handle_session_update(
1997 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1998 id,
1999 fields: acp::ToolCallUpdateFields {
2000 status: Some(acp::ToolCallStatus::Completed),
2001 ..Default::default()
2002 },
2003 }),
2004 cx,
2005 )
2006 })
2007 .unwrap();
2008
2009 request.await.unwrap();
2010
2011 thread.read_with(cx, |thread, _| {
2012 assert!(matches!(
2013 thread.entries[1],
2014 AgentThreadEntry::ToolCall(ToolCall {
2015 status: ToolCallStatus::Completed,
2016 ..
2017 })
2018 ));
2019 });
2020 }
2021
2022 #[gpui::test]
2023 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2024 init_test(cx);
2025 let fs = FakeFs::new(cx.background_executor.clone());
2026 fs.insert_tree(path!("/test"), json!({})).await;
2027 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2028
2029 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2030 move |_, thread, mut cx| {
2031 async move {
2032 thread
2033 .update(&mut cx, |thread, cx| {
2034 thread.handle_session_update(
2035 acp::SessionUpdate::ToolCall(acp::ToolCall {
2036 id: acp::ToolCallId("test".into()),
2037 title: "Label".into(),
2038 kind: acp::ToolKind::Edit,
2039 status: acp::ToolCallStatus::Completed,
2040 content: vec![acp::ToolCallContent::Diff {
2041 diff: acp::Diff {
2042 path: "/test/test.txt".into(),
2043 old_text: None,
2044 new_text: "foo".into(),
2045 },
2046 }],
2047 locations: vec![],
2048 raw_input: None,
2049 raw_output: None,
2050 }),
2051 cx,
2052 )
2053 })
2054 .unwrap()
2055 .unwrap();
2056 Ok(acp::PromptResponse {
2057 stop_reason: acp::StopReason::EndTurn,
2058 })
2059 }
2060 .boxed_local()
2061 }
2062 }));
2063
2064 let thread = cx
2065 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2066 .await
2067 .unwrap();
2068
2069 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2070 .await
2071 .unwrap();
2072
2073 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2074 }
2075
2076 #[gpui::test(iterations = 10)]
2077 async fn test_checkpoints(cx: &mut TestAppContext) {
2078 init_test(cx);
2079 let fs = FakeFs::new(cx.background_executor.clone());
2080 fs.insert_tree(
2081 path!("/test"),
2082 json!({
2083 ".git": {}
2084 }),
2085 )
2086 .await;
2087 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2088
2089 let simulate_changes = Arc::new(AtomicBool::new(true));
2090 let next_filename = Arc::new(AtomicUsize::new(0));
2091 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2092 let simulate_changes = simulate_changes.clone();
2093 let next_filename = next_filename.clone();
2094 let fs = fs.clone();
2095 move |request, thread, mut cx| {
2096 let fs = fs.clone();
2097 let simulate_changes = simulate_changes.clone();
2098 let next_filename = next_filename.clone();
2099 async move {
2100 if simulate_changes.load(SeqCst) {
2101 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2102 fs.write(Path::new(&filename), b"").await?;
2103 }
2104
2105 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2106 panic!("expected text content block");
2107 };
2108 thread.update(&mut cx, |thread, cx| {
2109 thread
2110 .handle_session_update(
2111 acp::SessionUpdate::AgentMessageChunk {
2112 content: content.text.to_uppercase().into(),
2113 },
2114 cx,
2115 )
2116 .unwrap();
2117 })?;
2118 Ok(acp::PromptResponse {
2119 stop_reason: acp::StopReason::EndTurn,
2120 })
2121 }
2122 .boxed_local()
2123 }
2124 }));
2125 let thread = cx
2126 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2127 .await
2128 .unwrap();
2129
2130 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2131 .await
2132 .unwrap();
2133 thread.read_with(cx, |thread, cx| {
2134 assert_eq!(
2135 thread.to_markdown(cx),
2136 indoc! {"
2137 ## User (checkpoint)
2138
2139 Lorem
2140
2141 ## Assistant
2142
2143 LOREM
2144
2145 "}
2146 );
2147 });
2148 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2149
2150 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2151 .await
2152 .unwrap();
2153 thread.read_with(cx, |thread, cx| {
2154 assert_eq!(
2155 thread.to_markdown(cx),
2156 indoc! {"
2157 ## User (checkpoint)
2158
2159 Lorem
2160
2161 ## Assistant
2162
2163 LOREM
2164
2165 ## User (checkpoint)
2166
2167 ipsum
2168
2169 ## Assistant
2170
2171 IPSUM
2172
2173 "}
2174 );
2175 });
2176 assert_eq!(
2177 fs.files(),
2178 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2179 );
2180
2181 // Checkpoint isn't stored when there are no changes.
2182 simulate_changes.store(false, SeqCst);
2183 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2184 .await
2185 .unwrap();
2186 thread.read_with(cx, |thread, cx| {
2187 assert_eq!(
2188 thread.to_markdown(cx),
2189 indoc! {"
2190 ## User (checkpoint)
2191
2192 Lorem
2193
2194 ## Assistant
2195
2196 LOREM
2197
2198 ## User (checkpoint)
2199
2200 ipsum
2201
2202 ## Assistant
2203
2204 IPSUM
2205
2206 ## User
2207
2208 dolor
2209
2210 ## Assistant
2211
2212 DOLOR
2213
2214 "}
2215 );
2216 });
2217 assert_eq!(
2218 fs.files(),
2219 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2220 );
2221
2222 // Rewinding the conversation truncates the history and restores the checkpoint.
2223 thread
2224 .update(cx, |thread, cx| {
2225 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2226 panic!("unexpected entries {:?}", thread.entries)
2227 };
2228 thread.rewind(message.id.clone().unwrap(), cx)
2229 })
2230 .await
2231 .unwrap();
2232 thread.read_with(cx, |thread, cx| {
2233 assert_eq!(
2234 thread.to_markdown(cx),
2235 indoc! {"
2236 ## User (checkpoint)
2237
2238 Lorem
2239
2240 ## Assistant
2241
2242 LOREM
2243
2244 "}
2245 );
2246 });
2247 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2248 }
2249
2250 async fn run_until_first_tool_call(
2251 thread: &Entity<AcpThread>,
2252 cx: &mut TestAppContext,
2253 ) -> usize {
2254 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2255
2256 let subscription = cx.update(|cx| {
2257 cx.subscribe(thread, move |thread, _, cx| {
2258 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2259 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2260 return tx.try_send(ix).unwrap();
2261 }
2262 }
2263 })
2264 });
2265
2266 select! {
2267 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2268 panic!("Timeout waiting for tool call")
2269 }
2270 ix = rx.next().fuse() => {
2271 drop(subscription);
2272 ix.unwrap()
2273 }
2274 }
2275 }
2276
2277 #[derive(Clone, Default)]
2278 struct FakeAgentConnection {
2279 auth_methods: Vec<acp::AuthMethod>,
2280 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2281 on_user_message: Option<
2282 Rc<
2283 dyn Fn(
2284 acp::PromptRequest,
2285 WeakEntity<AcpThread>,
2286 AsyncApp,
2287 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2288 + 'static,
2289 >,
2290 >,
2291 }
2292
2293 impl FakeAgentConnection {
2294 fn new() -> Self {
2295 Self {
2296 auth_methods: Vec::new(),
2297 on_user_message: None,
2298 sessions: Arc::default(),
2299 }
2300 }
2301
2302 #[expect(unused)]
2303 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2304 self.auth_methods = auth_methods;
2305 self
2306 }
2307
2308 fn on_user_message(
2309 mut self,
2310 handler: impl Fn(
2311 acp::PromptRequest,
2312 WeakEntity<AcpThread>,
2313 AsyncApp,
2314 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2315 + 'static,
2316 ) -> Self {
2317 self.on_user_message.replace(Rc::new(handler));
2318 self
2319 }
2320 }
2321
2322 impl AgentConnection for FakeAgentConnection {
2323 fn auth_methods(&self) -> &[acp::AuthMethod] {
2324 &self.auth_methods
2325 }
2326
2327 fn new_thread(
2328 self: Rc<Self>,
2329 project: Entity<Project>,
2330 _cwd: &Path,
2331 cx: &mut gpui::App,
2332 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2333 let session_id = acp::SessionId(
2334 rand::thread_rng()
2335 .sample_iter(&rand::distributions::Alphanumeric)
2336 .take(7)
2337 .map(char::from)
2338 .collect::<String>()
2339 .into(),
2340 );
2341 let thread =
2342 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2343 self.sessions.lock().insert(session_id, thread.downgrade());
2344 Task::ready(Ok(thread))
2345 }
2346
2347 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2348 if self.auth_methods().iter().any(|m| m.id == method) {
2349 Task::ready(Ok(()))
2350 } else {
2351 Task::ready(Err(anyhow!("Invalid Auth Method")))
2352 }
2353 }
2354
2355 fn prompt(
2356 &self,
2357 _id: Option<UserMessageId>,
2358 params: acp::PromptRequest,
2359 cx: &mut App,
2360 ) -> Task<gpui::Result<acp::PromptResponse>> {
2361 let sessions = self.sessions.lock();
2362 let thread = sessions.get(¶ms.session_id).unwrap();
2363 if let Some(handler) = &self.on_user_message {
2364 let handler = handler.clone();
2365 let thread = thread.clone();
2366 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2367 } else {
2368 Task::ready(Ok(acp::PromptResponse {
2369 stop_reason: acp::StopReason::EndTurn,
2370 }))
2371 }
2372 }
2373
2374 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2375 let sessions = self.sessions.lock();
2376 let thread = sessions.get(session_id).unwrap().clone();
2377
2378 cx.spawn(async move |cx| {
2379 thread
2380 .update(cx, |thread, cx| thread.cancel(cx))
2381 .unwrap()
2382 .await
2383 })
2384 .detach();
2385 }
2386
2387 fn session_editor(
2388 &self,
2389 session_id: &acp::SessionId,
2390 _cx: &mut App,
2391 ) -> Option<Rc<dyn AgentSessionEditor>> {
2392 Some(Rc::new(FakeAgentSessionEditor {
2393 _session_id: session_id.clone(),
2394 }))
2395 }
2396
2397 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2398 self
2399 }
2400 }
2401
2402 struct FakeAgentSessionEditor {
2403 _session_id: acp::SessionId,
2404 }
2405
2406 impl AgentSessionEditor for FakeAgentSessionEditor {
2407 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2408 Task::ready(Ok(()))
2409 }
2410 }
2411}