1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9pub use terminal::*;
10
11use action_log::ActionLog;
12use agent_client_protocol as acp;
13use anyhow::{Context as _, Result, anyhow};
14use editor::Bias;
15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use itertools::Itertools;
18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
19use markdown::Markdown;
20use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
21use std::collections::HashMap;
22use std::error::Error;
23use std::fmt::{Formatter, Write};
24use std::ops::Range;
25use std::process::ExitStatus;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
29use ui::App;
30use util::ResultExt;
31
32#[derive(Debug)]
33pub struct UserMessage {
34 pub id: Option<UserMessageId>,
35 pub content: ContentBlock,
36 pub chunks: Vec<acp::ContentBlock>,
37 pub checkpoint: Option<Checkpoint>,
38}
39
40#[derive(Debug)]
41pub struct Checkpoint {
42 git_checkpoint: GitStoreCheckpoint,
43 pub show: bool,
44}
45
46impl UserMessage {
47 fn to_markdown(&self, cx: &App) -> String {
48 let mut markdown = String::new();
49 if self
50 .checkpoint
51 .as_ref()
52 .map_or(false, |checkpoint| checkpoint.show)
53 {
54 writeln!(markdown, "## User (checkpoint)").unwrap();
55 } else {
56 writeln!(markdown, "## User").unwrap();
57 }
58 writeln!(markdown).unwrap();
59 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
60 writeln!(markdown).unwrap();
61 markdown
62 }
63}
64
65#[derive(Debug, PartialEq)]
66pub struct AssistantMessage {
67 pub chunks: Vec<AssistantMessageChunk>,
68}
69
70impl AssistantMessage {
71 pub fn to_markdown(&self, cx: &App) -> String {
72 format!(
73 "## Assistant\n\n{}\n\n",
74 self.chunks
75 .iter()
76 .map(|chunk| chunk.to_markdown(cx))
77 .join("\n\n")
78 )
79 }
80}
81
82#[derive(Debug, PartialEq)]
83pub enum AssistantMessageChunk {
84 Message { block: ContentBlock },
85 Thought { block: ContentBlock },
86}
87
88impl AssistantMessageChunk {
89 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
90 Self::Message {
91 block: ContentBlock::new(chunk.into(), language_registry, cx),
92 }
93 }
94
95 fn to_markdown(&self, cx: &App) -> String {
96 match self {
97 Self::Message { block } => block.to_markdown(cx).to_string(),
98 Self::Thought { block } => {
99 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum AgentThreadEntry {
107 UserMessage(UserMessage),
108 AssistantMessage(AssistantMessage),
109 ToolCall(ToolCall),
110}
111
112impl AgentThreadEntry {
113 pub fn to_markdown(&self, cx: &App) -> String {
114 match self {
115 Self::UserMessage(message) => message.to_markdown(cx),
116 Self::AssistantMessage(message) => message.to_markdown(cx),
117 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
118 }
119 }
120
121 pub fn user_message(&self) -> Option<&UserMessage> {
122 if let AgentThreadEntry::UserMessage(message) = self {
123 Some(message)
124 } else {
125 None
126 }
127 }
128
129 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
130 if let AgentThreadEntry::ToolCall(call) = self {
131 itertools::Either::Left(call.diffs())
132 } else {
133 itertools::Either::Right(std::iter::empty())
134 }
135 }
136
137 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.terminals())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
146 if let AgentThreadEntry::ToolCall(ToolCall {
147 locations,
148 resolved_locations,
149 ..
150 }) = self
151 {
152 Some((
153 locations.get(ix)?.clone(),
154 resolved_locations.get(ix)?.clone()?,
155 ))
156 } else {
157 None
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct ToolCall {
164 pub id: acp::ToolCallId,
165 pub label: Entity<Markdown>,
166 pub kind: acp::ToolKind,
167 pub content: Vec<ToolCallContent>,
168 pub status: ToolCallStatus,
169 pub locations: Vec<acp::ToolCallLocation>,
170 pub resolved_locations: Vec<Option<AgentLocation>>,
171 pub raw_input: Option<serde_json::Value>,
172 pub raw_output: Option<serde_json::Value>,
173}
174
175impl ToolCall {
176 fn from_acp(
177 tool_call: acp::ToolCall,
178 status: ToolCallStatus,
179 language_registry: Arc<LanguageRegistry>,
180 cx: &mut App,
181 ) -> Self {
182 Self {
183 id: tool_call.id,
184 label: cx.new(|cx| {
185 Markdown::new(
186 tool_call.title.into(),
187 Some(language_registry.clone()),
188 None,
189 cx,
190 )
191 }),
192 kind: tool_call.kind,
193 content: tool_call
194 .content
195 .into_iter()
196 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
197 .collect(),
198 locations: tool_call.locations,
199 resolved_locations: Vec::default(),
200 status,
201 raw_input: tool_call.raw_input,
202 raw_output: tool_call.raw_output,
203 }
204 }
205
206 fn update_fields(
207 &mut self,
208 fields: acp::ToolCallUpdateFields,
209 language_registry: Arc<LanguageRegistry>,
210 cx: &mut App,
211 ) {
212 let acp::ToolCallUpdateFields {
213 kind,
214 status,
215 title,
216 content,
217 locations,
218 raw_input,
219 raw_output,
220 } = fields;
221
222 if let Some(kind) = kind {
223 self.kind = kind;
224 }
225
226 if let Some(status) = status {
227 self.status = status.into();
228 }
229
230 if let Some(title) = title {
231 self.label.update(cx, |label, cx| {
232 label.replace(title, cx);
233 });
234 }
235
236 if let Some(content) = content {
237 self.content = content
238 .into_iter()
239 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
240 .collect();
241 }
242
243 if let Some(locations) = locations {
244 self.locations = locations;
245 }
246
247 if let Some(raw_input) = raw_input {
248 self.raw_input = Some(raw_input);
249 }
250
251 if let Some(raw_output) = raw_output {
252 if self.content.is_empty() {
253 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
254 {
255 self.content
256 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
257 markdown,
258 }));
259 }
260 }
261 self.raw_output = Some(raw_output);
262 }
263 }
264
265 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
266 self.content.iter().filter_map(|content| match content {
267 ToolCallContent::Diff(diff) => Some(diff),
268 ToolCallContent::ContentBlock(_) => None,
269 ToolCallContent::Terminal(_) => None,
270 })
271 }
272
273 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
274 self.content.iter().filter_map(|content| match content {
275 ToolCallContent::Terminal(terminal) => Some(terminal),
276 ToolCallContent::ContentBlock(_) => None,
277 ToolCallContent::Diff(_) => None,
278 })
279 }
280
281 fn to_markdown(&self, cx: &App) -> String {
282 let mut markdown = format!(
283 "**Tool Call: {}**\nStatus: {}\n\n",
284 self.label.read(cx).source(),
285 self.status
286 );
287 for content in &self.content {
288 markdown.push_str(content.to_markdown(cx).as_str());
289 markdown.push_str("\n\n");
290 }
291 markdown
292 }
293
294 async fn resolve_location(
295 location: acp::ToolCallLocation,
296 project: WeakEntity<Project>,
297 cx: &mut AsyncApp,
298 ) -> Option<AgentLocation> {
299 let buffer = project
300 .update(cx, |project, cx| {
301 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
302 Some(project.open_buffer(path, cx))
303 } else {
304 None
305 }
306 })
307 .ok()??;
308 let buffer = buffer.await.log_err()?;
309 let position = buffer
310 .update(cx, |buffer, _| {
311 if let Some(row) = location.line {
312 let snapshot = buffer.snapshot();
313 let column = snapshot.indent_size_for_line(row).len;
314 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
315 snapshot.anchor_before(point)
316 } else {
317 Anchor::MIN
318 }
319 })
320 .ok()?;
321
322 Some(AgentLocation {
323 buffer: buffer.downgrade(),
324 position,
325 })
326 }
327
328 fn resolve_locations(
329 &self,
330 project: Entity<Project>,
331 cx: &mut App,
332 ) -> Task<Vec<Option<AgentLocation>>> {
333 let locations = self.locations.clone();
334 project.update(cx, |_, cx| {
335 cx.spawn(async move |project, cx| {
336 let mut new_locations = Vec::new();
337 for location in locations {
338 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
339 }
340 new_locations
341 })
342 })
343 }
344}
345
346#[derive(Debug)]
347pub enum ToolCallStatus {
348 /// The tool call hasn't started running yet, but we start showing it to
349 /// the user.
350 Pending,
351 /// The tool call is waiting for confirmation from the user.
352 WaitingForConfirmation {
353 options: Vec<acp::PermissionOption>,
354 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
355 },
356 /// The tool call is currently running.
357 InProgress,
358 /// The tool call completed successfully.
359 Completed,
360 /// The tool call failed.
361 Failed,
362 /// The user rejected the tool call.
363 Rejected,
364 /// The user canceled generation so the tool call was canceled.
365 Canceled,
366}
367
368impl From<acp::ToolCallStatus> for ToolCallStatus {
369 fn from(status: acp::ToolCallStatus) -> Self {
370 match status {
371 acp::ToolCallStatus::Pending => Self::Pending,
372 acp::ToolCallStatus::InProgress => Self::InProgress,
373 acp::ToolCallStatus::Completed => Self::Completed,
374 acp::ToolCallStatus::Failed => Self::Failed,
375 }
376 }
377}
378
379impl Display for ToolCallStatus {
380 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
381 write!(
382 f,
383 "{}",
384 match self {
385 ToolCallStatus::Pending => "Pending",
386 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
387 ToolCallStatus::InProgress => "In Progress",
388 ToolCallStatus::Completed => "Completed",
389 ToolCallStatus::Failed => "Failed",
390 ToolCallStatus::Rejected => "Rejected",
391 ToolCallStatus::Canceled => "Canceled",
392 }
393 )
394 }
395}
396
397#[derive(Debug, PartialEq, Clone)]
398pub enum ContentBlock {
399 Empty,
400 Markdown { markdown: Entity<Markdown> },
401 ResourceLink { resource_link: acp::ResourceLink },
402}
403
404impl ContentBlock {
405 pub fn new(
406 block: acp::ContentBlock,
407 language_registry: &Arc<LanguageRegistry>,
408 cx: &mut App,
409 ) -> Self {
410 let mut this = Self::Empty;
411 this.append(block, language_registry, cx);
412 this
413 }
414
415 pub fn new_combined(
416 blocks: impl IntoIterator<Item = acp::ContentBlock>,
417 language_registry: Arc<LanguageRegistry>,
418 cx: &mut App,
419 ) -> Self {
420 let mut this = Self::Empty;
421 for block in blocks {
422 this.append(block, &language_registry, cx);
423 }
424 this
425 }
426
427 pub fn append(
428 &mut self,
429 block: acp::ContentBlock,
430 language_registry: &Arc<LanguageRegistry>,
431 cx: &mut App,
432 ) {
433 if matches!(self, ContentBlock::Empty) {
434 if let acp::ContentBlock::ResourceLink(resource_link) = block {
435 *self = ContentBlock::ResourceLink { resource_link };
436 return;
437 }
438 }
439
440 let new_content = self.block_string_contents(block);
441
442 match self {
443 ContentBlock::Empty => {
444 *self = Self::create_markdown_block(new_content, language_registry, cx);
445 }
446 ContentBlock::Markdown { markdown } => {
447 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
448 }
449 ContentBlock::ResourceLink { resource_link } => {
450 let existing_content = Self::resource_link_md(&resource_link.uri);
451 let combined = format!("{}\n{}", existing_content, new_content);
452
453 *self = Self::create_markdown_block(combined, language_registry, cx);
454 }
455 }
456 }
457
458 fn create_markdown_block(
459 content: String,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) -> ContentBlock {
463 ContentBlock::Markdown {
464 markdown: cx
465 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
466 }
467 }
468
469 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
470 match block {
471 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
472 acp::ContentBlock::ResourceLink(resource_link) => {
473 Self::resource_link_md(&resource_link.uri)
474 }
475 acp::ContentBlock::Resource(acp::EmbeddedResource {
476 resource:
477 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
478 uri,
479 ..
480 }),
481 ..
482 }) => Self::resource_link_md(&uri),
483 acp::ContentBlock::Image(image) => Self::image_md(&image),
484 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
485 }
486 }
487
488 fn resource_link_md(uri: &str) -> String {
489 if let Some(uri) = MentionUri::parse(uri).log_err() {
490 uri.as_link().to_string()
491 } else {
492 uri.to_string()
493 }
494 }
495
496 fn image_md(_image: &acp::ImageContent) -> String {
497 "`Image`".into()
498 }
499
500 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
501 match self {
502 ContentBlock::Empty => "",
503 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
504 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
505 }
506 }
507
508 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
509 match self {
510 ContentBlock::Empty => None,
511 ContentBlock::Markdown { markdown } => Some(markdown),
512 ContentBlock::ResourceLink { .. } => None,
513 }
514 }
515
516 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
517 match self {
518 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
519 _ => None,
520 }
521 }
522}
523
524#[derive(Debug)]
525pub enum ToolCallContent {
526 ContentBlock(ContentBlock),
527 Diff(Entity<Diff>),
528 Terminal(Entity<Terminal>),
529}
530
531impl ToolCallContent {
532 pub fn from_acp(
533 content: acp::ToolCallContent,
534 language_registry: Arc<LanguageRegistry>,
535 cx: &mut App,
536 ) -> Self {
537 match content {
538 acp::ToolCallContent::Content { content } => {
539 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
540 }
541 acp::ToolCallContent::Diff { diff } => {
542 Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
543 }
544 }
545 }
546
547 pub fn to_markdown(&self, cx: &App) -> String {
548 match self {
549 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
550 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
551 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
552 }
553 }
554}
555
556#[derive(Debug, PartialEq)]
557pub enum ToolCallUpdate {
558 UpdateFields(acp::ToolCallUpdate),
559 UpdateDiff(ToolCallUpdateDiff),
560 UpdateTerminal(ToolCallUpdateTerminal),
561}
562
563impl ToolCallUpdate {
564 fn id(&self) -> &acp::ToolCallId {
565 match self {
566 Self::UpdateFields(update) => &update.id,
567 Self::UpdateDiff(diff) => &diff.id,
568 Self::UpdateTerminal(terminal) => &terminal.id,
569 }
570 }
571}
572
573impl From<acp::ToolCallUpdate> for ToolCallUpdate {
574 fn from(update: acp::ToolCallUpdate) -> Self {
575 Self::UpdateFields(update)
576 }
577}
578
579impl From<ToolCallUpdateDiff> for ToolCallUpdate {
580 fn from(diff: ToolCallUpdateDiff) -> Self {
581 Self::UpdateDiff(diff)
582 }
583}
584
585#[derive(Debug, PartialEq)]
586pub struct ToolCallUpdateDiff {
587 pub id: acp::ToolCallId,
588 pub diff: Entity<Diff>,
589}
590
591impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
592 fn from(terminal: ToolCallUpdateTerminal) -> Self {
593 Self::UpdateTerminal(terminal)
594 }
595}
596
597#[derive(Debug, PartialEq)]
598pub struct ToolCallUpdateTerminal {
599 pub id: acp::ToolCallId,
600 pub terminal: Entity<Terminal>,
601}
602
603#[derive(Debug, Default)]
604pub struct Plan {
605 pub entries: Vec<PlanEntry>,
606}
607
608#[derive(Debug)]
609pub struct PlanStats<'a> {
610 pub in_progress_entry: Option<&'a PlanEntry>,
611 pub pending: u32,
612 pub completed: u32,
613}
614
615impl Plan {
616 pub fn is_empty(&self) -> bool {
617 self.entries.is_empty()
618 }
619
620 pub fn stats(&self) -> PlanStats<'_> {
621 let mut stats = PlanStats {
622 in_progress_entry: None,
623 pending: 0,
624 completed: 0,
625 };
626
627 for entry in &self.entries {
628 match &entry.status {
629 acp::PlanEntryStatus::Pending => {
630 stats.pending += 1;
631 }
632 acp::PlanEntryStatus::InProgress => {
633 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
634 }
635 acp::PlanEntryStatus::Completed => {
636 stats.completed += 1;
637 }
638 }
639 }
640
641 stats
642 }
643}
644
645#[derive(Debug)]
646pub struct PlanEntry {
647 pub content: Entity<Markdown>,
648 pub priority: acp::PlanEntryPriority,
649 pub status: acp::PlanEntryStatus,
650}
651
652impl PlanEntry {
653 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
654 Self {
655 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
656 priority: entry.priority,
657 status: entry.status,
658 }
659 }
660}
661
662#[derive(Debug, Clone)]
663pub struct RetryStatus {
664 pub last_error: SharedString,
665 pub attempt: usize,
666 pub max_attempts: usize,
667 pub started_at: Instant,
668 pub duration: Duration,
669}
670
671pub struct AcpThread {
672 title: SharedString,
673 entries: Vec<AgentThreadEntry>,
674 plan: Plan,
675 project: Entity<Project>,
676 action_log: Entity<ActionLog>,
677 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
678 send_task: Option<Task<()>>,
679 connection: Rc<dyn AgentConnection>,
680 session_id: acp::SessionId,
681}
682
683#[derive(Debug)]
684pub enum AcpThreadEvent {
685 NewEntry,
686 EntryUpdated(usize),
687 EntriesRemoved(Range<usize>),
688 ToolAuthorizationRequired,
689 Retry(RetryStatus),
690 Stopped,
691 Error,
692 ServerExited(ExitStatus),
693}
694
695impl EventEmitter<AcpThreadEvent> for AcpThread {}
696
697#[derive(PartialEq, Eq)]
698pub enum ThreadStatus {
699 Idle,
700 WaitingForToolConfirmation,
701 Generating,
702}
703
704#[derive(Debug, Clone)]
705pub enum LoadError {
706 Unsupported {
707 error_message: SharedString,
708 upgrade_message: SharedString,
709 upgrade_command: String,
710 },
711 Exited(i32),
712 Other(SharedString),
713}
714
715impl Display for LoadError {
716 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
717 match self {
718 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
719 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
720 LoadError::Other(msg) => write!(f, "{}", msg),
721 }
722 }
723}
724
725impl Error for LoadError {}
726
727impl AcpThread {
728 pub fn new(
729 title: impl Into<SharedString>,
730 connection: Rc<dyn AgentConnection>,
731 project: Entity<Project>,
732 session_id: acp::SessionId,
733 cx: &mut Context<Self>,
734 ) -> Self {
735 let action_log = cx.new(|_| ActionLog::new(project.clone()));
736
737 Self {
738 action_log,
739 shared_buffers: Default::default(),
740 entries: Default::default(),
741 plan: Default::default(),
742 title: title.into(),
743 project,
744 send_task: None,
745 connection,
746 session_id,
747 }
748 }
749
750 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
751 &self.connection
752 }
753
754 pub fn action_log(&self) -> &Entity<ActionLog> {
755 &self.action_log
756 }
757
758 pub fn project(&self) -> &Entity<Project> {
759 &self.project
760 }
761
762 pub fn title(&self) -> SharedString {
763 self.title.clone()
764 }
765
766 pub fn entries(&self) -> &[AgentThreadEntry] {
767 &self.entries
768 }
769
770 pub fn session_id(&self) -> &acp::SessionId {
771 &self.session_id
772 }
773
774 pub fn status(&self) -> ThreadStatus {
775 if self.send_task.is_some() {
776 if self.waiting_for_tool_confirmation() {
777 ThreadStatus::WaitingForToolConfirmation
778 } else {
779 ThreadStatus::Generating
780 }
781 } else {
782 ThreadStatus::Idle
783 }
784 }
785
786 pub fn has_pending_edit_tool_calls(&self) -> bool {
787 for entry in self.entries.iter().rev() {
788 match entry {
789 AgentThreadEntry::UserMessage(_) => return false,
790 AgentThreadEntry::ToolCall(
791 call @ ToolCall {
792 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
793 ..
794 },
795 ) if call.diffs().next().is_some() => {
796 return true;
797 }
798 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
799 }
800 }
801
802 false
803 }
804
805 pub fn used_tools_since_last_user_message(&self) -> bool {
806 for entry in self.entries.iter().rev() {
807 match entry {
808 AgentThreadEntry::UserMessage(..) => return false,
809 AgentThreadEntry::AssistantMessage(..) => continue,
810 AgentThreadEntry::ToolCall(..) => return true,
811 }
812 }
813
814 false
815 }
816
817 pub fn handle_session_update(
818 &mut self,
819 update: acp::SessionUpdate,
820 cx: &mut Context<Self>,
821 ) -> Result<(), acp::Error> {
822 match update {
823 acp::SessionUpdate::UserMessageChunk { content } => {
824 self.push_user_content_block(None, content, cx);
825 }
826 acp::SessionUpdate::AgentMessageChunk { content } => {
827 self.push_assistant_content_block(content, false, cx);
828 }
829 acp::SessionUpdate::AgentThoughtChunk { content } => {
830 self.push_assistant_content_block(content, true, cx);
831 }
832 acp::SessionUpdate::ToolCall(tool_call) => {
833 self.upsert_tool_call(tool_call, cx)?;
834 }
835 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
836 self.update_tool_call(tool_call_update, cx)?;
837 }
838 acp::SessionUpdate::Plan(plan) => {
839 self.update_plan(plan, cx);
840 }
841 }
842 Ok(())
843 }
844
845 pub fn push_user_content_block(
846 &mut self,
847 message_id: Option<UserMessageId>,
848 chunk: acp::ContentBlock,
849 cx: &mut Context<Self>,
850 ) {
851 let language_registry = self.project.read(cx).languages().clone();
852 let entries_len = self.entries.len();
853
854 if let Some(last_entry) = self.entries.last_mut()
855 && let AgentThreadEntry::UserMessage(UserMessage {
856 id,
857 content,
858 chunks,
859 ..
860 }) = last_entry
861 {
862 *id = message_id.or(id.take());
863 content.append(chunk.clone(), &language_registry, cx);
864 chunks.push(chunk);
865 let idx = entries_len - 1;
866 cx.emit(AcpThreadEvent::EntryUpdated(idx));
867 } else {
868 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
869 self.push_entry(
870 AgentThreadEntry::UserMessage(UserMessage {
871 id: message_id,
872 content,
873 chunks: vec![chunk],
874 checkpoint: None,
875 }),
876 cx,
877 );
878 }
879 }
880
881 pub fn push_assistant_content_block(
882 &mut self,
883 chunk: acp::ContentBlock,
884 is_thought: bool,
885 cx: &mut Context<Self>,
886 ) {
887 let language_registry = self.project.read(cx).languages().clone();
888 let entries_len = self.entries.len();
889 if let Some(last_entry) = self.entries.last_mut()
890 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
891 {
892 let idx = entries_len - 1;
893 cx.emit(AcpThreadEvent::EntryUpdated(idx));
894 match (chunks.last_mut(), is_thought) {
895 (Some(AssistantMessageChunk::Message { block }), false)
896 | (Some(AssistantMessageChunk::Thought { block }), true) => {
897 block.append(chunk, &language_registry, cx)
898 }
899 _ => {
900 let block = ContentBlock::new(chunk, &language_registry, cx);
901 if is_thought {
902 chunks.push(AssistantMessageChunk::Thought { block })
903 } else {
904 chunks.push(AssistantMessageChunk::Message { block })
905 }
906 }
907 }
908 } else {
909 let block = ContentBlock::new(chunk, &language_registry, cx);
910 let chunk = if is_thought {
911 AssistantMessageChunk::Thought { block }
912 } else {
913 AssistantMessageChunk::Message { block }
914 };
915
916 self.push_entry(
917 AgentThreadEntry::AssistantMessage(AssistantMessage {
918 chunks: vec![chunk],
919 }),
920 cx,
921 );
922 }
923 }
924
925 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
926 self.entries.push(entry);
927 cx.emit(AcpThreadEvent::NewEntry);
928 }
929
930 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
931 cx.emit(AcpThreadEvent::Retry(status));
932 }
933
934 pub fn update_tool_call(
935 &mut self,
936 update: impl Into<ToolCallUpdate>,
937 cx: &mut Context<Self>,
938 ) -> Result<()> {
939 let update = update.into();
940 let languages = self.project.read(cx).languages().clone();
941
942 let (ix, current_call) = self
943 .tool_call_mut(update.id())
944 .context("Tool call not found")?;
945 match update {
946 ToolCallUpdate::UpdateFields(update) => {
947 let location_updated = update.fields.locations.is_some();
948 current_call.update_fields(update.fields, languages, cx);
949 if location_updated {
950 self.resolve_locations(update.id.clone(), cx);
951 }
952 }
953 ToolCallUpdate::UpdateDiff(update) => {
954 current_call.content.clear();
955 current_call
956 .content
957 .push(ToolCallContent::Diff(update.diff));
958 }
959 ToolCallUpdate::UpdateTerminal(update) => {
960 current_call.content.clear();
961 current_call
962 .content
963 .push(ToolCallContent::Terminal(update.terminal));
964 }
965 }
966
967 cx.emit(AcpThreadEvent::EntryUpdated(ix));
968
969 Ok(())
970 }
971
972 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
973 pub fn upsert_tool_call(
974 &mut self,
975 tool_call: acp::ToolCall,
976 cx: &mut Context<Self>,
977 ) -> Result<(), acp::Error> {
978 let status = tool_call.status.into();
979 self.upsert_tool_call_inner(tool_call.into(), status, cx)
980 }
981
982 /// Fails if id does not match an existing entry.
983 pub fn upsert_tool_call_inner(
984 &mut self,
985 tool_call_update: acp::ToolCallUpdate,
986 status: ToolCallStatus,
987 cx: &mut Context<Self>,
988 ) -> Result<(), acp::Error> {
989 let language_registry = self.project.read(cx).languages().clone();
990 let id = tool_call_update.id.clone();
991
992 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
993 current_call.update_fields(tool_call_update.fields, language_registry, cx);
994 current_call.status = status;
995
996 cx.emit(AcpThreadEvent::EntryUpdated(ix));
997 } else {
998 let call =
999 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1000 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1001 };
1002
1003 self.resolve_locations(id, cx);
1004 Ok(())
1005 }
1006
1007 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1008 // The tool call we are looking for is typically the last one, or very close to the end.
1009 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1010 self.entries
1011 .iter_mut()
1012 .enumerate()
1013 .rev()
1014 .find_map(|(index, tool_call)| {
1015 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1016 && &tool_call.id == id
1017 {
1018 Some((index, tool_call))
1019 } else {
1020 None
1021 }
1022 })
1023 }
1024
1025 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1026 let project = self.project.clone();
1027 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1028 return;
1029 };
1030 let task = tool_call.resolve_locations(project, cx);
1031 cx.spawn(async move |this, cx| {
1032 let resolved_locations = task.await;
1033 this.update(cx, |this, cx| {
1034 let project = this.project.clone();
1035 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1036 return;
1037 };
1038 if let Some(Some(location)) = resolved_locations.last() {
1039 project.update(cx, |project, cx| {
1040 if let Some(agent_location) = project.agent_location() {
1041 let should_ignore = agent_location.buffer == location.buffer
1042 && location
1043 .buffer
1044 .update(cx, |buffer, _| {
1045 let snapshot = buffer.snapshot();
1046 let old_position =
1047 agent_location.position.to_point(&snapshot);
1048 let new_position = location.position.to_point(&snapshot);
1049 // ignore this so that when we get updates from the edit tool
1050 // the position doesn't reset to the startof line
1051 old_position.row == new_position.row
1052 && old_position.column > new_position.column
1053 })
1054 .ok()
1055 .unwrap_or_default();
1056 if !should_ignore {
1057 project.set_agent_location(Some(location.clone()), cx);
1058 }
1059 }
1060 });
1061 }
1062 if tool_call.resolved_locations != resolved_locations {
1063 tool_call.resolved_locations = resolved_locations;
1064 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1065 }
1066 })
1067 })
1068 .detach();
1069 }
1070
1071 pub fn request_tool_call_authorization(
1072 &mut self,
1073 tool_call: acp::ToolCallUpdate,
1074 options: Vec<acp::PermissionOption>,
1075 cx: &mut Context<Self>,
1076 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1077 let (tx, rx) = oneshot::channel();
1078
1079 let status = ToolCallStatus::WaitingForConfirmation {
1080 options,
1081 respond_tx: tx,
1082 };
1083
1084 self.upsert_tool_call_inner(tool_call, status, cx)?;
1085 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1086 Ok(rx)
1087 }
1088
1089 pub fn authorize_tool_call(
1090 &mut self,
1091 id: acp::ToolCallId,
1092 option_id: acp::PermissionOptionId,
1093 option_kind: acp::PermissionOptionKind,
1094 cx: &mut Context<Self>,
1095 ) {
1096 let Some((ix, call)) = self.tool_call_mut(&id) else {
1097 return;
1098 };
1099
1100 let new_status = match option_kind {
1101 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1102 ToolCallStatus::Rejected
1103 }
1104 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1105 ToolCallStatus::InProgress
1106 }
1107 };
1108
1109 let curr_status = mem::replace(&mut call.status, new_status);
1110
1111 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1112 respond_tx.send(option_id).log_err();
1113 } else if cfg!(debug_assertions) {
1114 panic!("tried to authorize an already authorized tool call");
1115 }
1116
1117 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1118 }
1119
1120 /// Returns true if the last turn is awaiting tool authorization
1121 pub fn waiting_for_tool_confirmation(&self) -> bool {
1122 for entry in self.entries.iter().rev() {
1123 match &entry {
1124 AgentThreadEntry::ToolCall(call) => match call.status {
1125 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1126 ToolCallStatus::Pending
1127 | ToolCallStatus::InProgress
1128 | ToolCallStatus::Completed
1129 | ToolCallStatus::Failed
1130 | ToolCallStatus::Rejected
1131 | ToolCallStatus::Canceled => continue,
1132 },
1133 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1134 // Reached the beginning of the turn
1135 return false;
1136 }
1137 }
1138 }
1139 false
1140 }
1141
1142 pub fn plan(&self) -> &Plan {
1143 &self.plan
1144 }
1145
1146 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1147 let new_entries_len = request.entries.len();
1148 let mut new_entries = request.entries.into_iter();
1149
1150 // Reuse existing markdown to prevent flickering
1151 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1152 let PlanEntry {
1153 content,
1154 priority,
1155 status,
1156 } = old;
1157 content.update(cx, |old, cx| {
1158 old.replace(new.content, cx);
1159 });
1160 *priority = new.priority;
1161 *status = new.status;
1162 }
1163 for new in new_entries {
1164 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1165 }
1166 self.plan.entries.truncate(new_entries_len);
1167
1168 cx.notify();
1169 }
1170
1171 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1172 self.plan
1173 .entries
1174 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1175 cx.notify();
1176 }
1177
1178 #[cfg(any(test, feature = "test-support"))]
1179 pub fn send_raw(
1180 &mut self,
1181 message: &str,
1182 cx: &mut Context<Self>,
1183 ) -> BoxFuture<'static, Result<()>> {
1184 self.send(
1185 vec![acp::ContentBlock::Text(acp::TextContent {
1186 text: message.to_string(),
1187 annotations: None,
1188 })],
1189 cx,
1190 )
1191 }
1192
1193 pub fn send(
1194 &mut self,
1195 message: Vec<acp::ContentBlock>,
1196 cx: &mut Context<Self>,
1197 ) -> BoxFuture<'static, Result<()>> {
1198 let block = ContentBlock::new_combined(
1199 message.clone(),
1200 self.project.read(cx).languages().clone(),
1201 cx,
1202 );
1203 let request = acp::PromptRequest {
1204 prompt: message.clone(),
1205 session_id: self.session_id.clone(),
1206 };
1207 let git_store = self.project.read(cx).git_store().clone();
1208
1209 let message_id = if self
1210 .connection
1211 .session_editor(&self.session_id, cx)
1212 .is_some()
1213 {
1214 Some(UserMessageId::new())
1215 } else {
1216 None
1217 };
1218
1219 self.run_turn(cx, async move |this, cx| {
1220 this.update(cx, |this, cx| {
1221 this.push_entry(
1222 AgentThreadEntry::UserMessage(UserMessage {
1223 id: message_id.clone(),
1224 content: block,
1225 chunks: message,
1226 checkpoint: None,
1227 }),
1228 cx,
1229 );
1230 })
1231 .ok();
1232
1233 let old_checkpoint = git_store
1234 .update(cx, |git, cx| git.checkpoint(cx))?
1235 .await
1236 .context("failed to get old checkpoint")
1237 .log_err();
1238 this.update(cx, |this, cx| {
1239 if let Some((_ix, message)) = this.last_user_message() {
1240 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1241 git_checkpoint,
1242 show: false,
1243 });
1244 }
1245 this.connection.prompt(message_id, request, cx)
1246 })?
1247 .await
1248 })
1249 }
1250
1251 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1252 self.run_turn(cx, async move |this, cx| {
1253 this.update(cx, |this, cx| {
1254 this.connection
1255 .resume(&this.session_id, cx)
1256 .map(|resume| resume.run(cx))
1257 })?
1258 .context("resuming a session is not supported")?
1259 .await
1260 })
1261 }
1262
1263 fn run_turn(
1264 &mut self,
1265 cx: &mut Context<Self>,
1266 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1267 ) -> BoxFuture<'static, Result<()>> {
1268 self.clear_completed_plan_entries(cx);
1269
1270 let (tx, rx) = oneshot::channel();
1271 let cancel_task = self.cancel(cx);
1272
1273 self.send_task = Some(cx.spawn(async move |this, cx| {
1274 cancel_task.await;
1275 tx.send(f(this, cx).await).ok();
1276 }));
1277
1278 cx.spawn(async move |this, cx| {
1279 let response = rx.await;
1280
1281 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1282 .await?;
1283
1284 this.update(cx, |this, cx| {
1285 match response {
1286 Ok(Err(e)) => {
1287 this.send_task.take();
1288 cx.emit(AcpThreadEvent::Error);
1289 Err(e)
1290 }
1291 result => {
1292 let canceled = matches!(
1293 result,
1294 Ok(Ok(acp::PromptResponse {
1295 stop_reason: acp::StopReason::Canceled
1296 }))
1297 );
1298
1299 // We only take the task if the current prompt wasn't canceled.
1300 //
1301 // This prompt may have been canceled because another one was sent
1302 // while it was still generating. In these cases, dropping `send_task`
1303 // would cause the next generation to be canceled.
1304 if !canceled {
1305 this.send_task.take();
1306 }
1307
1308 cx.emit(AcpThreadEvent::Stopped);
1309 Ok(())
1310 }
1311 }
1312 })?
1313 })
1314 .boxed()
1315 }
1316
1317 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1318 let Some(send_task) = self.send_task.take() else {
1319 return Task::ready(());
1320 };
1321
1322 for entry in self.entries.iter_mut() {
1323 if let AgentThreadEntry::ToolCall(call) = entry {
1324 let cancel = matches!(
1325 call.status,
1326 ToolCallStatus::Pending
1327 | ToolCallStatus::WaitingForConfirmation { .. }
1328 | ToolCallStatus::InProgress
1329 );
1330
1331 if cancel {
1332 call.status = ToolCallStatus::Canceled;
1333 }
1334 }
1335 }
1336
1337 self.connection.cancel(&self.session_id, cx);
1338
1339 // Wait for the send task to complete
1340 cx.foreground_executor().spawn(send_task)
1341 }
1342
1343 /// Rewinds this thread to before the entry at `index`, removing it and all
1344 /// subsequent entries while reverting any changes made from that point.
1345 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1346 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1347 return Task::ready(Err(anyhow!("not supported")));
1348 };
1349 let Some(message) = self.user_message(&id) else {
1350 return Task::ready(Err(anyhow!("message not found")));
1351 };
1352
1353 let checkpoint = message
1354 .checkpoint
1355 .as_ref()
1356 .map(|c| c.git_checkpoint.clone());
1357
1358 let git_store = self.project.read(cx).git_store().clone();
1359 cx.spawn(async move |this, cx| {
1360 if let Some(checkpoint) = checkpoint {
1361 git_store
1362 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1363 .await?;
1364 }
1365
1366 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1367 .await?;
1368 this.update(cx, |this, cx| {
1369 if let Some((ix, _)) = this.user_message_mut(&id) {
1370 let range = ix..this.entries.len();
1371 this.entries.truncate(ix);
1372 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1373 }
1374 })
1375 })
1376 }
1377
1378 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1379 let git_store = self.project.read(cx).git_store().clone();
1380
1381 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1382 if let Some(checkpoint) = message.checkpoint.as_ref() {
1383 checkpoint.git_checkpoint.clone()
1384 } else {
1385 return Task::ready(Ok(()));
1386 }
1387 } else {
1388 return Task::ready(Ok(()));
1389 };
1390
1391 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1392 cx.spawn(async move |this, cx| {
1393 let new_checkpoint = new_checkpoint
1394 .await
1395 .context("failed to get new checkpoint")
1396 .log_err();
1397 if let Some(new_checkpoint) = new_checkpoint {
1398 let equal = git_store
1399 .update(cx, |git, cx| {
1400 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1401 })?
1402 .await
1403 .unwrap_or(true);
1404 this.update(cx, |this, cx| {
1405 let (ix, message) = this.last_user_message().context("no user message")?;
1406 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1407 checkpoint.show = !equal;
1408 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1409 anyhow::Ok(())
1410 })??;
1411 }
1412
1413 Ok(())
1414 })
1415 }
1416
1417 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1418 self.entries
1419 .iter_mut()
1420 .enumerate()
1421 .rev()
1422 .find_map(|(ix, entry)| {
1423 if let AgentThreadEntry::UserMessage(message) = entry {
1424 Some((ix, message))
1425 } else {
1426 None
1427 }
1428 })
1429 }
1430
1431 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1432 self.entries.iter().find_map(|entry| {
1433 if let AgentThreadEntry::UserMessage(message) = entry {
1434 if message.id.as_ref() == Some(id) {
1435 Some(message)
1436 } else {
1437 None
1438 }
1439 } else {
1440 None
1441 }
1442 })
1443 }
1444
1445 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1446 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1447 if let AgentThreadEntry::UserMessage(message) = entry {
1448 if message.id.as_ref() == Some(id) {
1449 Some((ix, message))
1450 } else {
1451 None
1452 }
1453 } else {
1454 None
1455 }
1456 })
1457 }
1458
1459 pub fn read_text_file(
1460 &self,
1461 path: PathBuf,
1462 line: Option<u32>,
1463 limit: Option<u32>,
1464 reuse_shared_snapshot: bool,
1465 cx: &mut Context<Self>,
1466 ) -> Task<Result<String>> {
1467 let project = self.project.clone();
1468 let action_log = self.action_log.clone();
1469 cx.spawn(async move |this, cx| {
1470 let load = project.update(cx, |project, cx| {
1471 let path = project
1472 .project_path_for_absolute_path(&path, cx)
1473 .context("invalid path")?;
1474 anyhow::Ok(project.open_buffer(path, cx))
1475 });
1476 let buffer = load??.await?;
1477
1478 let snapshot = if reuse_shared_snapshot {
1479 this.read_with(cx, |this, _| {
1480 this.shared_buffers.get(&buffer.clone()).cloned()
1481 })
1482 .log_err()
1483 .flatten()
1484 } else {
1485 None
1486 };
1487
1488 let snapshot = if let Some(snapshot) = snapshot {
1489 snapshot
1490 } else {
1491 action_log.update(cx, |action_log, cx| {
1492 action_log.buffer_read(buffer.clone(), cx);
1493 })?;
1494 project.update(cx, |project, cx| {
1495 let position = buffer
1496 .read(cx)
1497 .snapshot()
1498 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1499 project.set_agent_location(
1500 Some(AgentLocation {
1501 buffer: buffer.downgrade(),
1502 position,
1503 }),
1504 cx,
1505 );
1506 })?;
1507
1508 buffer.update(cx, |buffer, _| buffer.snapshot())?
1509 };
1510
1511 this.update(cx, |this, _| {
1512 let text = snapshot.text();
1513 this.shared_buffers.insert(buffer.clone(), snapshot);
1514 if line.is_none() && limit.is_none() {
1515 return Ok(text);
1516 }
1517 let limit = limit.unwrap_or(u32::MAX) as usize;
1518 let Some(line) = line else {
1519 return Ok(text.lines().take(limit).collect::<String>());
1520 };
1521
1522 let count = text.lines().count();
1523 if count < line as usize {
1524 anyhow::bail!("There are only {} lines", count);
1525 }
1526 Ok(text
1527 .lines()
1528 .skip(line as usize + 1)
1529 .take(limit)
1530 .collect::<String>())
1531 })?
1532 })
1533 }
1534
1535 pub fn write_text_file(
1536 &self,
1537 path: PathBuf,
1538 content: String,
1539 cx: &mut Context<Self>,
1540 ) -> Task<Result<()>> {
1541 let project = self.project.clone();
1542 let action_log = self.action_log.clone();
1543 cx.spawn(async move |this, cx| {
1544 let load = project.update(cx, |project, cx| {
1545 let path = project
1546 .project_path_for_absolute_path(&path, cx)
1547 .context("invalid path")?;
1548 anyhow::Ok(project.open_buffer(path, cx))
1549 });
1550 let buffer = load??.await?;
1551 let snapshot = this.update(cx, |this, cx| {
1552 this.shared_buffers
1553 .get(&buffer)
1554 .cloned()
1555 .unwrap_or_else(|| buffer.read(cx).snapshot())
1556 })?;
1557 let edits = cx
1558 .background_executor()
1559 .spawn(async move {
1560 let old_text = snapshot.text();
1561 text_diff(old_text.as_str(), &content)
1562 .into_iter()
1563 .map(|(range, replacement)| {
1564 (
1565 snapshot.anchor_after(range.start)
1566 ..snapshot.anchor_before(range.end),
1567 replacement,
1568 )
1569 })
1570 .collect::<Vec<_>>()
1571 })
1572 .await;
1573 cx.update(|cx| {
1574 project.update(cx, |project, cx| {
1575 project.set_agent_location(
1576 Some(AgentLocation {
1577 buffer: buffer.downgrade(),
1578 position: edits
1579 .last()
1580 .map(|(range, _)| range.end)
1581 .unwrap_or(Anchor::MIN),
1582 }),
1583 cx,
1584 );
1585 });
1586
1587 action_log.update(cx, |action_log, cx| {
1588 action_log.buffer_read(buffer.clone(), cx);
1589 });
1590 buffer.update(cx, |buffer, cx| {
1591 buffer.edit(edits, None, cx);
1592 });
1593 action_log.update(cx, |action_log, cx| {
1594 action_log.buffer_edited(buffer.clone(), cx);
1595 });
1596 })?;
1597 project
1598 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1599 .await
1600 })
1601 }
1602
1603 pub fn to_markdown(&self, cx: &App) -> String {
1604 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1605 }
1606
1607 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1608 cx.emit(AcpThreadEvent::ServerExited(status));
1609 }
1610}
1611
1612fn markdown_for_raw_output(
1613 raw_output: &serde_json::Value,
1614 language_registry: &Arc<LanguageRegistry>,
1615 cx: &mut App,
1616) -> Option<Entity<Markdown>> {
1617 match raw_output {
1618 serde_json::Value::Null => None,
1619 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1620 Markdown::new(
1621 value.to_string().into(),
1622 Some(language_registry.clone()),
1623 None,
1624 cx,
1625 )
1626 })),
1627 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1628 Markdown::new(
1629 value.to_string().into(),
1630 Some(language_registry.clone()),
1631 None,
1632 cx,
1633 )
1634 })),
1635 serde_json::Value::String(value) => Some(cx.new(|cx| {
1636 Markdown::new(
1637 value.clone().into(),
1638 Some(language_registry.clone()),
1639 None,
1640 cx,
1641 )
1642 })),
1643 value => Some(cx.new(|cx| {
1644 Markdown::new(
1645 format!("```json\n{}\n```", value).into(),
1646 Some(language_registry.clone()),
1647 None,
1648 cx,
1649 )
1650 })),
1651 }
1652}
1653
1654#[cfg(test)]
1655mod tests {
1656 use super::*;
1657 use anyhow::anyhow;
1658 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1659 use gpui::{AsyncApp, TestAppContext, WeakEntity};
1660 use indoc::indoc;
1661 use project::{FakeFs, Fs};
1662 use rand::Rng as _;
1663 use serde_json::json;
1664 use settings::SettingsStore;
1665 use smol::stream::StreamExt as _;
1666 use std::{
1667 any::Any,
1668 cell::RefCell,
1669 path::Path,
1670 rc::Rc,
1671 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1672 time::Duration,
1673 };
1674 use util::path;
1675
1676 fn init_test(cx: &mut TestAppContext) {
1677 env_logger::try_init().ok();
1678 cx.update(|cx| {
1679 let settings_store = SettingsStore::test(cx);
1680 cx.set_global(settings_store);
1681 Project::init_settings(cx);
1682 language::init(cx);
1683 });
1684 }
1685
1686 #[gpui::test]
1687 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1688 init_test(cx);
1689
1690 let fs = FakeFs::new(cx.executor());
1691 let project = Project::test(fs, [], cx).await;
1692 let connection = Rc::new(FakeAgentConnection::new());
1693 let thread = cx
1694 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1695 .await
1696 .unwrap();
1697
1698 // Test creating a new user message
1699 thread.update(cx, |thread, cx| {
1700 thread.push_user_content_block(
1701 None,
1702 acp::ContentBlock::Text(acp::TextContent {
1703 annotations: None,
1704 text: "Hello, ".to_string(),
1705 }),
1706 cx,
1707 );
1708 });
1709
1710 thread.update(cx, |thread, cx| {
1711 assert_eq!(thread.entries.len(), 1);
1712 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1713 assert_eq!(user_msg.id, None);
1714 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1715 } else {
1716 panic!("Expected UserMessage");
1717 }
1718 });
1719
1720 // Test appending to existing user message
1721 let message_1_id = UserMessageId::new();
1722 thread.update(cx, |thread, cx| {
1723 thread.push_user_content_block(
1724 Some(message_1_id.clone()),
1725 acp::ContentBlock::Text(acp::TextContent {
1726 annotations: None,
1727 text: "world!".to_string(),
1728 }),
1729 cx,
1730 );
1731 });
1732
1733 thread.update(cx, |thread, cx| {
1734 assert_eq!(thread.entries.len(), 1);
1735 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1736 assert_eq!(user_msg.id, Some(message_1_id));
1737 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1738 } else {
1739 panic!("Expected UserMessage");
1740 }
1741 });
1742
1743 // Test creating new user message after assistant message
1744 thread.update(cx, |thread, cx| {
1745 thread.push_assistant_content_block(
1746 acp::ContentBlock::Text(acp::TextContent {
1747 annotations: None,
1748 text: "Assistant response".to_string(),
1749 }),
1750 false,
1751 cx,
1752 );
1753 });
1754
1755 let message_2_id = UserMessageId::new();
1756 thread.update(cx, |thread, cx| {
1757 thread.push_user_content_block(
1758 Some(message_2_id.clone()),
1759 acp::ContentBlock::Text(acp::TextContent {
1760 annotations: None,
1761 text: "New user message".to_string(),
1762 }),
1763 cx,
1764 );
1765 });
1766
1767 thread.update(cx, |thread, cx| {
1768 assert_eq!(thread.entries.len(), 3);
1769 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1770 assert_eq!(user_msg.id, Some(message_2_id));
1771 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1772 } else {
1773 panic!("Expected UserMessage at index 2");
1774 }
1775 });
1776 }
1777
1778 #[gpui::test]
1779 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1780 init_test(cx);
1781
1782 let fs = FakeFs::new(cx.executor());
1783 let project = Project::test(fs, [], cx).await;
1784 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1785 |_, thread, mut cx| {
1786 async move {
1787 thread.update(&mut cx, |thread, cx| {
1788 thread
1789 .handle_session_update(
1790 acp::SessionUpdate::AgentThoughtChunk {
1791 content: "Thinking ".into(),
1792 },
1793 cx,
1794 )
1795 .unwrap();
1796 thread
1797 .handle_session_update(
1798 acp::SessionUpdate::AgentThoughtChunk {
1799 content: "hard!".into(),
1800 },
1801 cx,
1802 )
1803 .unwrap();
1804 })?;
1805 Ok(acp::PromptResponse {
1806 stop_reason: acp::StopReason::EndTurn,
1807 })
1808 }
1809 .boxed_local()
1810 },
1811 ));
1812
1813 let thread = cx
1814 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1815 .await
1816 .unwrap();
1817
1818 thread
1819 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1820 .await
1821 .unwrap();
1822
1823 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1824 assert_eq!(
1825 output,
1826 indoc! {r#"
1827 ## User
1828
1829 Hello from Zed!
1830
1831 ## Assistant
1832
1833 <thinking>
1834 Thinking hard!
1835 </thinking>
1836
1837 "#}
1838 );
1839 }
1840
1841 #[gpui::test]
1842 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1843 init_test(cx);
1844
1845 let fs = FakeFs::new(cx.executor());
1846 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1847 .await;
1848 let project = Project::test(fs.clone(), [], cx).await;
1849 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1850 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1851 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1852 move |_, thread, mut cx| {
1853 let read_file_tx = read_file_tx.clone();
1854 async move {
1855 let content = thread
1856 .update(&mut cx, |thread, cx| {
1857 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1858 })
1859 .unwrap()
1860 .await
1861 .unwrap();
1862 assert_eq!(content, "one\ntwo\nthree\n");
1863 read_file_tx.take().unwrap().send(()).unwrap();
1864 thread
1865 .update(&mut cx, |thread, cx| {
1866 thread.write_text_file(
1867 path!("/tmp/foo").into(),
1868 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1869 cx,
1870 )
1871 })
1872 .unwrap()
1873 .await
1874 .unwrap();
1875 Ok(acp::PromptResponse {
1876 stop_reason: acp::StopReason::EndTurn,
1877 })
1878 }
1879 .boxed_local()
1880 },
1881 ));
1882
1883 let (worktree, pathbuf) = project
1884 .update(cx, |project, cx| {
1885 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1886 })
1887 .await
1888 .unwrap();
1889 let buffer = project
1890 .update(cx, |project, cx| {
1891 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1892 })
1893 .await
1894 .unwrap();
1895
1896 let thread = cx
1897 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1898 .await
1899 .unwrap();
1900
1901 let request = thread.update(cx, |thread, cx| {
1902 thread.send_raw("Extend the count in /tmp/foo", cx)
1903 });
1904 read_file_rx.await.ok();
1905 buffer.update(cx, |buffer, cx| {
1906 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1907 });
1908 cx.run_until_parked();
1909 assert_eq!(
1910 buffer.read_with(cx, |buffer, _| buffer.text()),
1911 "zero\none\ntwo\nthree\nfour\nfive\n"
1912 );
1913 assert_eq!(
1914 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1915 "zero\none\ntwo\nthree\nfour\nfive\n"
1916 );
1917 request.await.unwrap();
1918 }
1919
1920 #[gpui::test]
1921 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1922 init_test(cx);
1923
1924 let fs = FakeFs::new(cx.executor());
1925 let project = Project::test(fs, [], cx).await;
1926 let id = acp::ToolCallId("test".into());
1927
1928 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1929 let id = id.clone();
1930 move |_, thread, mut cx| {
1931 let id = id.clone();
1932 async move {
1933 thread
1934 .update(&mut cx, |thread, cx| {
1935 thread.handle_session_update(
1936 acp::SessionUpdate::ToolCall(acp::ToolCall {
1937 id: id.clone(),
1938 title: "Label".into(),
1939 kind: acp::ToolKind::Fetch,
1940 status: acp::ToolCallStatus::InProgress,
1941 content: vec![],
1942 locations: vec![],
1943 raw_input: None,
1944 raw_output: None,
1945 }),
1946 cx,
1947 )
1948 })
1949 .unwrap()
1950 .unwrap();
1951 Ok(acp::PromptResponse {
1952 stop_reason: acp::StopReason::EndTurn,
1953 })
1954 }
1955 .boxed_local()
1956 }
1957 }));
1958
1959 let thread = cx
1960 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1961 .await
1962 .unwrap();
1963
1964 let request = thread.update(cx, |thread, cx| {
1965 thread.send_raw("Fetch https://example.com", cx)
1966 });
1967
1968 run_until_first_tool_call(&thread, cx).await;
1969
1970 thread.read_with(cx, |thread, _| {
1971 assert!(matches!(
1972 thread.entries[1],
1973 AgentThreadEntry::ToolCall(ToolCall {
1974 status: ToolCallStatus::InProgress,
1975 ..
1976 })
1977 ));
1978 });
1979
1980 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1981
1982 thread.read_with(cx, |thread, _| {
1983 assert!(matches!(
1984 &thread.entries[1],
1985 AgentThreadEntry::ToolCall(ToolCall {
1986 status: ToolCallStatus::Canceled,
1987 ..
1988 })
1989 ));
1990 });
1991
1992 thread
1993 .update(cx, |thread, cx| {
1994 thread.handle_session_update(
1995 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1996 id,
1997 fields: acp::ToolCallUpdateFields {
1998 status: Some(acp::ToolCallStatus::Completed),
1999 ..Default::default()
2000 },
2001 }),
2002 cx,
2003 )
2004 })
2005 .unwrap();
2006
2007 request.await.unwrap();
2008
2009 thread.read_with(cx, |thread, _| {
2010 assert!(matches!(
2011 thread.entries[1],
2012 AgentThreadEntry::ToolCall(ToolCall {
2013 status: ToolCallStatus::Completed,
2014 ..
2015 })
2016 ));
2017 });
2018 }
2019
2020 #[gpui::test]
2021 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2022 init_test(cx);
2023 let fs = FakeFs::new(cx.background_executor.clone());
2024 fs.insert_tree(path!("/test"), json!({})).await;
2025 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2026
2027 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2028 move |_, thread, mut cx| {
2029 async move {
2030 thread
2031 .update(&mut cx, |thread, cx| {
2032 thread.handle_session_update(
2033 acp::SessionUpdate::ToolCall(acp::ToolCall {
2034 id: acp::ToolCallId("test".into()),
2035 title: "Label".into(),
2036 kind: acp::ToolKind::Edit,
2037 status: acp::ToolCallStatus::Completed,
2038 content: vec![acp::ToolCallContent::Diff {
2039 diff: acp::Diff {
2040 path: "/test/test.txt".into(),
2041 old_text: None,
2042 new_text: "foo".into(),
2043 },
2044 }],
2045 locations: vec![],
2046 raw_input: None,
2047 raw_output: None,
2048 }),
2049 cx,
2050 )
2051 })
2052 .unwrap()
2053 .unwrap();
2054 Ok(acp::PromptResponse {
2055 stop_reason: acp::StopReason::EndTurn,
2056 })
2057 }
2058 .boxed_local()
2059 }
2060 }));
2061
2062 let thread = cx
2063 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2064 .await
2065 .unwrap();
2066
2067 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2068 .await
2069 .unwrap();
2070
2071 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2072 }
2073
2074 #[gpui::test(iterations = 10)]
2075 async fn test_checkpoints(cx: &mut TestAppContext) {
2076 init_test(cx);
2077 let fs = FakeFs::new(cx.background_executor.clone());
2078 fs.insert_tree(
2079 path!("/test"),
2080 json!({
2081 ".git": {}
2082 }),
2083 )
2084 .await;
2085 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2086
2087 let simulate_changes = Arc::new(AtomicBool::new(true));
2088 let next_filename = Arc::new(AtomicUsize::new(0));
2089 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2090 let simulate_changes = simulate_changes.clone();
2091 let next_filename = next_filename.clone();
2092 let fs = fs.clone();
2093 move |request, thread, mut cx| {
2094 let fs = fs.clone();
2095 let simulate_changes = simulate_changes.clone();
2096 let next_filename = next_filename.clone();
2097 async move {
2098 if simulate_changes.load(SeqCst) {
2099 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2100 fs.write(Path::new(&filename), b"").await?;
2101 }
2102
2103 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2104 panic!("expected text content block");
2105 };
2106 thread.update(&mut cx, |thread, cx| {
2107 thread
2108 .handle_session_update(
2109 acp::SessionUpdate::AgentMessageChunk {
2110 content: content.text.to_uppercase().into(),
2111 },
2112 cx,
2113 )
2114 .unwrap();
2115 })?;
2116 Ok(acp::PromptResponse {
2117 stop_reason: acp::StopReason::EndTurn,
2118 })
2119 }
2120 .boxed_local()
2121 }
2122 }));
2123 let thread = cx
2124 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2125 .await
2126 .unwrap();
2127
2128 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2129 .await
2130 .unwrap();
2131 thread.read_with(cx, |thread, cx| {
2132 assert_eq!(
2133 thread.to_markdown(cx),
2134 indoc! {"
2135 ## User (checkpoint)
2136
2137 Lorem
2138
2139 ## Assistant
2140
2141 LOREM
2142
2143 "}
2144 );
2145 });
2146 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2147
2148 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2149 .await
2150 .unwrap();
2151 thread.read_with(cx, |thread, cx| {
2152 assert_eq!(
2153 thread.to_markdown(cx),
2154 indoc! {"
2155 ## User (checkpoint)
2156
2157 Lorem
2158
2159 ## Assistant
2160
2161 LOREM
2162
2163 ## User (checkpoint)
2164
2165 ipsum
2166
2167 ## Assistant
2168
2169 IPSUM
2170
2171 "}
2172 );
2173 });
2174 assert_eq!(
2175 fs.files(),
2176 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2177 );
2178
2179 // Checkpoint isn't stored when there are no changes.
2180 simulate_changes.store(false, SeqCst);
2181 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2182 .await
2183 .unwrap();
2184 thread.read_with(cx, |thread, cx| {
2185 assert_eq!(
2186 thread.to_markdown(cx),
2187 indoc! {"
2188 ## User (checkpoint)
2189
2190 Lorem
2191
2192 ## Assistant
2193
2194 LOREM
2195
2196 ## User (checkpoint)
2197
2198 ipsum
2199
2200 ## Assistant
2201
2202 IPSUM
2203
2204 ## User
2205
2206 dolor
2207
2208 ## Assistant
2209
2210 DOLOR
2211
2212 "}
2213 );
2214 });
2215 assert_eq!(
2216 fs.files(),
2217 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2218 );
2219
2220 // Rewinding the conversation truncates the history and restores the checkpoint.
2221 thread
2222 .update(cx, |thread, cx| {
2223 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2224 panic!("unexpected entries {:?}", thread.entries)
2225 };
2226 thread.rewind(message.id.clone().unwrap(), cx)
2227 })
2228 .await
2229 .unwrap();
2230 thread.read_with(cx, |thread, cx| {
2231 assert_eq!(
2232 thread.to_markdown(cx),
2233 indoc! {"
2234 ## User (checkpoint)
2235
2236 Lorem
2237
2238 ## Assistant
2239
2240 LOREM
2241
2242 "}
2243 );
2244 });
2245 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2246 }
2247
2248 async fn run_until_first_tool_call(
2249 thread: &Entity<AcpThread>,
2250 cx: &mut TestAppContext,
2251 ) -> usize {
2252 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2253
2254 let subscription = cx.update(|cx| {
2255 cx.subscribe(thread, move |thread, _, cx| {
2256 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2257 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2258 return tx.try_send(ix).unwrap();
2259 }
2260 }
2261 })
2262 });
2263
2264 select! {
2265 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2266 panic!("Timeout waiting for tool call")
2267 }
2268 ix = rx.next().fuse() => {
2269 drop(subscription);
2270 ix.unwrap()
2271 }
2272 }
2273 }
2274
2275 #[derive(Clone, Default)]
2276 struct FakeAgentConnection {
2277 auth_methods: Vec<acp::AuthMethod>,
2278 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2279 on_user_message: Option<
2280 Rc<
2281 dyn Fn(
2282 acp::PromptRequest,
2283 WeakEntity<AcpThread>,
2284 AsyncApp,
2285 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2286 + 'static,
2287 >,
2288 >,
2289 }
2290
2291 impl FakeAgentConnection {
2292 fn new() -> Self {
2293 Self {
2294 auth_methods: Vec::new(),
2295 on_user_message: None,
2296 sessions: Arc::default(),
2297 }
2298 }
2299
2300 #[expect(unused)]
2301 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2302 self.auth_methods = auth_methods;
2303 self
2304 }
2305
2306 fn on_user_message(
2307 mut self,
2308 handler: impl Fn(
2309 acp::PromptRequest,
2310 WeakEntity<AcpThread>,
2311 AsyncApp,
2312 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2313 + 'static,
2314 ) -> Self {
2315 self.on_user_message.replace(Rc::new(handler));
2316 self
2317 }
2318 }
2319
2320 impl AgentConnection for FakeAgentConnection {
2321 fn auth_methods(&self) -> &[acp::AuthMethod] {
2322 &self.auth_methods
2323 }
2324
2325 fn new_thread(
2326 self: Rc<Self>,
2327 project: Entity<Project>,
2328 _cwd: &Path,
2329 cx: &mut gpui::App,
2330 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2331 let session_id = acp::SessionId(
2332 rand::thread_rng()
2333 .sample_iter(&rand::distributions::Alphanumeric)
2334 .take(7)
2335 .map(char::from)
2336 .collect::<String>()
2337 .into(),
2338 );
2339 let thread =
2340 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2341 self.sessions.lock().insert(session_id, thread.downgrade());
2342 Task::ready(Ok(thread))
2343 }
2344
2345 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2346 if self.auth_methods().iter().any(|m| m.id == method) {
2347 Task::ready(Ok(()))
2348 } else {
2349 Task::ready(Err(anyhow!("Invalid Auth Method")))
2350 }
2351 }
2352
2353 fn prompt(
2354 &self,
2355 _id: Option<UserMessageId>,
2356 params: acp::PromptRequest,
2357 cx: &mut App,
2358 ) -> Task<gpui::Result<acp::PromptResponse>> {
2359 let sessions = self.sessions.lock();
2360 let thread = sessions.get(¶ms.session_id).unwrap();
2361 if let Some(handler) = &self.on_user_message {
2362 let handler = handler.clone();
2363 let thread = thread.clone();
2364 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2365 } else {
2366 Task::ready(Ok(acp::PromptResponse {
2367 stop_reason: acp::StopReason::EndTurn,
2368 }))
2369 }
2370 }
2371
2372 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2373 let sessions = self.sessions.lock();
2374 let thread = sessions.get(session_id).unwrap().clone();
2375
2376 cx.spawn(async move |cx| {
2377 thread
2378 .update(cx, |thread, cx| thread.cancel(cx))
2379 .unwrap()
2380 .await
2381 })
2382 .detach();
2383 }
2384
2385 fn session_editor(
2386 &self,
2387 session_id: &acp::SessionId,
2388 _cx: &mut App,
2389 ) -> Option<Rc<dyn AgentSessionEditor>> {
2390 Some(Rc::new(FakeAgentSessionEditor {
2391 _session_id: session_id.clone(),
2392 }))
2393 }
2394
2395 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2396 self
2397 }
2398 }
2399
2400 struct FakeAgentSessionEditor {
2401 _session_id: acp::SessionId,
2402 }
2403
2404 impl AgentSessionEditor for FakeAgentSessionEditor {
2405 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2406 Task::ready(Ok(()))
2407 }
2408 }
2409}