1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub id: Option<UserMessageId>,
35 pub content: ContentBlock,
36 pub chunks: Vec<acp::ContentBlock>,
37 pub checkpoint: Option<Checkpoint>,
38}
39
40#[derive(Debug)]
41pub struct Checkpoint {
42 git_checkpoint: GitStoreCheckpoint,
43 pub show: bool,
44}
45
46impl UserMessage {
47 fn to_markdown(&self, cx: &App) -> String {
48 let mut markdown = String::new();
49 if self
50 .checkpoint
51 .as_ref()
52 .is_some_and(|checkpoint| checkpoint.show)
53 {
54 writeln!(markdown, "## User (checkpoint)").unwrap();
55 } else {
56 writeln!(markdown, "## User").unwrap();
57 }
58 writeln!(markdown).unwrap();
59 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
60 writeln!(markdown).unwrap();
61 markdown
62 }
63}
64
65#[derive(Debug, PartialEq)]
66pub struct AssistantMessage {
67 pub chunks: Vec<AssistantMessageChunk>,
68}
69
70impl AssistantMessage {
71 pub fn to_markdown(&self, cx: &App) -> String {
72 format!(
73 "## Assistant\n\n{}\n\n",
74 self.chunks
75 .iter()
76 .map(|chunk| chunk.to_markdown(cx))
77 .join("\n\n")
78 )
79 }
80}
81
82#[derive(Debug, PartialEq)]
83pub enum AssistantMessageChunk {
84 Message { block: ContentBlock },
85 Thought { block: ContentBlock },
86}
87
88impl AssistantMessageChunk {
89 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
90 Self::Message {
91 block: ContentBlock::new(chunk.into(), language_registry, cx),
92 }
93 }
94
95 fn to_markdown(&self, cx: &App) -> String {
96 match self {
97 Self::Message { block } => block.to_markdown(cx).to_string(),
98 Self::Thought { block } => {
99 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum AgentThreadEntry {
107 UserMessage(UserMessage),
108 AssistantMessage(AssistantMessage),
109 ToolCall(ToolCall),
110}
111
112impl AgentThreadEntry {
113 pub fn to_markdown(&self, cx: &App) -> String {
114 match self {
115 Self::UserMessage(message) => message.to_markdown(cx),
116 Self::AssistantMessage(message) => message.to_markdown(cx),
117 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
118 }
119 }
120
121 pub fn user_message(&self) -> Option<&UserMessage> {
122 if let AgentThreadEntry::UserMessage(message) = self {
123 Some(message)
124 } else {
125 None
126 }
127 }
128
129 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
130 if let AgentThreadEntry::ToolCall(call) = self {
131 itertools::Either::Left(call.diffs())
132 } else {
133 itertools::Either::Right(std::iter::empty())
134 }
135 }
136
137 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.terminals())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
146 if let AgentThreadEntry::ToolCall(ToolCall {
147 locations,
148 resolved_locations,
149 ..
150 }) = self
151 {
152 Some((
153 locations.get(ix)?.clone(),
154 resolved_locations.get(ix)?.clone()?,
155 ))
156 } else {
157 None
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct ToolCall {
164 pub id: acp::ToolCallId,
165 pub label: Entity<Markdown>,
166 pub kind: acp::ToolKind,
167 pub content: Vec<ToolCallContent>,
168 pub status: ToolCallStatus,
169 pub locations: Vec<acp::ToolCallLocation>,
170 pub resolved_locations: Vec<Option<AgentLocation>>,
171 pub raw_input: Option<serde_json::Value>,
172 pub raw_output: Option<serde_json::Value>,
173}
174
175impl ToolCall {
176 fn from_acp(
177 tool_call: acp::ToolCall,
178 status: ToolCallStatus,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) -> Self {
182 Self {
183 id: tool_call.id,
184 label: cx.new(|cx| {
185 Markdown::new(
186 tool_call.title.into(),
187 Some(language_registry.clone()),
188 None,
189 cx,
190 )
191 }),
192 kind: tool_call.kind,
193 content: tool_call
194 .content
195 .into_iter()
196 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
197 .collect(),
198 locations: tool_call.locations,
199 resolved_locations: Vec::default(),
200 status,
201 raw_input: tool_call.raw_input,
202 raw_output: tool_call.raw_output,
203 }
204 }
205
206 fn update_fields(
207 &mut self,
208 fields: acp::ToolCallUpdateFields,
209 language_registry: Arc<LanguageRegistry>,
210 cx: &mut App,
211 ) {
212 let acp::ToolCallUpdateFields {
213 kind,
214 status,
215 title,
216 content,
217 locations,
218 raw_input,
219 raw_output,
220 } = fields;
221
222 if let Some(kind) = kind {
223 self.kind = kind;
224 }
225
226 if let Some(status) = status {
227 self.status = status.into();
228 }
229
230 if let Some(title) = title {
231 self.label.update(cx, |label, cx| {
232 label.replace(title, cx);
233 });
234 }
235
236 if let Some(content) = content {
237 self.content = content
238 .into_iter()
239 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
240 .collect();
241 }
242
243 if let Some(locations) = locations {
244 self.locations = locations;
245 }
246
247 if let Some(raw_input) = raw_input {
248 self.raw_input = Some(raw_input);
249 }
250
251 if let Some(raw_output) = raw_output {
252 if self.content.is_empty()
253 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
254 {
255 self.content
256 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
257 markdown,
258 }));
259 }
260 self.raw_output = Some(raw_output);
261 }
262 }
263
264 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
265 self.content.iter().filter_map(|content| match content {
266 ToolCallContent::Diff(diff) => Some(diff),
267 ToolCallContent::ContentBlock(_) => None,
268 ToolCallContent::Terminal(_) => None,
269 })
270 }
271
272 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
273 self.content.iter().filter_map(|content| match content {
274 ToolCallContent::Terminal(terminal) => Some(terminal),
275 ToolCallContent::ContentBlock(_) => None,
276 ToolCallContent::Diff(_) => None,
277 })
278 }
279
280 fn to_markdown(&self, cx: &App) -> String {
281 let mut markdown = format!(
282 "**Tool Call: {}**\nStatus: {}\n\n",
283 self.label.read(cx).source(),
284 self.status
285 );
286 for content in &self.content {
287 markdown.push_str(content.to_markdown(cx).as_str());
288 markdown.push_str("\n\n");
289 }
290 markdown
291 }
292
293 async fn resolve_location(
294 location: acp::ToolCallLocation,
295 project: WeakEntity<Project>,
296 cx: &mut AsyncApp,
297 ) -> Option<AgentLocation> {
298 let buffer = project
299 .update(cx, |project, cx| {
300 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
301 Some(project.open_buffer(path, cx))
302 } else {
303 None
304 }
305 })
306 .ok()??;
307 let buffer = buffer.await.log_err()?;
308 let position = buffer
309 .update(cx, |buffer, _| {
310 if let Some(row) = location.line {
311 let snapshot = buffer.snapshot();
312 let column = snapshot.indent_size_for_line(row).len;
313 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
314 snapshot.anchor_before(point)
315 } else {
316 Anchor::MIN
317 }
318 })
319 .ok()?;
320
321 Some(AgentLocation {
322 buffer: buffer.downgrade(),
323 position,
324 })
325 }
326
327 fn resolve_locations(
328 &self,
329 project: Entity<Project>,
330 cx: &mut App,
331 ) -> Task<Vec<Option<AgentLocation>>> {
332 let locations = self.locations.clone();
333 project.update(cx, |_, cx| {
334 cx.spawn(async move |project, cx| {
335 let mut new_locations = Vec::new();
336 for location in locations {
337 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
338 }
339 new_locations
340 })
341 })
342 }
343}
344
345#[derive(Debug)]
346pub enum ToolCallStatus {
347 /// The tool call hasn't started running yet, but we start showing it to
348 /// the user.
349 Pending,
350 /// The tool call is waiting for confirmation from the user.
351 WaitingForConfirmation {
352 options: Vec<acp::PermissionOption>,
353 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
354 },
355 /// The tool call is currently running.
356 InProgress,
357 /// The tool call completed successfully.
358 Completed,
359 /// The tool call failed.
360 Failed,
361 /// The user rejected the tool call.
362 Rejected,
363 /// The user canceled generation so the tool call was canceled.
364 Canceled,
365}
366
367impl From<acp::ToolCallStatus> for ToolCallStatus {
368 fn from(status: acp::ToolCallStatus) -> Self {
369 match status {
370 acp::ToolCallStatus::Pending => Self::Pending,
371 acp::ToolCallStatus::InProgress => Self::InProgress,
372 acp::ToolCallStatus::Completed => Self::Completed,
373 acp::ToolCallStatus::Failed => Self::Failed,
374 }
375 }
376}
377
378impl Display for ToolCallStatus {
379 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
380 write!(
381 f,
382 "{}",
383 match self {
384 ToolCallStatus::Pending => "Pending",
385 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
386 ToolCallStatus::InProgress => "In Progress",
387 ToolCallStatus::Completed => "Completed",
388 ToolCallStatus::Failed => "Failed",
389 ToolCallStatus::Rejected => "Rejected",
390 ToolCallStatus::Canceled => "Canceled",
391 }
392 )
393 }
394}
395
396#[derive(Debug, PartialEq, Clone)]
397pub enum ContentBlock {
398 Empty,
399 Markdown { markdown: Entity<Markdown> },
400 ResourceLink { resource_link: acp::ResourceLink },
401}
402
403impl ContentBlock {
404 pub fn new(
405 block: acp::ContentBlock,
406 language_registry: &Arc<LanguageRegistry>,
407 cx: &mut App,
408 ) -> Self {
409 let mut this = Self::Empty;
410 this.append(block, language_registry, cx);
411 this
412 }
413
414 pub fn new_combined(
415 blocks: impl IntoIterator<Item = acp::ContentBlock>,
416 language_registry: Arc<LanguageRegistry>,
417 cx: &mut App,
418 ) -> Self {
419 let mut this = Self::Empty;
420 for block in blocks {
421 this.append(block, &language_registry, cx);
422 }
423 this
424 }
425
426 pub fn append(
427 &mut self,
428 block: acp::ContentBlock,
429 language_registry: &Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) {
432 if matches!(self, ContentBlock::Empty)
433 && let acp::ContentBlock::ResourceLink(resource_link) = block
434 {
435 *self = ContentBlock::ResourceLink { resource_link };
436 return;
437 }
438
439 let new_content = self.block_string_contents(block);
440
441 match self {
442 ContentBlock::Empty => {
443 *self = Self::create_markdown_block(new_content, language_registry, cx);
444 }
445 ContentBlock::Markdown { markdown } => {
446 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
447 }
448 ContentBlock::ResourceLink { resource_link } => {
449 let existing_content = Self::resource_link_md(&resource_link.uri);
450 let combined = format!("{}\n{}", existing_content, new_content);
451
452 *self = Self::create_markdown_block(combined, language_registry, cx);
453 }
454 }
455 }
456
457 fn create_markdown_block(
458 content: String,
459 language_registry: &Arc<LanguageRegistry>,
460 cx: &mut App,
461 ) -> ContentBlock {
462 ContentBlock::Markdown {
463 markdown: cx
464 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
465 }
466 }
467
468 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
469 match block {
470 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
471 acp::ContentBlock::ResourceLink(resource_link) => {
472 Self::resource_link_md(&resource_link.uri)
473 }
474 acp::ContentBlock::Resource(acp::EmbeddedResource {
475 resource:
476 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
477 uri,
478 ..
479 }),
480 ..
481 }) => Self::resource_link_md(&uri),
482 acp::ContentBlock::Image(image) => Self::image_md(&image),
483 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
484 }
485 }
486
487 fn resource_link_md(uri: &str) -> String {
488 if let Some(uri) = MentionUri::parse(uri).log_err() {
489 uri.as_link().to_string()
490 } else {
491 uri.to_string()
492 }
493 }
494
495 fn image_md(_image: &acp::ImageContent) -> String {
496 "`Image`".into()
497 }
498
499 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
500 match self {
501 ContentBlock::Empty => "",
502 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
503 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
504 }
505 }
506
507 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
508 match self {
509 ContentBlock::Empty => None,
510 ContentBlock::Markdown { markdown } => Some(markdown),
511 ContentBlock::ResourceLink { .. } => None,
512 }
513 }
514
515 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
516 match self {
517 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
518 _ => None,
519 }
520 }
521}
522
523#[derive(Debug)]
524pub enum ToolCallContent {
525 ContentBlock(ContentBlock),
526 Diff(Entity<Diff>),
527 Terminal(Entity<Terminal>),
528}
529
530impl ToolCallContent {
531 pub fn from_acp(
532 content: acp::ToolCallContent,
533 language_registry: Arc<LanguageRegistry>,
534 cx: &mut App,
535 ) -> Self {
536 match content {
537 acp::ToolCallContent::Content { content } => {
538 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
539 }
540 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
541 Diff::finalized(
542 diff.path,
543 diff.old_text,
544 diff.new_text,
545 language_registry,
546 cx,
547 )
548 })),
549 }
550 }
551
552 pub fn to_markdown(&self, cx: &App) -> String {
553 match self {
554 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
555 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
556 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
557 }
558 }
559}
560
561#[derive(Debug, PartialEq)]
562pub enum ToolCallUpdate {
563 UpdateFields(acp::ToolCallUpdate),
564 UpdateDiff(ToolCallUpdateDiff),
565 UpdateTerminal(ToolCallUpdateTerminal),
566}
567
568impl ToolCallUpdate {
569 fn id(&self) -> &acp::ToolCallId {
570 match self {
571 Self::UpdateFields(update) => &update.id,
572 Self::UpdateDiff(diff) => &diff.id,
573 Self::UpdateTerminal(terminal) => &terminal.id,
574 }
575 }
576}
577
578impl From<acp::ToolCallUpdate> for ToolCallUpdate {
579 fn from(update: acp::ToolCallUpdate) -> Self {
580 Self::UpdateFields(update)
581 }
582}
583
584impl From<ToolCallUpdateDiff> for ToolCallUpdate {
585 fn from(diff: ToolCallUpdateDiff) -> Self {
586 Self::UpdateDiff(diff)
587 }
588}
589
590#[derive(Debug, PartialEq)]
591pub struct ToolCallUpdateDiff {
592 pub id: acp::ToolCallId,
593 pub diff: Entity<Diff>,
594}
595
596impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
597 fn from(terminal: ToolCallUpdateTerminal) -> Self {
598 Self::UpdateTerminal(terminal)
599 }
600}
601
602#[derive(Debug, PartialEq)]
603pub struct ToolCallUpdateTerminal {
604 pub id: acp::ToolCallId,
605 pub terminal: Entity<Terminal>,
606}
607
608#[derive(Debug, Default)]
609pub struct Plan {
610 pub entries: Vec<PlanEntry>,
611}
612
613#[derive(Debug)]
614pub struct PlanStats<'a> {
615 pub in_progress_entry: Option<&'a PlanEntry>,
616 pub pending: u32,
617 pub completed: u32,
618}
619
620impl Plan {
621 pub fn is_empty(&self) -> bool {
622 self.entries.is_empty()
623 }
624
625 pub fn stats(&self) -> PlanStats<'_> {
626 let mut stats = PlanStats {
627 in_progress_entry: None,
628 pending: 0,
629 completed: 0,
630 };
631
632 for entry in &self.entries {
633 match &entry.status {
634 acp::PlanEntryStatus::Pending => {
635 stats.pending += 1;
636 }
637 acp::PlanEntryStatus::InProgress => {
638 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
639 }
640 acp::PlanEntryStatus::Completed => {
641 stats.completed += 1;
642 }
643 }
644 }
645
646 stats
647 }
648}
649
650#[derive(Debug)]
651pub struct PlanEntry {
652 pub content: Entity<Markdown>,
653 pub priority: acp::PlanEntryPriority,
654 pub status: acp::PlanEntryStatus,
655}
656
657impl PlanEntry {
658 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
659 Self {
660 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
661 priority: entry.priority,
662 status: entry.status,
663 }
664 }
665}
666
667#[derive(Debug, Clone)]
668pub struct RetryStatus {
669 pub last_error: SharedString,
670 pub attempt: usize,
671 pub max_attempts: usize,
672 pub started_at: Instant,
673 pub duration: Duration,
674}
675
676pub struct AcpThread {
677 title: SharedString,
678 entries: Vec<AgentThreadEntry>,
679 plan: Plan,
680 project: Entity<Project>,
681 action_log: Entity<ActionLog>,
682 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
683 send_task: Option<Task<()>>,
684 connection: Rc<dyn AgentConnection>,
685 session_id: acp::SessionId,
686}
687
688#[derive(Debug)]
689pub enum AcpThreadEvent {
690 NewEntry,
691 TitleUpdated,
692 EntryUpdated(usize),
693 EntriesRemoved(Range<usize>),
694 ToolAuthorizationRequired,
695 Retry(RetryStatus),
696 Stopped,
697 Error,
698 ServerExited(ExitStatus),
699}
700
701impl EventEmitter<AcpThreadEvent> for AcpThread {}
702
703#[derive(PartialEq, Eq)]
704pub enum ThreadStatus {
705 Idle,
706 WaitingForToolConfirmation,
707 Generating,
708}
709
710#[derive(Debug, Clone)]
711pub enum LoadError {
712 Unsupported {
713 error_message: SharedString,
714 upgrade_message: SharedString,
715 upgrade_command: String,
716 },
717 Exited(i32),
718 Other(SharedString),
719}
720
721impl Display for LoadError {
722 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
723 match self {
724 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
725 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
726 LoadError::Other(msg) => write!(f, "{}", msg),
727 }
728 }
729}
730
731impl Error for LoadError {}
732
733impl AcpThread {
734 pub fn new(
735 title: impl Into<SharedString>,
736 connection: Rc<dyn AgentConnection>,
737 project: Entity<Project>,
738 action_log: Entity<ActionLog>,
739 session_id: acp::SessionId,
740 ) -> Self {
741 Self {
742 action_log,
743 shared_buffers: Default::default(),
744 entries: Default::default(),
745 plan: Default::default(),
746 title: title.into(),
747 project,
748 send_task: None,
749 connection,
750 session_id,
751 }
752 }
753
754 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
755 &self.connection
756 }
757
758 pub fn action_log(&self) -> &Entity<ActionLog> {
759 &self.action_log
760 }
761
762 pub fn project(&self) -> &Entity<Project> {
763 &self.project
764 }
765
766 pub fn title(&self) -> SharedString {
767 self.title.clone()
768 }
769
770 pub fn entries(&self) -> &[AgentThreadEntry] {
771 &self.entries
772 }
773
774 pub fn session_id(&self) -> &acp::SessionId {
775 &self.session_id
776 }
777
778 pub fn status(&self) -> ThreadStatus {
779 if self.send_task.is_some() {
780 if self.waiting_for_tool_confirmation() {
781 ThreadStatus::WaitingForToolConfirmation
782 } else {
783 ThreadStatus::Generating
784 }
785 } else {
786 ThreadStatus::Idle
787 }
788 }
789
790 pub fn has_pending_edit_tool_calls(&self) -> bool {
791 for entry in self.entries.iter().rev() {
792 match entry {
793 AgentThreadEntry::UserMessage(_) => return false,
794 AgentThreadEntry::ToolCall(
795 call @ ToolCall {
796 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
797 ..
798 },
799 ) if call.diffs().next().is_some() => {
800 return true;
801 }
802 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
803 }
804 }
805
806 false
807 }
808
809 pub fn used_tools_since_last_user_message(&self) -> bool {
810 for entry in self.entries.iter().rev() {
811 match entry {
812 AgentThreadEntry::UserMessage(..) => return false,
813 AgentThreadEntry::AssistantMessage(..) => continue,
814 AgentThreadEntry::ToolCall(..) => return true,
815 }
816 }
817
818 false
819 }
820
821 pub fn handle_session_update(
822 &mut self,
823 update: acp::SessionUpdate,
824 cx: &mut Context<Self>,
825 ) -> Result<(), acp::Error> {
826 match update {
827 acp::SessionUpdate::UserMessageChunk { content } => {
828 self.push_user_content_block(None, content, cx);
829 }
830 acp::SessionUpdate::AgentMessageChunk { content } => {
831 self.push_assistant_content_block(content, false, cx);
832 }
833 acp::SessionUpdate::AgentThoughtChunk { content } => {
834 self.push_assistant_content_block(content, true, cx);
835 }
836 acp::SessionUpdate::ToolCall(tool_call) => {
837 self.upsert_tool_call(tool_call, cx)?;
838 }
839 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
840 self.update_tool_call(tool_call_update, cx)?;
841 }
842 acp::SessionUpdate::Plan(plan) => {
843 self.update_plan(plan, cx);
844 }
845 }
846 Ok(())
847 }
848
849 pub fn push_user_content_block(
850 &mut self,
851 message_id: Option<UserMessageId>,
852 chunk: acp::ContentBlock,
853 cx: &mut Context<Self>,
854 ) {
855 let language_registry = self.project.read(cx).languages().clone();
856 let entries_len = self.entries.len();
857
858 if let Some(last_entry) = self.entries.last_mut()
859 && let AgentThreadEntry::UserMessage(UserMessage {
860 id,
861 content,
862 chunks,
863 ..
864 }) = last_entry
865 {
866 *id = message_id.or(id.take());
867 content.append(chunk.clone(), &language_registry, cx);
868 chunks.push(chunk);
869 let idx = entries_len - 1;
870 cx.emit(AcpThreadEvent::EntryUpdated(idx));
871 } else {
872 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
873 self.push_entry(
874 AgentThreadEntry::UserMessage(UserMessage {
875 id: message_id,
876 content,
877 chunks: vec![chunk],
878 checkpoint: None,
879 }),
880 cx,
881 );
882 }
883 }
884
885 pub fn push_assistant_content_block(
886 &mut self,
887 chunk: acp::ContentBlock,
888 is_thought: bool,
889 cx: &mut Context<Self>,
890 ) {
891 let language_registry = self.project.read(cx).languages().clone();
892 let entries_len = self.entries.len();
893 if let Some(last_entry) = self.entries.last_mut()
894 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
895 {
896 let idx = entries_len - 1;
897 cx.emit(AcpThreadEvent::EntryUpdated(idx));
898 match (chunks.last_mut(), is_thought) {
899 (Some(AssistantMessageChunk::Message { block }), false)
900 | (Some(AssistantMessageChunk::Thought { block }), true) => {
901 block.append(chunk, &language_registry, cx)
902 }
903 _ => {
904 let block = ContentBlock::new(chunk, &language_registry, cx);
905 if is_thought {
906 chunks.push(AssistantMessageChunk::Thought { block })
907 } else {
908 chunks.push(AssistantMessageChunk::Message { block })
909 }
910 }
911 }
912 } else {
913 let block = ContentBlock::new(chunk, &language_registry, cx);
914 let chunk = if is_thought {
915 AssistantMessageChunk::Thought { block }
916 } else {
917 AssistantMessageChunk::Message { block }
918 };
919
920 self.push_entry(
921 AgentThreadEntry::AssistantMessage(AssistantMessage {
922 chunks: vec![chunk],
923 }),
924 cx,
925 );
926 }
927 }
928
929 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
930 self.entries.push(entry);
931 cx.emit(AcpThreadEvent::NewEntry);
932 }
933
934 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
935 self.title = title;
936 cx.emit(AcpThreadEvent::TitleUpdated);
937 Ok(())
938 }
939
940 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
941 cx.emit(AcpThreadEvent::Retry(status));
942 }
943
944 pub fn update_tool_call(
945 &mut self,
946 update: impl Into<ToolCallUpdate>,
947 cx: &mut Context<Self>,
948 ) -> Result<()> {
949 let update = update.into();
950 let languages = self.project.read(cx).languages().clone();
951
952 let (ix, current_call) = self
953 .tool_call_mut(update.id())
954 .context("Tool call not found")?;
955 match update {
956 ToolCallUpdate::UpdateFields(update) => {
957 let location_updated = update.fields.locations.is_some();
958 current_call.update_fields(update.fields, languages, cx);
959 if location_updated {
960 self.resolve_locations(update.id.clone(), cx);
961 }
962 }
963 ToolCallUpdate::UpdateDiff(update) => {
964 current_call.content.clear();
965 current_call
966 .content
967 .push(ToolCallContent::Diff(update.diff));
968 }
969 ToolCallUpdate::UpdateTerminal(update) => {
970 current_call.content.clear();
971 current_call
972 .content
973 .push(ToolCallContent::Terminal(update.terminal));
974 }
975 }
976
977 cx.emit(AcpThreadEvent::EntryUpdated(ix));
978
979 Ok(())
980 }
981
982 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
983 pub fn upsert_tool_call(
984 &mut self,
985 tool_call: acp::ToolCall,
986 cx: &mut Context<Self>,
987 ) -> Result<(), acp::Error> {
988 let status = tool_call.status.into();
989 self.upsert_tool_call_inner(tool_call.into(), status, cx)
990 }
991
992 /// Fails if id does not match an existing entry.
993 pub fn upsert_tool_call_inner(
994 &mut self,
995 tool_call_update: acp::ToolCallUpdate,
996 status: ToolCallStatus,
997 cx: &mut Context<Self>,
998 ) -> Result<(), acp::Error> {
999 let language_registry = self.project.read(cx).languages().clone();
1000 let id = tool_call_update.id.clone();
1001
1002 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1003 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1004 current_call.status = status;
1005
1006 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1007 } else {
1008 let call =
1009 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1010 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1011 };
1012
1013 self.resolve_locations(id, cx);
1014 Ok(())
1015 }
1016
1017 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1018 // The tool call we are looking for is typically the last one, or very close to the end.
1019 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1020 self.entries
1021 .iter_mut()
1022 .enumerate()
1023 .rev()
1024 .find_map(|(index, tool_call)| {
1025 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1026 && &tool_call.id == id
1027 {
1028 Some((index, tool_call))
1029 } else {
1030 None
1031 }
1032 })
1033 }
1034
1035 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1036 let project = self.project.clone();
1037 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1038 return;
1039 };
1040 let task = tool_call.resolve_locations(project, cx);
1041 cx.spawn(async move |this, cx| {
1042 let resolved_locations = task.await;
1043 this.update(cx, |this, cx| {
1044 let project = this.project.clone();
1045 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1046 return;
1047 };
1048 if let Some(Some(location)) = resolved_locations.last() {
1049 project.update(cx, |project, cx| {
1050 if let Some(agent_location) = project.agent_location() {
1051 let should_ignore = agent_location.buffer == location.buffer
1052 && location
1053 .buffer
1054 .update(cx, |buffer, _| {
1055 let snapshot = buffer.snapshot();
1056 let old_position =
1057 agent_location.position.to_point(&snapshot);
1058 let new_position = location.position.to_point(&snapshot);
1059 // ignore this so that when we get updates from the edit tool
1060 // the position doesn't reset to the startof line
1061 old_position.row == new_position.row
1062 && old_position.column > new_position.column
1063 })
1064 .ok()
1065 .unwrap_or_default();
1066 if !should_ignore {
1067 project.set_agent_location(Some(location.clone()), cx);
1068 }
1069 }
1070 });
1071 }
1072 if tool_call.resolved_locations != resolved_locations {
1073 tool_call.resolved_locations = resolved_locations;
1074 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1075 }
1076 })
1077 })
1078 .detach();
1079 }
1080
1081 pub fn request_tool_call_authorization(
1082 &mut self,
1083 tool_call: acp::ToolCallUpdate,
1084 options: Vec<acp::PermissionOption>,
1085 cx: &mut Context<Self>,
1086 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1087 let (tx, rx) = oneshot::channel();
1088
1089 let status = ToolCallStatus::WaitingForConfirmation {
1090 options,
1091 respond_tx: tx,
1092 };
1093
1094 self.upsert_tool_call_inner(tool_call, status, cx)?;
1095 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1096 Ok(rx)
1097 }
1098
1099 pub fn authorize_tool_call(
1100 &mut self,
1101 id: acp::ToolCallId,
1102 option_id: acp::PermissionOptionId,
1103 option_kind: acp::PermissionOptionKind,
1104 cx: &mut Context<Self>,
1105 ) {
1106 let Some((ix, call)) = self.tool_call_mut(&id) else {
1107 return;
1108 };
1109
1110 let new_status = match option_kind {
1111 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1112 ToolCallStatus::Rejected
1113 }
1114 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1115 ToolCallStatus::InProgress
1116 }
1117 };
1118
1119 let curr_status = mem::replace(&mut call.status, new_status);
1120
1121 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1122 respond_tx.send(option_id).log_err();
1123 } else if cfg!(debug_assertions) {
1124 panic!("tried to authorize an already authorized tool call");
1125 }
1126
1127 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1128 }
1129
1130 /// Returns true if the last turn is awaiting tool authorization
1131 pub fn waiting_for_tool_confirmation(&self) -> bool {
1132 for entry in self.entries.iter().rev() {
1133 match &entry {
1134 AgentThreadEntry::ToolCall(call) => match call.status {
1135 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1136 ToolCallStatus::Pending
1137 | ToolCallStatus::InProgress
1138 | ToolCallStatus::Completed
1139 | ToolCallStatus::Failed
1140 | ToolCallStatus::Rejected
1141 | ToolCallStatus::Canceled => continue,
1142 },
1143 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1144 // Reached the beginning of the turn
1145 return false;
1146 }
1147 }
1148 }
1149 false
1150 }
1151
1152 pub fn plan(&self) -> &Plan {
1153 &self.plan
1154 }
1155
1156 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1157 let new_entries_len = request.entries.len();
1158 let mut new_entries = request.entries.into_iter();
1159
1160 // Reuse existing markdown to prevent flickering
1161 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1162 let PlanEntry {
1163 content,
1164 priority,
1165 status,
1166 } = old;
1167 content.update(cx, |old, cx| {
1168 old.replace(new.content, cx);
1169 });
1170 *priority = new.priority;
1171 *status = new.status;
1172 }
1173 for new in new_entries {
1174 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1175 }
1176 self.plan.entries.truncate(new_entries_len);
1177
1178 cx.notify();
1179 }
1180
1181 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1182 self.plan
1183 .entries
1184 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1185 cx.notify();
1186 }
1187
1188 #[cfg(any(test, feature = "test-support"))]
1189 pub fn send_raw(
1190 &mut self,
1191 message: &str,
1192 cx: &mut Context<Self>,
1193 ) -> BoxFuture<'static, Result<()>> {
1194 self.send(
1195 vec![acp::ContentBlock::Text(acp::TextContent {
1196 text: message.to_string(),
1197 annotations: None,
1198 })],
1199 cx,
1200 )
1201 }
1202
1203 pub fn send(
1204 &mut self,
1205 message: Vec<acp::ContentBlock>,
1206 cx: &mut Context<Self>,
1207 ) -> BoxFuture<'static, Result<()>> {
1208 let block = ContentBlock::new_combined(
1209 message.clone(),
1210 self.project.read(cx).languages().clone(),
1211 cx,
1212 );
1213 let request = acp::PromptRequest {
1214 prompt: message.clone(),
1215 session_id: self.session_id.clone(),
1216 };
1217 let git_store = self.project.read(cx).git_store().clone();
1218
1219 let message_id = if self
1220 .connection
1221 .session_editor(&self.session_id, cx)
1222 .is_some()
1223 {
1224 Some(UserMessageId::new())
1225 } else {
1226 None
1227 };
1228
1229 self.run_turn(cx, async move |this, cx| {
1230 this.update(cx, |this, cx| {
1231 this.push_entry(
1232 AgentThreadEntry::UserMessage(UserMessage {
1233 id: message_id.clone(),
1234 content: block,
1235 chunks: message,
1236 checkpoint: None,
1237 }),
1238 cx,
1239 );
1240 })
1241 .ok();
1242
1243 let old_checkpoint = git_store
1244 .update(cx, |git, cx| git.checkpoint(cx))?
1245 .await
1246 .context("failed to get old checkpoint")
1247 .log_err();
1248 this.update(cx, |this, cx| {
1249 if let Some((_ix, message)) = this.last_user_message() {
1250 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1251 git_checkpoint,
1252 show: false,
1253 });
1254 }
1255 this.connection.prompt(message_id, request, cx)
1256 })?
1257 .await
1258 })
1259 }
1260
1261 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1262 self.run_turn(cx, async move |this, cx| {
1263 this.update(cx, |this, cx| {
1264 this.connection
1265 .resume(&this.session_id, cx)
1266 .map(|resume| resume.run(cx))
1267 })?
1268 .context("resuming a session is not supported")?
1269 .await
1270 })
1271 }
1272
1273 fn run_turn(
1274 &mut self,
1275 cx: &mut Context<Self>,
1276 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1277 ) -> BoxFuture<'static, Result<()>> {
1278 self.clear_completed_plan_entries(cx);
1279
1280 let (tx, rx) = oneshot::channel();
1281 let cancel_task = self.cancel(cx);
1282
1283 self.send_task = Some(cx.spawn(async move |this, cx| {
1284 cancel_task.await;
1285 tx.send(f(this, cx).await).ok();
1286 }));
1287
1288 cx.spawn(async move |this, cx| {
1289 let response = rx.await;
1290
1291 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1292 .await?;
1293
1294 this.update(cx, |this, cx| {
1295 this.project
1296 .update(cx, |project, cx| project.set_agent_location(None, cx));
1297 match response {
1298 Ok(Err(e)) => {
1299 this.send_task.take();
1300 cx.emit(AcpThreadEvent::Error);
1301 Err(e)
1302 }
1303 result => {
1304 let canceled = matches!(
1305 result,
1306 Ok(Ok(acp::PromptResponse {
1307 stop_reason: acp::StopReason::Canceled
1308 }))
1309 );
1310
1311 // We only take the task if the current prompt wasn't canceled.
1312 //
1313 // This prompt may have been canceled because another one was sent
1314 // while it was still generating. In these cases, dropping `send_task`
1315 // would cause the next generation to be canceled.
1316 if !canceled {
1317 this.send_task.take();
1318 }
1319
1320 cx.emit(AcpThreadEvent::Stopped);
1321 Ok(())
1322 }
1323 }
1324 })?
1325 })
1326 .boxed()
1327 }
1328
1329 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1330 let Some(send_task) = self.send_task.take() else {
1331 return Task::ready(());
1332 };
1333
1334 for entry in self.entries.iter_mut() {
1335 if let AgentThreadEntry::ToolCall(call) = entry {
1336 let cancel = matches!(
1337 call.status,
1338 ToolCallStatus::Pending
1339 | ToolCallStatus::WaitingForConfirmation { .. }
1340 | ToolCallStatus::InProgress
1341 );
1342
1343 if cancel {
1344 call.status = ToolCallStatus::Canceled;
1345 }
1346 }
1347 }
1348
1349 self.connection.cancel(&self.session_id, cx);
1350
1351 // Wait for the send task to complete
1352 cx.foreground_executor().spawn(send_task)
1353 }
1354
1355 /// Rewinds this thread to before the entry at `index`, removing it and all
1356 /// subsequent entries while reverting any changes made from that point.
1357 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1358 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1359 return Task::ready(Err(anyhow!("not supported")));
1360 };
1361 let Some(message) = self.user_message(&id) else {
1362 return Task::ready(Err(anyhow!("message not found")));
1363 };
1364
1365 let checkpoint = message
1366 .checkpoint
1367 .as_ref()
1368 .map(|c| c.git_checkpoint.clone());
1369
1370 let git_store = self.project.read(cx).git_store().clone();
1371 cx.spawn(async move |this, cx| {
1372 if let Some(checkpoint) = checkpoint {
1373 git_store
1374 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1375 .await?;
1376 }
1377
1378 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1379 .await?;
1380 this.update(cx, |this, cx| {
1381 if let Some((ix, _)) = this.user_message_mut(&id) {
1382 let range = ix..this.entries.len();
1383 this.entries.truncate(ix);
1384 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1385 }
1386 })
1387 })
1388 }
1389
1390 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1391 let git_store = self.project.read(cx).git_store().clone();
1392
1393 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1394 if let Some(checkpoint) = message.checkpoint.as_ref() {
1395 checkpoint.git_checkpoint.clone()
1396 } else {
1397 return Task::ready(Ok(()));
1398 }
1399 } else {
1400 return Task::ready(Ok(()));
1401 };
1402
1403 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1404 cx.spawn(async move |this, cx| {
1405 let new_checkpoint = new_checkpoint
1406 .await
1407 .context("failed to get new checkpoint")
1408 .log_err();
1409 if let Some(new_checkpoint) = new_checkpoint {
1410 let equal = git_store
1411 .update(cx, |git, cx| {
1412 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1413 })?
1414 .await
1415 .unwrap_or(true);
1416 this.update(cx, |this, cx| {
1417 let (ix, message) = this.last_user_message().context("no user message")?;
1418 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1419 checkpoint.show = !equal;
1420 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1421 anyhow::Ok(())
1422 })??;
1423 }
1424
1425 Ok(())
1426 })
1427 }
1428
1429 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1430 self.entries
1431 .iter_mut()
1432 .enumerate()
1433 .rev()
1434 .find_map(|(ix, entry)| {
1435 if let AgentThreadEntry::UserMessage(message) = entry {
1436 Some((ix, message))
1437 } else {
1438 None
1439 }
1440 })
1441 }
1442
1443 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1444 self.entries.iter().find_map(|entry| {
1445 if let AgentThreadEntry::UserMessage(message) = entry {
1446 if message.id.as_ref() == Some(id) {
1447 Some(message)
1448 } else {
1449 None
1450 }
1451 } else {
1452 None
1453 }
1454 })
1455 }
1456
1457 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1458 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1459 if let AgentThreadEntry::UserMessage(message) = entry {
1460 if message.id.as_ref() == Some(id) {
1461 Some((ix, message))
1462 } else {
1463 None
1464 }
1465 } else {
1466 None
1467 }
1468 })
1469 }
1470
1471 pub fn read_text_file(
1472 &self,
1473 path: PathBuf,
1474 line: Option<u32>,
1475 limit: Option<u32>,
1476 reuse_shared_snapshot: bool,
1477 cx: &mut Context<Self>,
1478 ) -> Task<Result<String>> {
1479 let project = self.project.clone();
1480 let action_log = self.action_log.clone();
1481 cx.spawn(async move |this, cx| {
1482 let load = project.update(cx, |project, cx| {
1483 let path = project
1484 .project_path_for_absolute_path(&path, cx)
1485 .context("invalid path")?;
1486 anyhow::Ok(project.open_buffer(path, cx))
1487 });
1488 let buffer = load??.await?;
1489
1490 let snapshot = if reuse_shared_snapshot {
1491 this.read_with(cx, |this, _| {
1492 this.shared_buffers.get(&buffer.clone()).cloned()
1493 })
1494 .log_err()
1495 .flatten()
1496 } else {
1497 None
1498 };
1499
1500 let snapshot = if let Some(snapshot) = snapshot {
1501 snapshot
1502 } else {
1503 action_log.update(cx, |action_log, cx| {
1504 action_log.buffer_read(buffer.clone(), cx);
1505 })?;
1506 project.update(cx, |project, cx| {
1507 let position = buffer
1508 .read(cx)
1509 .snapshot()
1510 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1511 project.set_agent_location(
1512 Some(AgentLocation {
1513 buffer: buffer.downgrade(),
1514 position,
1515 }),
1516 cx,
1517 );
1518 })?;
1519
1520 buffer.update(cx, |buffer, _| buffer.snapshot())?
1521 };
1522
1523 this.update(cx, |this, _| {
1524 let text = snapshot.text();
1525 this.shared_buffers.insert(buffer.clone(), snapshot);
1526 if line.is_none() && limit.is_none() {
1527 return Ok(text);
1528 }
1529 let limit = limit.unwrap_or(u32::MAX) as usize;
1530 let Some(line) = line else {
1531 return Ok(text.lines().take(limit).collect::<String>());
1532 };
1533
1534 let count = text.lines().count();
1535 if count < line as usize {
1536 anyhow::bail!("There are only {} lines", count);
1537 }
1538 Ok(text
1539 .lines()
1540 .skip(line as usize + 1)
1541 .take(limit)
1542 .collect::<String>())
1543 })?
1544 })
1545 }
1546
1547 pub fn write_text_file(
1548 &self,
1549 path: PathBuf,
1550 content: String,
1551 cx: &mut Context<Self>,
1552 ) -> Task<Result<()>> {
1553 let project = self.project.clone();
1554 let action_log = self.action_log.clone();
1555 cx.spawn(async move |this, cx| {
1556 let load = project.update(cx, |project, cx| {
1557 let path = project
1558 .project_path_for_absolute_path(&path, cx)
1559 .context("invalid path")?;
1560 anyhow::Ok(project.open_buffer(path, cx))
1561 });
1562 let buffer = load??.await?;
1563 let snapshot = this.update(cx, |this, cx| {
1564 this.shared_buffers
1565 .get(&buffer)
1566 .cloned()
1567 .unwrap_or_else(|| buffer.read(cx).snapshot())
1568 })?;
1569 let edits = cx
1570 .background_executor()
1571 .spawn(async move {
1572 let old_text = snapshot.text();
1573 text_diff(old_text.as_str(), &content)
1574 .into_iter()
1575 .map(|(range, replacement)| {
1576 (
1577 snapshot.anchor_after(range.start)
1578 ..snapshot.anchor_before(range.end),
1579 replacement,
1580 )
1581 })
1582 .collect::<Vec<_>>()
1583 })
1584 .await;
1585 cx.update(|cx| {
1586 project.update(cx, |project, cx| {
1587 project.set_agent_location(
1588 Some(AgentLocation {
1589 buffer: buffer.downgrade(),
1590 position: edits
1591 .last()
1592 .map(|(range, _)| range.end)
1593 .unwrap_or(Anchor::MIN),
1594 }),
1595 cx,
1596 );
1597 });
1598
1599 action_log.update(cx, |action_log, cx| {
1600 action_log.buffer_read(buffer.clone(), cx);
1601 });
1602 buffer.update(cx, |buffer, cx| {
1603 buffer.edit(edits, None, cx);
1604 });
1605 action_log.update(cx, |action_log, cx| {
1606 action_log.buffer_edited(buffer.clone(), cx);
1607 });
1608 })?;
1609 project
1610 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1611 .await
1612 })
1613 }
1614
1615 pub fn to_markdown(&self, cx: &App) -> String {
1616 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1617 }
1618
1619 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1620 cx.emit(AcpThreadEvent::ServerExited(status));
1621 }
1622}
1623
1624fn markdown_for_raw_output(
1625 raw_output: &serde_json::Value,
1626 language_registry: &Arc<LanguageRegistry>,
1627 cx: &mut App,
1628) -> Option<Entity<Markdown>> {
1629 match raw_output {
1630 serde_json::Value::Null => None,
1631 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1632 Markdown::new(
1633 value.to_string().into(),
1634 Some(language_registry.clone()),
1635 None,
1636 cx,
1637 )
1638 })),
1639 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1640 Markdown::new(
1641 value.to_string().into(),
1642 Some(language_registry.clone()),
1643 None,
1644 cx,
1645 )
1646 })),
1647 serde_json::Value::String(value) => Some(cx.new(|cx| {
1648 Markdown::new(
1649 value.clone().into(),
1650 Some(language_registry.clone()),
1651 None,
1652 cx,
1653 )
1654 })),
1655 value => Some(cx.new(|cx| {
1656 Markdown::new(
1657 format!("```json\n{}\n```", value).into(),
1658 Some(language_registry.clone()),
1659 None,
1660 cx,
1661 )
1662 })),
1663 }
1664}
1665
1666#[cfg(test)]
1667mod tests {
1668 use super::*;
1669 use anyhow::anyhow;
1670 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1671 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1672 use indoc::indoc;
1673 use project::{FakeFs, Fs};
1674 use rand::Rng as _;
1675 use serde_json::json;
1676 use settings::SettingsStore;
1677 use smol::stream::StreamExt as _;
1678 use std::{
1679 any::Any,
1680 cell::RefCell,
1681 path::Path,
1682 rc::Rc,
1683 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1684 time::Duration,
1685 };
1686 use util::path;
1687
1688 fn init_test(cx: &mut TestAppContext) {
1689 env_logger::try_init().ok();
1690 cx.update(|cx| {
1691 let settings_store = SettingsStore::test(cx);
1692 cx.set_global(settings_store);
1693 Project::init_settings(cx);
1694 language::init(cx);
1695 });
1696 }
1697
1698 #[gpui::test]
1699 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1700 init_test(cx);
1701
1702 let fs = FakeFs::new(cx.executor());
1703 let project = Project::test(fs, [], cx).await;
1704 let connection = Rc::new(FakeAgentConnection::new());
1705 let thread = cx
1706 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1707 .await
1708 .unwrap();
1709
1710 // Test creating a new user message
1711 thread.update(cx, |thread, cx| {
1712 thread.push_user_content_block(
1713 None,
1714 acp::ContentBlock::Text(acp::TextContent {
1715 annotations: None,
1716 text: "Hello, ".to_string(),
1717 }),
1718 cx,
1719 );
1720 });
1721
1722 thread.update(cx, |thread, cx| {
1723 assert_eq!(thread.entries.len(), 1);
1724 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1725 assert_eq!(user_msg.id, None);
1726 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1727 } else {
1728 panic!("Expected UserMessage");
1729 }
1730 });
1731
1732 // Test appending to existing user message
1733 let message_1_id = UserMessageId::new();
1734 thread.update(cx, |thread, cx| {
1735 thread.push_user_content_block(
1736 Some(message_1_id.clone()),
1737 acp::ContentBlock::Text(acp::TextContent {
1738 annotations: None,
1739 text: "world!".to_string(),
1740 }),
1741 cx,
1742 );
1743 });
1744
1745 thread.update(cx, |thread, cx| {
1746 assert_eq!(thread.entries.len(), 1);
1747 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1748 assert_eq!(user_msg.id, Some(message_1_id));
1749 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1750 } else {
1751 panic!("Expected UserMessage");
1752 }
1753 });
1754
1755 // Test creating new user message after assistant message
1756 thread.update(cx, |thread, cx| {
1757 thread.push_assistant_content_block(
1758 acp::ContentBlock::Text(acp::TextContent {
1759 annotations: None,
1760 text: "Assistant response".to_string(),
1761 }),
1762 false,
1763 cx,
1764 );
1765 });
1766
1767 let message_2_id = UserMessageId::new();
1768 thread.update(cx, |thread, cx| {
1769 thread.push_user_content_block(
1770 Some(message_2_id.clone()),
1771 acp::ContentBlock::Text(acp::TextContent {
1772 annotations: None,
1773 text: "New user message".to_string(),
1774 }),
1775 cx,
1776 );
1777 });
1778
1779 thread.update(cx, |thread, cx| {
1780 assert_eq!(thread.entries.len(), 3);
1781 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1782 assert_eq!(user_msg.id, Some(message_2_id));
1783 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1784 } else {
1785 panic!("Expected UserMessage at index 2");
1786 }
1787 });
1788 }
1789
1790 #[gpui::test]
1791 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1792 init_test(cx);
1793
1794 let fs = FakeFs::new(cx.executor());
1795 let project = Project::test(fs, [], cx).await;
1796 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1797 |_, thread, mut cx| {
1798 async move {
1799 thread.update(&mut cx, |thread, cx| {
1800 thread
1801 .handle_session_update(
1802 acp::SessionUpdate::AgentThoughtChunk {
1803 content: "Thinking ".into(),
1804 },
1805 cx,
1806 )
1807 .unwrap();
1808 thread
1809 .handle_session_update(
1810 acp::SessionUpdate::AgentThoughtChunk {
1811 content: "hard!".into(),
1812 },
1813 cx,
1814 )
1815 .unwrap();
1816 })?;
1817 Ok(acp::PromptResponse {
1818 stop_reason: acp::StopReason::EndTurn,
1819 })
1820 }
1821 .boxed_local()
1822 },
1823 ));
1824
1825 let thread = cx
1826 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1827 .await
1828 .unwrap();
1829
1830 thread
1831 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1832 .await
1833 .unwrap();
1834
1835 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1836 assert_eq!(
1837 output,
1838 indoc! {r#"
1839 ## User
1840
1841 Hello from Zed!
1842
1843 ## Assistant
1844
1845 <thinking>
1846 Thinking hard!
1847 </thinking>
1848
1849 "#}
1850 );
1851 }
1852
1853 #[gpui::test]
1854 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1855 init_test(cx);
1856
1857 let fs = FakeFs::new(cx.executor());
1858 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1859 .await;
1860 let project = Project::test(fs.clone(), [], cx).await;
1861 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1862 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1863 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1864 move |_, thread, mut cx| {
1865 let read_file_tx = read_file_tx.clone();
1866 async move {
1867 let content = thread
1868 .update(&mut cx, |thread, cx| {
1869 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1870 })
1871 .unwrap()
1872 .await
1873 .unwrap();
1874 assert_eq!(content, "one\ntwo\nthree\n");
1875 read_file_tx.take().unwrap().send(()).unwrap();
1876 thread
1877 .update(&mut cx, |thread, cx| {
1878 thread.write_text_file(
1879 path!("/tmp/foo").into(),
1880 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1881 cx,
1882 )
1883 })
1884 .unwrap()
1885 .await
1886 .unwrap();
1887 Ok(acp::PromptResponse {
1888 stop_reason: acp::StopReason::EndTurn,
1889 })
1890 }
1891 .boxed_local()
1892 },
1893 ));
1894
1895 let (worktree, pathbuf) = project
1896 .update(cx, |project, cx| {
1897 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1898 })
1899 .await
1900 .unwrap();
1901 let buffer = project
1902 .update(cx, |project, cx| {
1903 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1904 })
1905 .await
1906 .unwrap();
1907
1908 let thread = cx
1909 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1910 .await
1911 .unwrap();
1912
1913 let request = thread.update(cx, |thread, cx| {
1914 thread.send_raw("Extend the count in /tmp/foo", cx)
1915 });
1916 read_file_rx.await.ok();
1917 buffer.update(cx, |buffer, cx| {
1918 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1919 });
1920 cx.run_until_parked();
1921 assert_eq!(
1922 buffer.read_with(cx, |buffer, _| buffer.text()),
1923 "zero\none\ntwo\nthree\nfour\nfive\n"
1924 );
1925 assert_eq!(
1926 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1927 "zero\none\ntwo\nthree\nfour\nfive\n"
1928 );
1929 request.await.unwrap();
1930 }
1931
1932 #[gpui::test]
1933 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1934 init_test(cx);
1935
1936 let fs = FakeFs::new(cx.executor());
1937 let project = Project::test(fs, [], cx).await;
1938 let id = acp::ToolCallId("test".into());
1939
1940 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1941 let id = id.clone();
1942 move |_, thread, mut cx| {
1943 let id = id.clone();
1944 async move {
1945 thread
1946 .update(&mut cx, |thread, cx| {
1947 thread.handle_session_update(
1948 acp::SessionUpdate::ToolCall(acp::ToolCall {
1949 id: id.clone(),
1950 title: "Label".into(),
1951 kind: acp::ToolKind::Fetch,
1952 status: acp::ToolCallStatus::InProgress,
1953 content: vec![],
1954 locations: vec![],
1955 raw_input: None,
1956 raw_output: None,
1957 }),
1958 cx,
1959 )
1960 })
1961 .unwrap()
1962 .unwrap();
1963 Ok(acp::PromptResponse {
1964 stop_reason: acp::StopReason::EndTurn,
1965 })
1966 }
1967 .boxed_local()
1968 }
1969 }));
1970
1971 let thread = cx
1972 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1973 .await
1974 .unwrap();
1975
1976 let request = thread.update(cx, |thread, cx| {
1977 thread.send_raw("Fetch https://example.com", cx)
1978 });
1979
1980 run_until_first_tool_call(&thread, cx).await;
1981
1982 thread.read_with(cx, |thread, _| {
1983 assert!(matches!(
1984 thread.entries[1],
1985 AgentThreadEntry::ToolCall(ToolCall {
1986 status: ToolCallStatus::InProgress,
1987 ..
1988 })
1989 ));
1990 });
1991
1992 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1993
1994 thread.read_with(cx, |thread, _| {
1995 assert!(matches!(
1996 &thread.entries[1],
1997 AgentThreadEntry::ToolCall(ToolCall {
1998 status: ToolCallStatus::Canceled,
1999 ..
2000 })
2001 ));
2002 });
2003
2004 thread
2005 .update(cx, |thread, cx| {
2006 thread.handle_session_update(
2007 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2008 id,
2009 fields: acp::ToolCallUpdateFields {
2010 status: Some(acp::ToolCallStatus::Completed),
2011 ..Default::default()
2012 },
2013 }),
2014 cx,
2015 )
2016 })
2017 .unwrap();
2018
2019 request.await.unwrap();
2020
2021 thread.read_with(cx, |thread, _| {
2022 assert!(matches!(
2023 thread.entries[1],
2024 AgentThreadEntry::ToolCall(ToolCall {
2025 status: ToolCallStatus::Completed,
2026 ..
2027 })
2028 ));
2029 });
2030 }
2031
2032 #[gpui::test]
2033 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2034 init_test(cx);
2035 let fs = FakeFs::new(cx.background_executor.clone());
2036 fs.insert_tree(path!("/test"), json!({})).await;
2037 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2038
2039 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2040 move |_, thread, mut cx| {
2041 async move {
2042 thread
2043 .update(&mut cx, |thread, cx| {
2044 thread.handle_session_update(
2045 acp::SessionUpdate::ToolCall(acp::ToolCall {
2046 id: acp::ToolCallId("test".into()),
2047 title: "Label".into(),
2048 kind: acp::ToolKind::Edit,
2049 status: acp::ToolCallStatus::Completed,
2050 content: vec![acp::ToolCallContent::Diff {
2051 diff: acp::Diff {
2052 path: "/test/test.txt".into(),
2053 old_text: None,
2054 new_text: "foo".into(),
2055 },
2056 }],
2057 locations: vec![],
2058 raw_input: None,
2059 raw_output: None,
2060 }),
2061 cx,
2062 )
2063 })
2064 .unwrap()
2065 .unwrap();
2066 Ok(acp::PromptResponse {
2067 stop_reason: acp::StopReason::EndTurn,
2068 })
2069 }
2070 .boxed_local()
2071 }
2072 }));
2073
2074 let thread = cx
2075 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2076 .await
2077 .unwrap();
2078
2079 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2080 .await
2081 .unwrap();
2082
2083 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2084 }
2085
2086 #[gpui::test(iterations = 10)]
2087 async fn test_checkpoints(cx: &mut TestAppContext) {
2088 init_test(cx);
2089 let fs = FakeFs::new(cx.background_executor.clone());
2090 fs.insert_tree(
2091 path!("/test"),
2092 json!({
2093 ".git": {}
2094 }),
2095 )
2096 .await;
2097 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2098
2099 let simulate_changes = Arc::new(AtomicBool::new(true));
2100 let next_filename = Arc::new(AtomicUsize::new(0));
2101 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2102 let simulate_changes = simulate_changes.clone();
2103 let next_filename = next_filename.clone();
2104 let fs = fs.clone();
2105 move |request, thread, mut cx| {
2106 let fs = fs.clone();
2107 let simulate_changes = simulate_changes.clone();
2108 let next_filename = next_filename.clone();
2109 async move {
2110 if simulate_changes.load(SeqCst) {
2111 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2112 fs.write(Path::new(&filename), b"").await?;
2113 }
2114
2115 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2116 panic!("expected text content block");
2117 };
2118 thread.update(&mut cx, |thread, cx| {
2119 thread
2120 .handle_session_update(
2121 acp::SessionUpdate::AgentMessageChunk {
2122 content: content.text.to_uppercase().into(),
2123 },
2124 cx,
2125 )
2126 .unwrap();
2127 })?;
2128 Ok(acp::PromptResponse {
2129 stop_reason: acp::StopReason::EndTurn,
2130 })
2131 }
2132 .boxed_local()
2133 }
2134 }));
2135 let thread = cx
2136 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2137 .await
2138 .unwrap();
2139
2140 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2141 .await
2142 .unwrap();
2143 thread.read_with(cx, |thread, cx| {
2144 assert_eq!(
2145 thread.to_markdown(cx),
2146 indoc! {"
2147 ## User (checkpoint)
2148
2149 Lorem
2150
2151 ## Assistant
2152
2153 LOREM
2154
2155 "}
2156 );
2157 });
2158 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2159
2160 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2161 .await
2162 .unwrap();
2163 thread.read_with(cx, |thread, cx| {
2164 assert_eq!(
2165 thread.to_markdown(cx),
2166 indoc! {"
2167 ## User (checkpoint)
2168
2169 Lorem
2170
2171 ## Assistant
2172
2173 LOREM
2174
2175 ## User (checkpoint)
2176
2177 ipsum
2178
2179 ## Assistant
2180
2181 IPSUM
2182
2183 "}
2184 );
2185 });
2186 assert_eq!(
2187 fs.files(),
2188 vec![
2189 Path::new(path!("/test/file-0")),
2190 Path::new(path!("/test/file-1"))
2191 ]
2192 );
2193
2194 // Checkpoint isn't stored when there are no changes.
2195 simulate_changes.store(false, SeqCst);
2196 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2197 .await
2198 .unwrap();
2199 thread.read_with(cx, |thread, cx| {
2200 assert_eq!(
2201 thread.to_markdown(cx),
2202 indoc! {"
2203 ## User (checkpoint)
2204
2205 Lorem
2206
2207 ## Assistant
2208
2209 LOREM
2210
2211 ## User (checkpoint)
2212
2213 ipsum
2214
2215 ## Assistant
2216
2217 IPSUM
2218
2219 ## User
2220
2221 dolor
2222
2223 ## Assistant
2224
2225 DOLOR
2226
2227 "}
2228 );
2229 });
2230 assert_eq!(
2231 fs.files(),
2232 vec![
2233 Path::new(path!("/test/file-0")),
2234 Path::new(path!("/test/file-1"))
2235 ]
2236 );
2237
2238 // Rewinding the conversation truncates the history and restores the checkpoint.
2239 thread
2240 .update(cx, |thread, cx| {
2241 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2242 panic!("unexpected entries {:?}", thread.entries)
2243 };
2244 thread.rewind(message.id.clone().unwrap(), cx)
2245 })
2246 .await
2247 .unwrap();
2248 thread.read_with(cx, |thread, cx| {
2249 assert_eq!(
2250 thread.to_markdown(cx),
2251 indoc! {"
2252 ## User (checkpoint)
2253
2254 Lorem
2255
2256 ## Assistant
2257
2258 LOREM
2259
2260 "}
2261 );
2262 });
2263 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2264 }
2265
2266 async fn run_until_first_tool_call(
2267 thread: &Entity<AcpThread>,
2268 cx: &mut TestAppContext,
2269 ) -> usize {
2270 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2271
2272 let subscription = cx.update(|cx| {
2273 cx.subscribe(thread, move |thread, _, cx| {
2274 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2275 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2276 return tx.try_send(ix).unwrap();
2277 }
2278 }
2279 })
2280 });
2281
2282 select! {
2283 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2284 panic!("Timeout waiting for tool call")
2285 }
2286 ix = rx.next().fuse() => {
2287 drop(subscription);
2288 ix.unwrap()
2289 }
2290 }
2291 }
2292
2293 #[derive(Clone, Default)]
2294 struct FakeAgentConnection {
2295 auth_methods: Vec<acp::AuthMethod>,
2296 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2297 on_user_message: Option<
2298 Rc<
2299 dyn Fn(
2300 acp::PromptRequest,
2301 WeakEntity<AcpThread>,
2302 AsyncApp,
2303 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2304 + 'static,
2305 >,
2306 >,
2307 }
2308
2309 impl FakeAgentConnection {
2310 fn new() -> Self {
2311 Self {
2312 auth_methods: Vec::new(),
2313 on_user_message: None,
2314 sessions: Arc::default(),
2315 }
2316 }
2317
2318 #[expect(unused)]
2319 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2320 self.auth_methods = auth_methods;
2321 self
2322 }
2323
2324 fn on_user_message(
2325 mut self,
2326 handler: impl Fn(
2327 acp::PromptRequest,
2328 WeakEntity<AcpThread>,
2329 AsyncApp,
2330 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2331 + 'static,
2332 ) -> Self {
2333 self.on_user_message.replace(Rc::new(handler));
2334 self
2335 }
2336 }
2337
2338 impl AgentConnection for FakeAgentConnection {
2339 fn auth_methods(&self) -> &[acp::AuthMethod] {
2340 &self.auth_methods
2341 }
2342
2343 fn new_thread(
2344 self: Rc<Self>,
2345 project: Entity<Project>,
2346 _cwd: &Path,
2347 cx: &mut App,
2348 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2349 let session_id = acp::SessionId(
2350 rand::thread_rng()
2351 .sample_iter(&rand::distributions::Alphanumeric)
2352 .take(7)
2353 .map(char::from)
2354 .collect::<String>()
2355 .into(),
2356 );
2357 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2358 let thread = cx.new(|_cx| {
2359 AcpThread::new(
2360 "Test",
2361 self.clone(),
2362 project,
2363 action_log,
2364 session_id.clone(),
2365 )
2366 });
2367 self.sessions.lock().insert(session_id, thread.downgrade());
2368 Task::ready(Ok(thread))
2369 }
2370
2371 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2372 if self.auth_methods().iter().any(|m| m.id == method) {
2373 Task::ready(Ok(()))
2374 } else {
2375 Task::ready(Err(anyhow!("Invalid Auth Method")))
2376 }
2377 }
2378
2379 fn prompt(
2380 &self,
2381 _id: Option<UserMessageId>,
2382 params: acp::PromptRequest,
2383 cx: &mut App,
2384 ) -> Task<gpui::Result<acp::PromptResponse>> {
2385 let sessions = self.sessions.lock();
2386 let thread = sessions.get(¶ms.session_id).unwrap();
2387 if let Some(handler) = &self.on_user_message {
2388 let handler = handler.clone();
2389 let thread = thread.clone();
2390 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2391 } else {
2392 Task::ready(Ok(acp::PromptResponse {
2393 stop_reason: acp::StopReason::EndTurn,
2394 }))
2395 }
2396 }
2397
2398 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2399 let sessions = self.sessions.lock();
2400 let thread = sessions.get(session_id).unwrap().clone();
2401
2402 cx.spawn(async move |cx| {
2403 thread
2404 .update(cx, |thread, cx| thread.cancel(cx))
2405 .unwrap()
2406 .await
2407 })
2408 .detach();
2409 }
2410
2411 fn session_editor(
2412 &self,
2413 session_id: &acp::SessionId,
2414 _cx: &mut App,
2415 ) -> Option<Rc<dyn AgentSessionEditor>> {
2416 Some(Rc::new(FakeAgentSessionEditor {
2417 _session_id: session_id.clone(),
2418 }))
2419 }
2420
2421 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2422 self
2423 }
2424 }
2425
2426 struct FakeAgentSessionEditor {
2427 _session_id: acp::SessionId,
2428 }
2429
2430 impl AgentSessionEditor for FakeAgentSessionEditor {
2431 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2432 Task::ready(Ok(()))
2433 }
2434 }
2435}