1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub id: Option<UserMessageId>,
35 pub content: ContentBlock,
36 pub chunks: Vec<acp::ContentBlock>,
37 pub checkpoint: Option<Checkpoint>,
38}
39
40#[derive(Debug)]
41pub struct Checkpoint {
42 git_checkpoint: GitStoreCheckpoint,
43 pub show: bool,
44}
45
46impl UserMessage {
47 fn to_markdown(&self, cx: &App) -> String {
48 let mut markdown = String::new();
49 if self
50 .checkpoint
51 .as_ref()
52 .map_or(false, |checkpoint| checkpoint.show)
53 {
54 writeln!(markdown, "## User (checkpoint)").unwrap();
55 } else {
56 writeln!(markdown, "## User").unwrap();
57 }
58 writeln!(markdown).unwrap();
59 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
60 writeln!(markdown).unwrap();
61 markdown
62 }
63}
64
65#[derive(Debug, PartialEq)]
66pub struct AssistantMessage {
67 pub chunks: Vec<AssistantMessageChunk>,
68}
69
70impl AssistantMessage {
71 pub fn to_markdown(&self, cx: &App) -> String {
72 format!(
73 "## Assistant\n\n{}\n\n",
74 self.chunks
75 .iter()
76 .map(|chunk| chunk.to_markdown(cx))
77 .join("\n\n")
78 )
79 }
80}
81
82#[derive(Debug, PartialEq)]
83pub enum AssistantMessageChunk {
84 Message { block: ContentBlock },
85 Thought { block: ContentBlock },
86}
87
88impl AssistantMessageChunk {
89 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
90 Self::Message {
91 block: ContentBlock::new(chunk.into(), language_registry, cx),
92 }
93 }
94
95 fn to_markdown(&self, cx: &App) -> String {
96 match self {
97 Self::Message { block } => block.to_markdown(cx).to_string(),
98 Self::Thought { block } => {
99 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum AgentThreadEntry {
107 UserMessage(UserMessage),
108 AssistantMessage(AssistantMessage),
109 ToolCall(ToolCall),
110}
111
112impl AgentThreadEntry {
113 pub fn to_markdown(&self, cx: &App) -> String {
114 match self {
115 Self::UserMessage(message) => message.to_markdown(cx),
116 Self::AssistantMessage(message) => message.to_markdown(cx),
117 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
118 }
119 }
120
121 pub fn user_message(&self) -> Option<&UserMessage> {
122 if let AgentThreadEntry::UserMessage(message) = self {
123 Some(message)
124 } else {
125 None
126 }
127 }
128
129 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
130 if let AgentThreadEntry::ToolCall(call) = self {
131 itertools::Either::Left(call.diffs())
132 } else {
133 itertools::Either::Right(std::iter::empty())
134 }
135 }
136
137 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.terminals())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
146 if let AgentThreadEntry::ToolCall(ToolCall {
147 locations,
148 resolved_locations,
149 ..
150 }) = self
151 {
152 Some((
153 locations.get(ix)?.clone(),
154 resolved_locations.get(ix)?.clone()?,
155 ))
156 } else {
157 None
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct ToolCall {
164 pub id: acp::ToolCallId,
165 pub label: Entity<Markdown>,
166 pub kind: acp::ToolKind,
167 pub content: Vec<ToolCallContent>,
168 pub status: ToolCallStatus,
169 pub locations: Vec<acp::ToolCallLocation>,
170 pub resolved_locations: Vec<Option<AgentLocation>>,
171 pub raw_input: Option<serde_json::Value>,
172 pub raw_output: Option<serde_json::Value>,
173}
174
175impl ToolCall {
176 fn from_acp(
177 tool_call: acp::ToolCall,
178 status: ToolCallStatus,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) -> Self {
182 Self {
183 id: tool_call.id,
184 label: cx.new(|cx| {
185 Markdown::new(
186 tool_call.title.into(),
187 Some(language_registry.clone()),
188 None,
189 cx,
190 )
191 }),
192 kind: tool_call.kind,
193 content: tool_call
194 .content
195 .into_iter()
196 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
197 .collect(),
198 locations: tool_call.locations,
199 resolved_locations: Vec::default(),
200 status,
201 raw_input: tool_call.raw_input,
202 raw_output: tool_call.raw_output,
203 }
204 }
205
206 fn update_fields(
207 &mut self,
208 fields: acp::ToolCallUpdateFields,
209 language_registry: Arc<LanguageRegistry>,
210 cx: &mut App,
211 ) {
212 let acp::ToolCallUpdateFields {
213 kind,
214 status,
215 title,
216 content,
217 locations,
218 raw_input,
219 raw_output,
220 } = fields;
221
222 if let Some(kind) = kind {
223 self.kind = kind;
224 }
225
226 if let Some(status) = status {
227 self.status = status.into();
228 }
229
230 if let Some(title) = title {
231 self.label.update(cx, |label, cx| {
232 label.replace(title, cx);
233 });
234 }
235
236 if let Some(content) = content {
237 self.content = content
238 .into_iter()
239 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
240 .collect();
241 }
242
243 if let Some(locations) = locations {
244 self.locations = locations;
245 }
246
247 if let Some(raw_input) = raw_input {
248 self.raw_input = Some(raw_input);
249 }
250
251 if let Some(raw_output) = raw_output {
252 if self.content.is_empty()
253 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
254 {
255 self.content
256 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
257 markdown,
258 }));
259 }
260 self.raw_output = Some(raw_output);
261 }
262 }
263
264 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
265 self.content.iter().filter_map(|content| match content {
266 ToolCallContent::Diff(diff) => Some(diff),
267 ToolCallContent::ContentBlock(_) => None,
268 ToolCallContent::Terminal(_) => None,
269 })
270 }
271
272 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
273 self.content.iter().filter_map(|content| match content {
274 ToolCallContent::Terminal(terminal) => Some(terminal),
275 ToolCallContent::ContentBlock(_) => None,
276 ToolCallContent::Diff(_) => None,
277 })
278 }
279
280 fn to_markdown(&self, cx: &App) -> String {
281 let mut markdown = format!(
282 "**Tool Call: {}**\nStatus: {}\n\n",
283 self.label.read(cx).source(),
284 self.status
285 );
286 for content in &self.content {
287 markdown.push_str(content.to_markdown(cx).as_str());
288 markdown.push_str("\n\n");
289 }
290 markdown
291 }
292
293 async fn resolve_location(
294 location: acp::ToolCallLocation,
295 project: WeakEntity<Project>,
296 cx: &mut AsyncApp,
297 ) -> Option<AgentLocation> {
298 let buffer = project
299 .update(cx, |project, cx| {
300 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
301 Some(project.open_buffer(path, cx))
302 } else {
303 None
304 }
305 })
306 .ok()??;
307 let buffer = buffer.await.log_err()?;
308 let position = buffer
309 .update(cx, |buffer, _| {
310 if let Some(row) = location.line {
311 let snapshot = buffer.snapshot();
312 let column = snapshot.indent_size_for_line(row).len;
313 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
314 snapshot.anchor_before(point)
315 } else {
316 Anchor::MIN
317 }
318 })
319 .ok()?;
320
321 Some(AgentLocation {
322 buffer: buffer.downgrade(),
323 position,
324 })
325 }
326
327 fn resolve_locations(
328 &self,
329 project: Entity<Project>,
330 cx: &mut App,
331 ) -> Task<Vec<Option<AgentLocation>>> {
332 let locations = self.locations.clone();
333 project.update(cx, |_, cx| {
334 cx.spawn(async move |project, cx| {
335 let mut new_locations = Vec::new();
336 for location in locations {
337 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
338 }
339 new_locations
340 })
341 })
342 }
343}
344
345#[derive(Debug)]
346pub enum ToolCallStatus {
347 /// The tool call hasn't started running yet, but we start showing it to
348 /// the user.
349 Pending,
350 /// The tool call is waiting for confirmation from the user.
351 WaitingForConfirmation {
352 options: Vec<acp::PermissionOption>,
353 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
354 },
355 /// The tool call is currently running.
356 InProgress,
357 /// The tool call completed successfully.
358 Completed,
359 /// The tool call failed.
360 Failed,
361 /// The user rejected the tool call.
362 Rejected,
363 /// The user canceled generation so the tool call was canceled.
364 Canceled,
365}
366
367impl From<acp::ToolCallStatus> for ToolCallStatus {
368 fn from(status: acp::ToolCallStatus) -> Self {
369 match status {
370 acp::ToolCallStatus::Pending => Self::Pending,
371 acp::ToolCallStatus::InProgress => Self::InProgress,
372 acp::ToolCallStatus::Completed => Self::Completed,
373 acp::ToolCallStatus::Failed => Self::Failed,
374 }
375 }
376}
377
378impl Display for ToolCallStatus {
379 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
380 write!(
381 f,
382 "{}",
383 match self {
384 ToolCallStatus::Pending => "Pending",
385 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
386 ToolCallStatus::InProgress => "In Progress",
387 ToolCallStatus::Completed => "Completed",
388 ToolCallStatus::Failed => "Failed",
389 ToolCallStatus::Rejected => "Rejected",
390 ToolCallStatus::Canceled => "Canceled",
391 }
392 )
393 }
394}
395
396#[derive(Debug, PartialEq, Clone)]
397pub enum ContentBlock {
398 Empty,
399 Markdown { markdown: Entity<Markdown> },
400 ResourceLink { resource_link: acp::ResourceLink },
401}
402
403impl ContentBlock {
404 pub fn new(
405 block: acp::ContentBlock,
406 language_registry: &Arc<LanguageRegistry>,
407 cx: &mut App,
408 ) -> Self {
409 let mut this = Self::Empty;
410 this.append(block, language_registry, cx);
411 this
412 }
413
414 pub fn new_combined(
415 blocks: impl IntoIterator<Item = acp::ContentBlock>,
416 language_registry: Arc<LanguageRegistry>,
417 cx: &mut App,
418 ) -> Self {
419 let mut this = Self::Empty;
420 for block in blocks {
421 this.append(block, &language_registry, cx);
422 }
423 this
424 }
425
426 pub fn append(
427 &mut self,
428 block: acp::ContentBlock,
429 language_registry: &Arc<LanguageRegistry>,
430 cx: &mut App,
431 ) {
432 if matches!(self, ContentBlock::Empty)
433 && let acp::ContentBlock::ResourceLink(resource_link) = block
434 {
435 *self = ContentBlock::ResourceLink { resource_link };
436 return;
437 }
438
439 let new_content = self.block_string_contents(block);
440
441 match self {
442 ContentBlock::Empty => {
443 *self = Self::create_markdown_block(new_content, language_registry, cx);
444 }
445 ContentBlock::Markdown { markdown } => {
446 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
447 }
448 ContentBlock::ResourceLink { resource_link } => {
449 let existing_content = Self::resource_link_md(&resource_link.uri);
450 let combined = format!("{}\n{}", existing_content, new_content);
451
452 *self = Self::create_markdown_block(combined, language_registry, cx);
453 }
454 }
455 }
456
457 fn create_markdown_block(
458 content: String,
459 language_registry: &Arc<LanguageRegistry>,
460 cx: &mut App,
461 ) -> ContentBlock {
462 ContentBlock::Markdown {
463 markdown: cx
464 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
465 }
466 }
467
468 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
469 match block {
470 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
471 acp::ContentBlock::ResourceLink(resource_link) => {
472 Self::resource_link_md(&resource_link.uri)
473 }
474 acp::ContentBlock::Resource(acp::EmbeddedResource {
475 resource:
476 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
477 uri,
478 ..
479 }),
480 ..
481 }) => Self::resource_link_md(&uri),
482 acp::ContentBlock::Image(image) => Self::image_md(&image),
483 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
484 }
485 }
486
487 fn resource_link_md(uri: &str) -> String {
488 if let Some(uri) = MentionUri::parse(uri).log_err() {
489 uri.as_link().to_string()
490 } else {
491 uri.to_string()
492 }
493 }
494
495 fn image_md(_image: &acp::ImageContent) -> String {
496 "`Image`".into()
497 }
498
499 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
500 match self {
501 ContentBlock::Empty => "",
502 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
503 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
504 }
505 }
506
507 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
508 match self {
509 ContentBlock::Empty => None,
510 ContentBlock::Markdown { markdown } => Some(markdown),
511 ContentBlock::ResourceLink { .. } => None,
512 }
513 }
514
515 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
516 match self {
517 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
518 _ => None,
519 }
520 }
521}
522
523#[derive(Debug)]
524pub enum ToolCallContent {
525 ContentBlock(ContentBlock),
526 Diff(Entity<Diff>),
527 Terminal(Entity<Terminal>),
528}
529
530impl ToolCallContent {
531 pub fn from_acp(
532 content: acp::ToolCallContent,
533 language_registry: Arc<LanguageRegistry>,
534 cx: &mut App,
535 ) -> Self {
536 match content {
537 acp::ToolCallContent::Content { content } => {
538 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
539 }
540 acp::ToolCallContent::Diff { diff } => {
541 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
542 }
543 }
544 }
545
546 pub fn to_markdown(&self, cx: &App) -> String {
547 match self {
548 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
549 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
550 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
551 }
552 }
553}
554
555#[derive(Debug, PartialEq)]
556pub enum ToolCallUpdate {
557 UpdateFields(acp::ToolCallUpdate),
558 UpdateDiff(ToolCallUpdateDiff),
559 UpdateTerminal(ToolCallUpdateTerminal),
560}
561
562impl ToolCallUpdate {
563 fn id(&self) -> &acp::ToolCallId {
564 match self {
565 Self::UpdateFields(update) => &update.id,
566 Self::UpdateDiff(diff) => &diff.id,
567 Self::UpdateTerminal(terminal) => &terminal.id,
568 }
569 }
570}
571
572impl From<acp::ToolCallUpdate> for ToolCallUpdate {
573 fn from(update: acp::ToolCallUpdate) -> Self {
574 Self::UpdateFields(update)
575 }
576}
577
578impl From<ToolCallUpdateDiff> for ToolCallUpdate {
579 fn from(diff: ToolCallUpdateDiff) -> Self {
580 Self::UpdateDiff(diff)
581 }
582}
583
584#[derive(Debug, PartialEq)]
585pub struct ToolCallUpdateDiff {
586 pub id: acp::ToolCallId,
587 pub diff: Entity<Diff>,
588}
589
590impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
591 fn from(terminal: ToolCallUpdateTerminal) -> Self {
592 Self::UpdateTerminal(terminal)
593 }
594}
595
596#[derive(Debug, PartialEq)]
597pub struct ToolCallUpdateTerminal {
598 pub id: acp::ToolCallId,
599 pub terminal: Entity<Terminal>,
600}
601
602#[derive(Debug, Default)]
603pub struct Plan {
604 pub entries: Vec<PlanEntry>,
605}
606
607#[derive(Debug)]
608pub struct PlanStats<'a> {
609 pub in_progress_entry: Option<&'a PlanEntry>,
610 pub pending: u32,
611 pub completed: u32,
612}
613
614impl Plan {
615 pub fn is_empty(&self) -> bool {
616 self.entries.is_empty()
617 }
618
619 pub fn stats(&self) -> PlanStats<'_> {
620 let mut stats = PlanStats {
621 in_progress_entry: None,
622 pending: 0,
623 completed: 0,
624 };
625
626 for entry in &self.entries {
627 match &entry.status {
628 acp::PlanEntryStatus::Pending => {
629 stats.pending += 1;
630 }
631 acp::PlanEntryStatus::InProgress => {
632 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
633 }
634 acp::PlanEntryStatus::Completed => {
635 stats.completed += 1;
636 }
637 }
638 }
639
640 stats
641 }
642}
643
644#[derive(Debug)]
645pub struct PlanEntry {
646 pub content: Entity<Markdown>,
647 pub priority: acp::PlanEntryPriority,
648 pub status: acp::PlanEntryStatus,
649}
650
651impl PlanEntry {
652 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
653 Self {
654 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
655 priority: entry.priority,
656 status: entry.status,
657 }
658 }
659}
660
661#[derive(Debug, Clone)]
662pub struct RetryStatus {
663 pub last_error: SharedString,
664 pub attempt: usize,
665 pub max_attempts: usize,
666 pub started_at: Instant,
667 pub duration: Duration,
668}
669
670pub struct AcpThread {
671 title: SharedString,
672 entries: Vec<AgentThreadEntry>,
673 plan: Plan,
674 project: Entity<Project>,
675 action_log: Entity<ActionLog>,
676 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
677 send_task: Option<Task<()>>,
678 connection: Rc<dyn AgentConnection>,
679 session_id: acp::SessionId,
680}
681
682#[derive(Debug)]
683pub enum AcpThreadEvent {
684 NewEntry,
685 EntryUpdated(usize),
686 EntriesRemoved(Range<usize>),
687 ToolAuthorizationRequired,
688 Retry(RetryStatus),
689 Stopped,
690 Error,
691 ServerExited(ExitStatus),
692}
693
694impl EventEmitter<AcpThreadEvent> for AcpThread {}
695
696#[derive(PartialEq, Eq)]
697pub enum ThreadStatus {
698 Idle,
699 WaitingForToolConfirmation,
700 Generating,
701}
702
703#[derive(Debug, Clone)]
704pub enum LoadError {
705 Unsupported {
706 error_message: SharedString,
707 upgrade_message: SharedString,
708 upgrade_command: String,
709 },
710 Exited(i32),
711 Other(SharedString),
712}
713
714impl Display for LoadError {
715 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
716 match self {
717 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
718 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
719 LoadError::Other(msg) => write!(f, "{}", msg),
720 }
721 }
722}
723
724impl Error for LoadError {}
725
726impl AcpThread {
727 pub fn new(
728 title: impl Into<SharedString>,
729 connection: Rc<dyn AgentConnection>,
730 project: Entity<Project>,
731 session_id: acp::SessionId,
732 cx: &mut Context<Self>,
733 ) -> Self {
734 let action_log = cx.new(|_| ActionLog::new(project.clone()));
735
736 Self {
737 action_log,
738 shared_buffers: Default::default(),
739 entries: Default::default(),
740 plan: Default::default(),
741 title: title.into(),
742 project,
743 send_task: None,
744 connection,
745 session_id,
746 }
747 }
748
749 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
750 &self.connection
751 }
752
753 pub fn action_log(&self) -> &Entity<ActionLog> {
754 &self.action_log
755 }
756
757 pub fn project(&self) -> &Entity<Project> {
758 &self.project
759 }
760
761 pub fn title(&self) -> SharedString {
762 self.title.clone()
763 }
764
765 pub fn entries(&self) -> &[AgentThreadEntry] {
766 &self.entries
767 }
768
769 pub fn session_id(&self) -> &acp::SessionId {
770 &self.session_id
771 }
772
773 pub fn status(&self) -> ThreadStatus {
774 if self.send_task.is_some() {
775 if self.waiting_for_tool_confirmation() {
776 ThreadStatus::WaitingForToolConfirmation
777 } else {
778 ThreadStatus::Generating
779 }
780 } else {
781 ThreadStatus::Idle
782 }
783 }
784
785 pub fn has_pending_edit_tool_calls(&self) -> bool {
786 for entry in self.entries.iter().rev() {
787 match entry {
788 AgentThreadEntry::UserMessage(_) => return false,
789 AgentThreadEntry::ToolCall(
790 call @ ToolCall {
791 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
792 ..
793 },
794 ) if call.diffs().next().is_some() => {
795 return true;
796 }
797 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
798 }
799 }
800
801 false
802 }
803
804 pub fn used_tools_since_last_user_message(&self) -> bool {
805 for entry in self.entries.iter().rev() {
806 match entry {
807 AgentThreadEntry::UserMessage(..) => return false,
808 AgentThreadEntry::AssistantMessage(..) => continue,
809 AgentThreadEntry::ToolCall(..) => return true,
810 }
811 }
812
813 false
814 }
815
816 pub fn handle_session_update(
817 &mut self,
818 update: acp::SessionUpdate,
819 cx: &mut Context<Self>,
820 ) -> Result<(), acp::Error> {
821 match update {
822 acp::SessionUpdate::UserMessageChunk { content } => {
823 self.push_user_content_block(None, content, cx);
824 }
825 acp::SessionUpdate::AgentMessageChunk { content } => {
826 self.push_assistant_content_block(content, false, cx);
827 }
828 acp::SessionUpdate::AgentThoughtChunk { content } => {
829 self.push_assistant_content_block(content, true, cx);
830 }
831 acp::SessionUpdate::ToolCall(tool_call) => {
832 self.upsert_tool_call(tool_call, cx)?;
833 }
834 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
835 self.update_tool_call(tool_call_update, cx)?;
836 }
837 acp::SessionUpdate::Plan(plan) => {
838 self.update_plan(plan, cx);
839 }
840 }
841 Ok(())
842 }
843
844 pub fn push_user_content_block(
845 &mut self,
846 message_id: Option<UserMessageId>,
847 chunk: acp::ContentBlock,
848 cx: &mut Context<Self>,
849 ) {
850 let language_registry = self.project.read(cx).languages().clone();
851 let entries_len = self.entries.len();
852
853 if let Some(last_entry) = self.entries.last_mut()
854 && let AgentThreadEntry::UserMessage(UserMessage {
855 id,
856 content,
857 chunks,
858 ..
859 }) = last_entry
860 {
861 *id = message_id.or(id.take());
862 content.append(chunk.clone(), &language_registry, cx);
863 chunks.push(chunk);
864 let idx = entries_len - 1;
865 cx.emit(AcpThreadEvent::EntryUpdated(idx));
866 } else {
867 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
868 self.push_entry(
869 AgentThreadEntry::UserMessage(UserMessage {
870 id: message_id,
871 content,
872 chunks: vec![chunk],
873 checkpoint: None,
874 }),
875 cx,
876 );
877 }
878 }
879
880 pub fn push_assistant_content_block(
881 &mut self,
882 chunk: acp::ContentBlock,
883 is_thought: bool,
884 cx: &mut Context<Self>,
885 ) {
886 let language_registry = self.project.read(cx).languages().clone();
887 let entries_len = self.entries.len();
888 if let Some(last_entry) = self.entries.last_mut()
889 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
890 {
891 let idx = entries_len - 1;
892 cx.emit(AcpThreadEvent::EntryUpdated(idx));
893 match (chunks.last_mut(), is_thought) {
894 (Some(AssistantMessageChunk::Message { block }), false)
895 | (Some(AssistantMessageChunk::Thought { block }), true) => {
896 block.append(chunk, &language_registry, cx)
897 }
898 _ => {
899 let block = ContentBlock::new(chunk, &language_registry, cx);
900 if is_thought {
901 chunks.push(AssistantMessageChunk::Thought { block })
902 } else {
903 chunks.push(AssistantMessageChunk::Message { block })
904 }
905 }
906 }
907 } else {
908 let block = ContentBlock::new(chunk, &language_registry, cx);
909 let chunk = if is_thought {
910 AssistantMessageChunk::Thought { block }
911 } else {
912 AssistantMessageChunk::Message { block }
913 };
914
915 self.push_entry(
916 AgentThreadEntry::AssistantMessage(AssistantMessage {
917 chunks: vec![chunk],
918 }),
919 cx,
920 );
921 }
922 }
923
924 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
925 self.entries.push(entry);
926 cx.emit(AcpThreadEvent::NewEntry);
927 }
928
929 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
930 cx.emit(AcpThreadEvent::Retry(status));
931 }
932
933 pub fn update_tool_call(
934 &mut self,
935 update: impl Into<ToolCallUpdate>,
936 cx: &mut Context<Self>,
937 ) -> Result<()> {
938 let update = update.into();
939 let languages = self.project.read(cx).languages().clone();
940
941 let (ix, current_call) = self
942 .tool_call_mut(update.id())
943 .context("Tool call not found")?;
944 match update {
945 ToolCallUpdate::UpdateFields(update) => {
946 let location_updated = update.fields.locations.is_some();
947 current_call.update_fields(update.fields, languages, cx);
948 if location_updated {
949 self.resolve_locations(update.id.clone(), cx);
950 }
951 }
952 ToolCallUpdate::UpdateDiff(update) => {
953 current_call.content.clear();
954 current_call
955 .content
956 .push(ToolCallContent::Diff(update.diff));
957 }
958 ToolCallUpdate::UpdateTerminal(update) => {
959 current_call.content.clear();
960 current_call
961 .content
962 .push(ToolCallContent::Terminal(update.terminal));
963 }
964 }
965
966 cx.emit(AcpThreadEvent::EntryUpdated(ix));
967
968 Ok(())
969 }
970
971 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
972 pub fn upsert_tool_call(
973 &mut self,
974 tool_call: acp::ToolCall,
975 cx: &mut Context<Self>,
976 ) -> Result<(), acp::Error> {
977 let status = tool_call.status.into();
978 self.upsert_tool_call_inner(tool_call.into(), status, cx)
979 }
980
981 /// Fails if id does not match an existing entry.
982 pub fn upsert_tool_call_inner(
983 &mut self,
984 tool_call_update: acp::ToolCallUpdate,
985 status: ToolCallStatus,
986 cx: &mut Context<Self>,
987 ) -> Result<(), acp::Error> {
988 let language_registry = self.project.read(cx).languages().clone();
989 let id = tool_call_update.id.clone();
990
991 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
992 current_call.update_fields(tool_call_update.fields, language_registry, cx);
993 current_call.status = status;
994
995 cx.emit(AcpThreadEvent::EntryUpdated(ix));
996 } else {
997 let call =
998 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
999 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1000 };
1001
1002 self.resolve_locations(id, cx);
1003 Ok(())
1004 }
1005
1006 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1007 // The tool call we are looking for is typically the last one, or very close to the end.
1008 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1009 self.entries
1010 .iter_mut()
1011 .enumerate()
1012 .rev()
1013 .find_map(|(index, tool_call)| {
1014 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1015 && &tool_call.id == id
1016 {
1017 Some((index, tool_call))
1018 } else {
1019 None
1020 }
1021 })
1022 }
1023
1024 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1025 let project = self.project.clone();
1026 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1027 return;
1028 };
1029 let task = tool_call.resolve_locations(project, cx);
1030 cx.spawn(async move |this, cx| {
1031 let resolved_locations = task.await;
1032 this.update(cx, |this, cx| {
1033 let project = this.project.clone();
1034 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1035 return;
1036 };
1037 if let Some(Some(location)) = resolved_locations.last() {
1038 project.update(cx, |project, cx| {
1039 if let Some(agent_location) = project.agent_location() {
1040 let should_ignore = agent_location.buffer == location.buffer
1041 && location
1042 .buffer
1043 .update(cx, |buffer, _| {
1044 let snapshot = buffer.snapshot();
1045 let old_position =
1046 agent_location.position.to_point(&snapshot);
1047 let new_position = location.position.to_point(&snapshot);
1048 // ignore this so that when we get updates from the edit tool
1049 // the position doesn't reset to the startof line
1050 old_position.row == new_position.row
1051 && old_position.column > new_position.column
1052 })
1053 .ok()
1054 .unwrap_or_default();
1055 if !should_ignore {
1056 project.set_agent_location(Some(location.clone()), cx);
1057 }
1058 }
1059 });
1060 }
1061 if tool_call.resolved_locations != resolved_locations {
1062 tool_call.resolved_locations = resolved_locations;
1063 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1064 }
1065 })
1066 })
1067 .detach();
1068 }
1069
1070 pub fn request_tool_call_authorization(
1071 &mut self,
1072 tool_call: acp::ToolCallUpdate,
1073 options: Vec<acp::PermissionOption>,
1074 cx: &mut Context<Self>,
1075 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1076 let (tx, rx) = oneshot::channel();
1077
1078 let status = ToolCallStatus::WaitingForConfirmation {
1079 options,
1080 respond_tx: tx,
1081 };
1082
1083 self.upsert_tool_call_inner(tool_call, status, cx)?;
1084 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1085 Ok(rx)
1086 }
1087
1088 pub fn authorize_tool_call(
1089 &mut self,
1090 id: acp::ToolCallId,
1091 option_id: acp::PermissionOptionId,
1092 option_kind: acp::PermissionOptionKind,
1093 cx: &mut Context<Self>,
1094 ) {
1095 let Some((ix, call)) = self.tool_call_mut(&id) else {
1096 return;
1097 };
1098
1099 let new_status = match option_kind {
1100 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1101 ToolCallStatus::Rejected
1102 }
1103 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1104 ToolCallStatus::InProgress
1105 }
1106 };
1107
1108 let curr_status = mem::replace(&mut call.status, new_status);
1109
1110 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1111 respond_tx.send(option_id).log_err();
1112 } else if cfg!(debug_assertions) {
1113 panic!("tried to authorize an already authorized tool call");
1114 }
1115
1116 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1117 }
1118
1119 /// Returns true if the last turn is awaiting tool authorization
1120 pub fn waiting_for_tool_confirmation(&self) -> bool {
1121 for entry in self.entries.iter().rev() {
1122 match &entry {
1123 AgentThreadEntry::ToolCall(call) => match call.status {
1124 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1125 ToolCallStatus::Pending
1126 | ToolCallStatus::InProgress
1127 | ToolCallStatus::Completed
1128 | ToolCallStatus::Failed
1129 | ToolCallStatus::Rejected
1130 | ToolCallStatus::Canceled => continue,
1131 },
1132 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1133 // Reached the beginning of the turn
1134 return false;
1135 }
1136 }
1137 }
1138 false
1139 }
1140
1141 pub fn plan(&self) -> &Plan {
1142 &self.plan
1143 }
1144
1145 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1146 let new_entries_len = request.entries.len();
1147 let mut new_entries = request.entries.into_iter();
1148
1149 // Reuse existing markdown to prevent flickering
1150 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1151 let PlanEntry {
1152 content,
1153 priority,
1154 status,
1155 } = old;
1156 content.update(cx, |old, cx| {
1157 old.replace(new.content, cx);
1158 });
1159 *priority = new.priority;
1160 *status = new.status;
1161 }
1162 for new in new_entries {
1163 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1164 }
1165 self.plan.entries.truncate(new_entries_len);
1166
1167 cx.notify();
1168 }
1169
1170 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1171 self.plan
1172 .entries
1173 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1174 cx.notify();
1175 }
1176
1177 #[cfg(any(test, feature = "test-support"))]
1178 pub fn send_raw(
1179 &mut self,
1180 message: &str,
1181 cx: &mut Context<Self>,
1182 ) -> BoxFuture<'static, Result<()>> {
1183 self.send(
1184 vec![acp::ContentBlock::Text(acp::TextContent {
1185 text: message.to_string(),
1186 annotations: None,
1187 })],
1188 cx,
1189 )
1190 }
1191
1192 pub fn send(
1193 &mut self,
1194 message: Vec<acp::ContentBlock>,
1195 cx: &mut Context<Self>,
1196 ) -> BoxFuture<'static, Result<()>> {
1197 let block = ContentBlock::new_combined(
1198 message.clone(),
1199 self.project.read(cx).languages().clone(),
1200 cx,
1201 );
1202 let request = acp::PromptRequest {
1203 prompt: message.clone(),
1204 session_id: self.session_id.clone(),
1205 };
1206 let git_store = self.project.read(cx).git_store().clone();
1207
1208 let message_id = if self
1209 .connection
1210 .session_editor(&self.session_id, cx)
1211 .is_some()
1212 {
1213 Some(UserMessageId::new())
1214 } else {
1215 None
1216 };
1217
1218 self.run_turn(cx, async move |this, cx| {
1219 this.update(cx, |this, cx| {
1220 this.push_entry(
1221 AgentThreadEntry::UserMessage(UserMessage {
1222 id: message_id.clone(),
1223 content: block,
1224 chunks: message,
1225 checkpoint: None,
1226 }),
1227 cx,
1228 );
1229 })
1230 .ok();
1231
1232 let old_checkpoint = git_store
1233 .update(cx, |git, cx| git.checkpoint(cx))?
1234 .await
1235 .context("failed to get old checkpoint")
1236 .log_err();
1237 this.update(cx, |this, cx| {
1238 if let Some((_ix, message)) = this.last_user_message() {
1239 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1240 git_checkpoint,
1241 show: false,
1242 });
1243 }
1244 this.connection.prompt(message_id, request, cx)
1245 })?
1246 .await
1247 })
1248 }
1249
1250 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1251 self.run_turn(cx, async move |this, cx| {
1252 this.update(cx, |this, cx| {
1253 this.connection
1254 .resume(&this.session_id, cx)
1255 .map(|resume| resume.run(cx))
1256 })?
1257 .context("resuming a session is not supported")?
1258 .await
1259 })
1260 }
1261
1262 fn run_turn(
1263 &mut self,
1264 cx: &mut Context<Self>,
1265 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1266 ) -> BoxFuture<'static, Result<()>> {
1267 self.clear_completed_plan_entries(cx);
1268
1269 let (tx, rx) = oneshot::channel();
1270 let cancel_task = self.cancel(cx);
1271
1272 self.send_task = Some(cx.spawn(async move |this, cx| {
1273 cancel_task.await;
1274 tx.send(f(this, cx).await).ok();
1275 }));
1276
1277 cx.spawn(async move |this, cx| {
1278 let response = rx.await;
1279
1280 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1281 .await?;
1282
1283 this.update(cx, |this, cx| {
1284 this.project
1285 .update(cx, |project, cx| project.set_agent_location(None, cx));
1286 match response {
1287 Ok(Err(e)) => {
1288 this.send_task.take();
1289 cx.emit(AcpThreadEvent::Error);
1290 Err(e)
1291 }
1292 result => {
1293 let canceled = matches!(
1294 result,
1295 Ok(Ok(acp::PromptResponse {
1296 stop_reason: acp::StopReason::Canceled
1297 }))
1298 );
1299
1300 // We only take the task if the current prompt wasn't canceled.
1301 //
1302 // This prompt may have been canceled because another one was sent
1303 // while it was still generating. In these cases, dropping `send_task`
1304 // would cause the next generation to be canceled.
1305 if !canceled {
1306 this.send_task.take();
1307 }
1308
1309 cx.emit(AcpThreadEvent::Stopped);
1310 Ok(())
1311 }
1312 }
1313 })?
1314 })
1315 .boxed()
1316 }
1317
1318 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1319 let Some(send_task) = self.send_task.take() else {
1320 return Task::ready(());
1321 };
1322
1323 for entry in self.entries.iter_mut() {
1324 if let AgentThreadEntry::ToolCall(call) = entry {
1325 let cancel = matches!(
1326 call.status,
1327 ToolCallStatus::Pending
1328 | ToolCallStatus::WaitingForConfirmation { .. }
1329 | ToolCallStatus::InProgress
1330 );
1331
1332 if cancel {
1333 call.status = ToolCallStatus::Canceled;
1334 }
1335 }
1336 }
1337
1338 self.connection.cancel(&self.session_id, cx);
1339
1340 // Wait for the send task to complete
1341 cx.foreground_executor().spawn(send_task)
1342 }
1343
1344 /// Rewinds this thread to before the entry at `index`, removing it and all
1345 /// subsequent entries while reverting any changes made from that point.
1346 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1347 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1348 return Task::ready(Err(anyhow!("not supported")));
1349 };
1350 let Some(message) = self.user_message(&id) else {
1351 return Task::ready(Err(anyhow!("message not found")));
1352 };
1353
1354 let checkpoint = message
1355 .checkpoint
1356 .as_ref()
1357 .map(|c| c.git_checkpoint.clone());
1358
1359 let git_store = self.project.read(cx).git_store().clone();
1360 cx.spawn(async move |this, cx| {
1361 if let Some(checkpoint) = checkpoint {
1362 git_store
1363 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1364 .await?;
1365 }
1366
1367 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1368 .await?;
1369 this.update(cx, |this, cx| {
1370 if let Some((ix, _)) = this.user_message_mut(&id) {
1371 let range = ix..this.entries.len();
1372 this.entries.truncate(ix);
1373 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1374 }
1375 })
1376 })
1377 }
1378
1379 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1380 let git_store = self.project.read(cx).git_store().clone();
1381
1382 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1383 if let Some(checkpoint) = message.checkpoint.as_ref() {
1384 checkpoint.git_checkpoint.clone()
1385 } else {
1386 return Task::ready(Ok(()));
1387 }
1388 } else {
1389 return Task::ready(Ok(()));
1390 };
1391
1392 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1393 cx.spawn(async move |this, cx| {
1394 let new_checkpoint = new_checkpoint
1395 .await
1396 .context("failed to get new checkpoint")
1397 .log_err();
1398 if let Some(new_checkpoint) = new_checkpoint {
1399 let equal = git_store
1400 .update(cx, |git, cx| {
1401 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1402 })?
1403 .await
1404 .unwrap_or(true);
1405 this.update(cx, |this, cx| {
1406 let (ix, message) = this.last_user_message().context("no user message")?;
1407 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1408 checkpoint.show = !equal;
1409 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1410 anyhow::Ok(())
1411 })??;
1412 }
1413
1414 Ok(())
1415 })
1416 }
1417
1418 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1419 self.entries
1420 .iter_mut()
1421 .enumerate()
1422 .rev()
1423 .find_map(|(ix, entry)| {
1424 if let AgentThreadEntry::UserMessage(message) = entry {
1425 Some((ix, message))
1426 } else {
1427 None
1428 }
1429 })
1430 }
1431
1432 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1433 self.entries.iter().find_map(|entry| {
1434 if let AgentThreadEntry::UserMessage(message) = entry {
1435 if message.id.as_ref() == Some(id) {
1436 Some(message)
1437 } else {
1438 None
1439 }
1440 } else {
1441 None
1442 }
1443 })
1444 }
1445
1446 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1447 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1448 if let AgentThreadEntry::UserMessage(message) = entry {
1449 if message.id.as_ref() == Some(id) {
1450 Some((ix, message))
1451 } else {
1452 None
1453 }
1454 } else {
1455 None
1456 }
1457 })
1458 }
1459
1460 pub fn read_text_file(
1461 &self,
1462 path: PathBuf,
1463 line: Option<u32>,
1464 limit: Option<u32>,
1465 reuse_shared_snapshot: bool,
1466 cx: &mut Context<Self>,
1467 ) -> Task<Result<String>> {
1468 let project = self.project.clone();
1469 let action_log = self.action_log.clone();
1470 cx.spawn(async move |this, cx| {
1471 let load = project.update(cx, |project, cx| {
1472 let path = project
1473 .project_path_for_absolute_path(&path, cx)
1474 .context("invalid path")?;
1475 anyhow::Ok(project.open_buffer(path, cx))
1476 });
1477 let buffer = load??.await?;
1478
1479 let snapshot = if reuse_shared_snapshot {
1480 this.read_with(cx, |this, _| {
1481 this.shared_buffers.get(&buffer.clone()).cloned()
1482 })
1483 .log_err()
1484 .flatten()
1485 } else {
1486 None
1487 };
1488
1489 let snapshot = if let Some(snapshot) = snapshot {
1490 snapshot
1491 } else {
1492 action_log.update(cx, |action_log, cx| {
1493 action_log.buffer_read(buffer.clone(), cx);
1494 })?;
1495 project.update(cx, |project, cx| {
1496 let position = buffer
1497 .read(cx)
1498 .snapshot()
1499 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1500 project.set_agent_location(
1501 Some(AgentLocation {
1502 buffer: buffer.downgrade(),
1503 position,
1504 }),
1505 cx,
1506 );
1507 })?;
1508
1509 buffer.update(cx, |buffer, _| buffer.snapshot())?
1510 };
1511
1512 this.update(cx, |this, _| {
1513 let text = snapshot.text();
1514 this.shared_buffers.insert(buffer.clone(), snapshot);
1515 if line.is_none() && limit.is_none() {
1516 return Ok(text);
1517 }
1518 let limit = limit.unwrap_or(u32::MAX) as usize;
1519 let Some(line) = line else {
1520 return Ok(text.lines().take(limit).collect::<String>());
1521 };
1522
1523 let count = text.lines().count();
1524 if count < line as usize {
1525 anyhow::bail!("There are only {} lines", count);
1526 }
1527 Ok(text
1528 .lines()
1529 .skip(line as usize + 1)
1530 .take(limit)
1531 .collect::<String>())
1532 })?
1533 })
1534 }
1535
1536 pub fn write_text_file(
1537 &self,
1538 path: PathBuf,
1539 content: String,
1540 cx: &mut Context<Self>,
1541 ) -> Task<Result<()>> {
1542 let project = self.project.clone();
1543 let action_log = self.action_log.clone();
1544 cx.spawn(async move |this, cx| {
1545 let load = project.update(cx, |project, cx| {
1546 let path = project
1547 .project_path_for_absolute_path(&path, cx)
1548 .context("invalid path")?;
1549 anyhow::Ok(project.open_buffer(path, cx))
1550 });
1551 let buffer = load??.await?;
1552 let snapshot = this.update(cx, |this, cx| {
1553 this.shared_buffers
1554 .get(&buffer)
1555 .cloned()
1556 .unwrap_or_else(|| buffer.read(cx).snapshot())
1557 })?;
1558 let edits = cx
1559 .background_executor()
1560 .spawn(async move {
1561 let old_text = snapshot.text();
1562 text_diff(old_text.as_str(), &content)
1563 .into_iter()
1564 .map(|(range, replacement)| {
1565 (
1566 snapshot.anchor_after(range.start)
1567 ..snapshot.anchor_before(range.end),
1568 replacement,
1569 )
1570 })
1571 .collect::<Vec<_>>()
1572 })
1573 .await;
1574 cx.update(|cx| {
1575 project.update(cx, |project, cx| {
1576 project.set_agent_location(
1577 Some(AgentLocation {
1578 buffer: buffer.downgrade(),
1579 position: edits
1580 .last()
1581 .map(|(range, _)| range.end)
1582 .unwrap_or(Anchor::MIN),
1583 }),
1584 cx,
1585 );
1586 });
1587
1588 action_log.update(cx, |action_log, cx| {
1589 action_log.buffer_read(buffer.clone(), cx);
1590 });
1591 buffer.update(cx, |buffer, cx| {
1592 buffer.edit(edits, None, cx);
1593 });
1594 action_log.update(cx, |action_log, cx| {
1595 action_log.buffer_edited(buffer.clone(), cx);
1596 });
1597 })?;
1598 project
1599 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1600 .await
1601 })
1602 }
1603
1604 pub fn to_markdown(&self, cx: &App) -> String {
1605 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1606 }
1607
1608 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1609 cx.emit(AcpThreadEvent::ServerExited(status));
1610 }
1611}
1612
1613fn markdown_for_raw_output(
1614 raw_output: &serde_json::Value,
1615 language_registry: &Arc<LanguageRegistry>,
1616 cx: &mut App,
1617) -> Option<Entity<Markdown>> {
1618 match raw_output {
1619 serde_json::Value::Null => None,
1620 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1621 Markdown::new(
1622 value.to_string().into(),
1623 Some(language_registry.clone()),
1624 None,
1625 cx,
1626 )
1627 })),
1628 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1629 Markdown::new(
1630 value.to_string().into(),
1631 Some(language_registry.clone()),
1632 None,
1633 cx,
1634 )
1635 })),
1636 serde_json::Value::String(value) => Some(cx.new(|cx| {
1637 Markdown::new(
1638 value.clone().into(),
1639 Some(language_registry.clone()),
1640 None,
1641 cx,
1642 )
1643 })),
1644 value => Some(cx.new(|cx| {
1645 Markdown::new(
1646 format!("```json\n{}\n```", value).into(),
1647 Some(language_registry.clone()),
1648 None,
1649 cx,
1650 )
1651 })),
1652 }
1653}
1654
1655#[cfg(test)]
1656mod tests {
1657 use super::*;
1658 use anyhow::anyhow;
1659 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1660 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1661 use indoc::indoc;
1662 use project::{FakeFs, Fs};
1663 use rand::Rng as _;
1664 use serde_json::json;
1665 use settings::SettingsStore;
1666 use smol::stream::StreamExt as _;
1667 use std::{
1668 any::Any,
1669 cell::RefCell,
1670 path::Path,
1671 rc::Rc,
1672 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1673 time::Duration,
1674 };
1675 use util::path;
1676
1677 fn init_test(cx: &mut TestAppContext) {
1678 env_logger::try_init().ok();
1679 cx.update(|cx| {
1680 let settings_store = SettingsStore::test(cx);
1681 cx.set_global(settings_store);
1682 Project::init_settings(cx);
1683 language::init(cx);
1684 });
1685 }
1686
1687 #[gpui::test]
1688 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1689 init_test(cx);
1690
1691 let fs = FakeFs::new(cx.executor());
1692 let project = Project::test(fs, [], cx).await;
1693 let connection = Rc::new(FakeAgentConnection::new());
1694 let thread = cx
1695 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1696 .await
1697 .unwrap();
1698
1699 // Test creating a new user message
1700 thread.update(cx, |thread, cx| {
1701 thread.push_user_content_block(
1702 None,
1703 acp::ContentBlock::Text(acp::TextContent {
1704 annotations: None,
1705 text: "Hello, ".to_string(),
1706 }),
1707 cx,
1708 );
1709 });
1710
1711 thread.update(cx, |thread, cx| {
1712 assert_eq!(thread.entries.len(), 1);
1713 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1714 assert_eq!(user_msg.id, None);
1715 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1716 } else {
1717 panic!("Expected UserMessage");
1718 }
1719 });
1720
1721 // Test appending to existing user message
1722 let message_1_id = UserMessageId::new();
1723 thread.update(cx, |thread, cx| {
1724 thread.push_user_content_block(
1725 Some(message_1_id.clone()),
1726 acp::ContentBlock::Text(acp::TextContent {
1727 annotations: None,
1728 text: "world!".to_string(),
1729 }),
1730 cx,
1731 );
1732 });
1733
1734 thread.update(cx, |thread, cx| {
1735 assert_eq!(thread.entries.len(), 1);
1736 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1737 assert_eq!(user_msg.id, Some(message_1_id));
1738 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1739 } else {
1740 panic!("Expected UserMessage");
1741 }
1742 });
1743
1744 // Test creating new user message after assistant message
1745 thread.update(cx, |thread, cx| {
1746 thread.push_assistant_content_block(
1747 acp::ContentBlock::Text(acp::TextContent {
1748 annotations: None,
1749 text: "Assistant response".to_string(),
1750 }),
1751 false,
1752 cx,
1753 );
1754 });
1755
1756 let message_2_id = UserMessageId::new();
1757 thread.update(cx, |thread, cx| {
1758 thread.push_user_content_block(
1759 Some(message_2_id.clone()),
1760 acp::ContentBlock::Text(acp::TextContent {
1761 annotations: None,
1762 text: "New user message".to_string(),
1763 }),
1764 cx,
1765 );
1766 });
1767
1768 thread.update(cx, |thread, cx| {
1769 assert_eq!(thread.entries.len(), 3);
1770 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1771 assert_eq!(user_msg.id, Some(message_2_id));
1772 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1773 } else {
1774 panic!("Expected UserMessage at index 2");
1775 }
1776 });
1777 }
1778
1779 #[gpui::test]
1780 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1781 init_test(cx);
1782
1783 let fs = FakeFs::new(cx.executor());
1784 let project = Project::test(fs, [], cx).await;
1785 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1786 |_, thread, mut cx| {
1787 async move {
1788 thread.update(&mut cx, |thread, cx| {
1789 thread
1790 .handle_session_update(
1791 acp::SessionUpdate::AgentThoughtChunk {
1792 content: "Thinking ".into(),
1793 },
1794 cx,
1795 )
1796 .unwrap();
1797 thread
1798 .handle_session_update(
1799 acp::SessionUpdate::AgentThoughtChunk {
1800 content: "hard!".into(),
1801 },
1802 cx,
1803 )
1804 .unwrap();
1805 })?;
1806 Ok(acp::PromptResponse {
1807 stop_reason: acp::StopReason::EndTurn,
1808 })
1809 }
1810 .boxed_local()
1811 },
1812 ));
1813
1814 let thread = cx
1815 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1816 .await
1817 .unwrap();
1818
1819 thread
1820 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1821 .await
1822 .unwrap();
1823
1824 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1825 assert_eq!(
1826 output,
1827 indoc! {r#"
1828 ## User
1829
1830 Hello from Zed!
1831
1832 ## Assistant
1833
1834 <thinking>
1835 Thinking hard!
1836 </thinking>
1837
1838 "#}
1839 );
1840 }
1841
1842 #[gpui::test]
1843 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1844 init_test(cx);
1845
1846 let fs = FakeFs::new(cx.executor());
1847 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1848 .await;
1849 let project = Project::test(fs.clone(), [], cx).await;
1850 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1851 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1852 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1853 move |_, thread, mut cx| {
1854 let read_file_tx = read_file_tx.clone();
1855 async move {
1856 let content = thread
1857 .update(&mut cx, |thread, cx| {
1858 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1859 })
1860 .unwrap()
1861 .await
1862 .unwrap();
1863 assert_eq!(content, "one\ntwo\nthree\n");
1864 read_file_tx.take().unwrap().send(()).unwrap();
1865 thread
1866 .update(&mut cx, |thread, cx| {
1867 thread.write_text_file(
1868 path!("/tmp/foo").into(),
1869 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1870 cx,
1871 )
1872 })
1873 .unwrap()
1874 .await
1875 .unwrap();
1876 Ok(acp::PromptResponse {
1877 stop_reason: acp::StopReason::EndTurn,
1878 })
1879 }
1880 .boxed_local()
1881 },
1882 ));
1883
1884 let (worktree, pathbuf) = project
1885 .update(cx, |project, cx| {
1886 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1887 })
1888 .await
1889 .unwrap();
1890 let buffer = project
1891 .update(cx, |project, cx| {
1892 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1893 })
1894 .await
1895 .unwrap();
1896
1897 let thread = cx
1898 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1899 .await
1900 .unwrap();
1901
1902 let request = thread.update(cx, |thread, cx| {
1903 thread.send_raw("Extend the count in /tmp/foo", cx)
1904 });
1905 read_file_rx.await.ok();
1906 buffer.update(cx, |buffer, cx| {
1907 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1908 });
1909 cx.run_until_parked();
1910 assert_eq!(
1911 buffer.read_with(cx, |buffer, _| buffer.text()),
1912 "zero\none\ntwo\nthree\nfour\nfive\n"
1913 );
1914 assert_eq!(
1915 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1916 "zero\none\ntwo\nthree\nfour\nfive\n"
1917 );
1918 request.await.unwrap();
1919 }
1920
1921 #[gpui::test]
1922 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1923 init_test(cx);
1924
1925 let fs = FakeFs::new(cx.executor());
1926 let project = Project::test(fs, [], cx).await;
1927 let id = acp::ToolCallId("test".into());
1928
1929 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1930 let id = id.clone();
1931 move |_, thread, mut cx| {
1932 let id = id.clone();
1933 async move {
1934 thread
1935 .update(&mut cx, |thread, cx| {
1936 thread.handle_session_update(
1937 acp::SessionUpdate::ToolCall(acp::ToolCall {
1938 id: id.clone(),
1939 title: "Label".into(),
1940 kind: acp::ToolKind::Fetch,
1941 status: acp::ToolCallStatus::InProgress,
1942 content: vec![],
1943 locations: vec![],
1944 raw_input: None,
1945 raw_output: None,
1946 }),
1947 cx,
1948 )
1949 })
1950 .unwrap()
1951 .unwrap();
1952 Ok(acp::PromptResponse {
1953 stop_reason: acp::StopReason::EndTurn,
1954 })
1955 }
1956 .boxed_local()
1957 }
1958 }));
1959
1960 let thread = cx
1961 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1962 .await
1963 .unwrap();
1964
1965 let request = thread.update(cx, |thread, cx| {
1966 thread.send_raw("Fetch https://example.com", cx)
1967 });
1968
1969 run_until_first_tool_call(&thread, cx).await;
1970
1971 thread.read_with(cx, |thread, _| {
1972 assert!(matches!(
1973 thread.entries[1],
1974 AgentThreadEntry::ToolCall(ToolCall {
1975 status: ToolCallStatus::InProgress,
1976 ..
1977 })
1978 ));
1979 });
1980
1981 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1982
1983 thread.read_with(cx, |thread, _| {
1984 assert!(matches!(
1985 &thread.entries[1],
1986 AgentThreadEntry::ToolCall(ToolCall {
1987 status: ToolCallStatus::Canceled,
1988 ..
1989 })
1990 ));
1991 });
1992
1993 thread
1994 .update(cx, |thread, cx| {
1995 thread.handle_session_update(
1996 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1997 id,
1998 fields: acp::ToolCallUpdateFields {
1999 status: Some(acp::ToolCallStatus::Completed),
2000 ..Default::default()
2001 },
2002 }),
2003 cx,
2004 )
2005 })
2006 .unwrap();
2007
2008 request.await.unwrap();
2009
2010 thread.read_with(cx, |thread, _| {
2011 assert!(matches!(
2012 thread.entries[1],
2013 AgentThreadEntry::ToolCall(ToolCall {
2014 status: ToolCallStatus::Completed,
2015 ..
2016 })
2017 ));
2018 });
2019 }
2020
2021 #[gpui::test]
2022 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2023 init_test(cx);
2024 let fs = FakeFs::new(cx.background_executor.clone());
2025 fs.insert_tree(path!("/test"), json!({})).await;
2026 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2027
2028 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2029 move |_, thread, mut cx| {
2030 async move {
2031 thread
2032 .update(&mut cx, |thread, cx| {
2033 thread.handle_session_update(
2034 acp::SessionUpdate::ToolCall(acp::ToolCall {
2035 id: acp::ToolCallId("test".into()),
2036 title: "Label".into(),
2037 kind: acp::ToolKind::Edit,
2038 status: acp::ToolCallStatus::Completed,
2039 content: vec![acp::ToolCallContent::Diff {
2040 diff: acp::Diff {
2041 path: "/test/test.txt".into(),
2042 old_text: None,
2043 new_text: "foo".into(),
2044 },
2045 }],
2046 locations: vec![],
2047 raw_input: None,
2048 raw_output: None,
2049 }),
2050 cx,
2051 )
2052 })
2053 .unwrap()
2054 .unwrap();
2055 Ok(acp::PromptResponse {
2056 stop_reason: acp::StopReason::EndTurn,
2057 })
2058 }
2059 .boxed_local()
2060 }
2061 }));
2062
2063 let thread = cx
2064 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2065 .await
2066 .unwrap();
2067
2068 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2069 .await
2070 .unwrap();
2071
2072 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2073 }
2074
2075 #[gpui::test(iterations = 10)]
2076 async fn test_checkpoints(cx: &mut TestAppContext) {
2077 init_test(cx);
2078 let fs = FakeFs::new(cx.background_executor.clone());
2079 fs.insert_tree(
2080 path!("/test"),
2081 json!({
2082 ".git": {}
2083 }),
2084 )
2085 .await;
2086 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2087
2088 let simulate_changes = Arc::new(AtomicBool::new(true));
2089 let next_filename = Arc::new(AtomicUsize::new(0));
2090 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2091 let simulate_changes = simulate_changes.clone();
2092 let next_filename = next_filename.clone();
2093 let fs = fs.clone();
2094 move |request, thread, mut cx| {
2095 let fs = fs.clone();
2096 let simulate_changes = simulate_changes.clone();
2097 let next_filename = next_filename.clone();
2098 async move {
2099 if simulate_changes.load(SeqCst) {
2100 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2101 fs.write(Path::new(&filename), b"").await?;
2102 }
2103
2104 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2105 panic!("expected text content block");
2106 };
2107 thread.update(&mut cx, |thread, cx| {
2108 thread
2109 .handle_session_update(
2110 acp::SessionUpdate::AgentMessageChunk {
2111 content: content.text.to_uppercase().into(),
2112 },
2113 cx,
2114 )
2115 .unwrap();
2116 })?;
2117 Ok(acp::PromptResponse {
2118 stop_reason: acp::StopReason::EndTurn,
2119 })
2120 }
2121 .boxed_local()
2122 }
2123 }));
2124 let thread = cx
2125 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2126 .await
2127 .unwrap();
2128
2129 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2130 .await
2131 .unwrap();
2132 thread.read_with(cx, |thread, cx| {
2133 assert_eq!(
2134 thread.to_markdown(cx),
2135 indoc! {"
2136 ## User (checkpoint)
2137
2138 Lorem
2139
2140 ## Assistant
2141
2142 LOREM
2143
2144 "}
2145 );
2146 });
2147 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2148
2149 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2150 .await
2151 .unwrap();
2152 thread.read_with(cx, |thread, cx| {
2153 assert_eq!(
2154 thread.to_markdown(cx),
2155 indoc! {"
2156 ## User (checkpoint)
2157
2158 Lorem
2159
2160 ## Assistant
2161
2162 LOREM
2163
2164 ## User (checkpoint)
2165
2166 ipsum
2167
2168 ## Assistant
2169
2170 IPSUM
2171
2172 "}
2173 );
2174 });
2175 assert_eq!(
2176 fs.files(),
2177 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2178 );
2179
2180 // Checkpoint isn't stored when there are no changes.
2181 simulate_changes.store(false, SeqCst);
2182 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2183 .await
2184 .unwrap();
2185 thread.read_with(cx, |thread, cx| {
2186 assert_eq!(
2187 thread.to_markdown(cx),
2188 indoc! {"
2189 ## User (checkpoint)
2190
2191 Lorem
2192
2193 ## Assistant
2194
2195 LOREM
2196
2197 ## User (checkpoint)
2198
2199 ipsum
2200
2201 ## Assistant
2202
2203 IPSUM
2204
2205 ## User
2206
2207 dolor
2208
2209 ## Assistant
2210
2211 DOLOR
2212
2213 "}
2214 );
2215 });
2216 assert_eq!(
2217 fs.files(),
2218 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2219 );
2220
2221 // Rewinding the conversation truncates the history and restores the checkpoint.
2222 thread
2223 .update(cx, |thread, cx| {
2224 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2225 panic!("unexpected entries {:?}", thread.entries)
2226 };
2227 thread.rewind(message.id.clone().unwrap(), cx)
2228 })
2229 .await
2230 .unwrap();
2231 thread.read_with(cx, |thread, cx| {
2232 assert_eq!(
2233 thread.to_markdown(cx),
2234 indoc! {"
2235 ## User (checkpoint)
2236
2237 Lorem
2238
2239 ## Assistant
2240
2241 LOREM
2242
2243 "}
2244 );
2245 });
2246 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2247 }
2248
2249 async fn run_until_first_tool_call(
2250 thread: &Entity<AcpThread>,
2251 cx: &mut TestAppContext,
2252 ) -> usize {
2253 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2254
2255 let subscription = cx.update(|cx| {
2256 cx.subscribe(thread, move |thread, _, cx| {
2257 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2258 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2259 return tx.try_send(ix).unwrap();
2260 }
2261 }
2262 })
2263 });
2264
2265 select! {
2266 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2267 panic!("Timeout waiting for tool call")
2268 }
2269 ix = rx.next().fuse() => {
2270 drop(subscription);
2271 ix.unwrap()
2272 }
2273 }
2274 }
2275
2276 #[derive(Clone, Default)]
2277 struct FakeAgentConnection {
2278 auth_methods: Vec<acp::AuthMethod>,
2279 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2280 on_user_message: Option<
2281 Rc<
2282 dyn Fn(
2283 acp::PromptRequest,
2284 WeakEntity<AcpThread>,
2285 AsyncApp,
2286 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2287 + 'static,
2288 >,
2289 >,
2290 }
2291
2292 impl FakeAgentConnection {
2293 fn new() -> Self {
2294 Self {
2295 auth_methods: Vec::new(),
2296 on_user_message: None,
2297 sessions: Arc::default(),
2298 }
2299 }
2300
2301 #[expect(unused)]
2302 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2303 self.auth_methods = auth_methods;
2304 self
2305 }
2306
2307 fn on_user_message(
2308 mut self,
2309 handler: impl Fn(
2310 acp::PromptRequest,
2311 WeakEntity<AcpThread>,
2312 AsyncApp,
2313 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2314 + 'static,
2315 ) -> Self {
2316 self.on_user_message.replace(Rc::new(handler));
2317 self
2318 }
2319 }
2320
2321 impl AgentConnection for FakeAgentConnection {
2322 fn auth_methods(&self) -> &[acp::AuthMethod] {
2323 &self.auth_methods
2324 }
2325
2326 fn new_thread(
2327 self: Rc<Self>,
2328 project: Entity<Project>,
2329 _cwd: &Path,
2330 cx: &mut gpui::App,
2331 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2332 let session_id = acp::SessionId(
2333 rand::thread_rng()
2334 .sample_iter(&rand::distributions::Alphanumeric)
2335 .take(7)
2336 .map(char::from)
2337 .collect::<String>()
2338 .into(),
2339 );
2340 let thread =
2341 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2342 self.sessions.lock().insert(session_id, thread.downgrade());
2343 Task::ready(Ok(thread))
2344 }
2345
2346 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2347 if self.auth_methods().iter().any(|m| m.id == method) {
2348 Task::ready(Ok(()))
2349 } else {
2350 Task::ready(Err(anyhow!("Invalid Auth Method")))
2351 }
2352 }
2353
2354 fn prompt(
2355 &self,
2356 _id: Option<UserMessageId>,
2357 params: acp::PromptRequest,
2358 cx: &mut App,
2359 ) -> Task<gpui::Result<acp::PromptResponse>> {
2360 let sessions = self.sessions.lock();
2361 let thread = sessions.get(¶ms.session_id).unwrap();
2362 if let Some(handler) = &self.on_user_message {
2363 let handler = handler.clone();
2364 let thread = thread.clone();
2365 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2366 } else {
2367 Task::ready(Ok(acp::PromptResponse {
2368 stop_reason: acp::StopReason::EndTurn,
2369 }))
2370 }
2371 }
2372
2373 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2374 let sessions = self.sessions.lock();
2375 let thread = sessions.get(session_id).unwrap().clone();
2376
2377 cx.spawn(async move |cx| {
2378 thread
2379 .update(cx, |thread, cx| thread.cancel(cx))
2380 .unwrap()
2381 .await
2382 })
2383 .detach();
2384 }
2385
2386 fn session_editor(
2387 &self,
2388 session_id: &acp::SessionId,
2389 _cx: &mut App,
2390 ) -> Option<Rc<dyn AgentSessionEditor>> {
2391 Some(Rc::new(FakeAgentSessionEditor {
2392 _session_id: session_id.clone(),
2393 }))
2394 }
2395
2396 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2397 self
2398 }
2399 }
2400
2401 struct FakeAgentSessionEditor {
2402 _session_id: acp::SessionId,
2403 }
2404
2405 impl AgentSessionEditor for FakeAgentSessionEditor {
2406 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2407 Task::ready(Ok(()))
2408 }
2409 }
2410}