1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9use serde::{Deserialize, Serialize};
10pub use terminal::*;
11
12use action_log::ActionLog;
13use agent_client_protocol as acp;
14use anyhow::{Context as _, Result, anyhow};
15use chrono::{DateTime, Utc};
16use editor::Bias;
17use futures::{FutureExt, channel::oneshot, future::BoxFuture};
18use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
19use itertools::Itertools;
20use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
21use markdown::Markdown;
22use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
23use std::collections::HashMap;
24use std::error::Error;
25use std::fmt::{Formatter, Write};
26use std::ops::Range;
27use std::process::ExitStatus;
28use std::rc::Rc;
29use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub id: Option<UserMessageId>,
36 pub content: ContentBlock,
37 pub chunks: Vec<acp::ContentBlock>,
38 pub checkpoint: Option<Checkpoint>,
39}
40
41#[derive(Debug)]
42pub struct Checkpoint {
43 git_checkpoint: GitStoreCheckpoint,
44 pub show: bool,
45}
46
47impl UserMessage {
48 fn to_markdown(&self, cx: &App) -> String {
49 let mut markdown = String::new();
50 if self
51 .checkpoint
52 .as_ref()
53 .map_or(false, |checkpoint| checkpoint.show)
54 {
55 writeln!(markdown, "## User (checkpoint)").unwrap();
56 } else {
57 writeln!(markdown, "## User").unwrap();
58 }
59 writeln!(markdown).unwrap();
60 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
61 writeln!(markdown).unwrap();
62 markdown
63 }
64}
65
66#[derive(Debug, PartialEq)]
67pub struct AssistantMessage {
68 pub chunks: Vec<AssistantMessageChunk>,
69}
70
71impl AssistantMessage {
72 pub fn to_markdown(&self, cx: &App) -> String {
73 format!(
74 "## Assistant\n\n{}\n\n",
75 self.chunks
76 .iter()
77 .map(|chunk| chunk.to_markdown(cx))
78 .join("\n\n")
79 )
80 }
81}
82
83#[derive(Debug, PartialEq)]
84pub enum AssistantMessageChunk {
85 Message { block: ContentBlock },
86 Thought { block: ContentBlock },
87}
88
89impl AssistantMessageChunk {
90 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
91 Self::Message {
92 block: ContentBlock::new(chunk.into(), language_registry, cx),
93 }
94 }
95
96 fn to_markdown(&self, cx: &App) -> String {
97 match self {
98 Self::Message { block } => block.to_markdown(cx).to_string(),
99 Self::Thought { block } => {
100 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
101 }
102 }
103 }
104}
105
106#[derive(Debug)]
107pub enum AgentThreadEntry {
108 UserMessage(UserMessage),
109 AssistantMessage(AssistantMessage),
110 ToolCall(ToolCall),
111}
112
113impl AgentThreadEntry {
114 pub fn to_markdown(&self, cx: &App) -> String {
115 match self {
116 Self::UserMessage(message) => message.to_markdown(cx),
117 Self::AssistantMessage(message) => message.to_markdown(cx),
118 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
119 }
120 }
121
122 pub fn user_message(&self) -> Option<&UserMessage> {
123 if let AgentThreadEntry::UserMessage(message) = self {
124 Some(message)
125 } else {
126 None
127 }
128 }
129
130 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
131 if let AgentThreadEntry::ToolCall(call) = self {
132 itertools::Either::Left(call.diffs())
133 } else {
134 itertools::Either::Right(std::iter::empty())
135 }
136 }
137
138 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.terminals())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
147 if let AgentThreadEntry::ToolCall(ToolCall {
148 locations,
149 resolved_locations,
150 ..
151 }) = self
152 {
153 Some((
154 locations.get(ix)?.clone(),
155 resolved_locations.get(ix)?.clone()?,
156 ))
157 } else {
158 None
159 }
160 }
161}
162
163#[derive(Debug)]
164pub struct ToolCall {
165 pub id: acp::ToolCallId,
166 pub label: Entity<Markdown>,
167 pub kind: acp::ToolKind,
168 pub content: Vec<ToolCallContent>,
169 pub status: ToolCallStatus,
170 pub locations: Vec<acp::ToolCallLocation>,
171 pub resolved_locations: Vec<Option<AgentLocation>>,
172 pub raw_input: Option<serde_json::Value>,
173 pub raw_output: Option<serde_json::Value>,
174}
175
176impl ToolCall {
177 fn from_acp(
178 tool_call: acp::ToolCall,
179 status: ToolCallStatus,
180 language_registry: Arc<LanguageRegistry>,
181 cx: &mut App,
182 ) -> Self {
183 Self {
184 id: tool_call.id,
185 label: cx.new(|cx| {
186 Markdown::new(
187 tool_call.title.into(),
188 Some(language_registry.clone()),
189 None,
190 cx,
191 )
192 }),
193 kind: tool_call.kind,
194 content: tool_call
195 .content
196 .into_iter()
197 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
198 .collect(),
199 locations: tool_call.locations,
200 resolved_locations: Vec::default(),
201 status,
202 raw_input: tool_call.raw_input,
203 raw_output: tool_call.raw_output,
204 }
205 }
206
207 fn update_fields(
208 &mut self,
209 fields: acp::ToolCallUpdateFields,
210 language_registry: Arc<LanguageRegistry>,
211 cx: &mut App,
212 ) {
213 let acp::ToolCallUpdateFields {
214 kind,
215 status,
216 title,
217 content,
218 locations,
219 raw_input,
220 raw_output,
221 } = fields;
222
223 if let Some(kind) = kind {
224 self.kind = kind;
225 }
226
227 if let Some(status) = status {
228 self.status = status.into();
229 }
230
231 if let Some(title) = title {
232 self.label.update(cx, |label, cx| {
233 label.replace(title, cx);
234 });
235 }
236
237 if let Some(content) = content {
238 self.content = content
239 .into_iter()
240 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
241 .collect();
242 }
243
244 if let Some(locations) = locations {
245 self.locations = locations;
246 }
247
248 if let Some(raw_input) = raw_input {
249 self.raw_input = Some(raw_input);
250 }
251
252 if let Some(raw_output) = raw_output {
253 if self.content.is_empty() {
254 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
255 {
256 self.content
257 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
258 markdown,
259 }));
260 }
261 }
262 self.raw_output = Some(raw_output);
263 }
264 }
265
266 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
267 self.content.iter().filter_map(|content| match content {
268 ToolCallContent::Diff(diff) => Some(diff),
269 ToolCallContent::ContentBlock(_) => None,
270 ToolCallContent::Terminal(_) => None,
271 })
272 }
273
274 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
275 self.content.iter().filter_map(|content| match content {
276 ToolCallContent::Terminal(terminal) => Some(terminal),
277 ToolCallContent::ContentBlock(_) => None,
278 ToolCallContent::Diff(_) => None,
279 })
280 }
281
282 fn to_markdown(&self, cx: &App) -> String {
283 let mut markdown = format!(
284 "**Tool Call: {}**\nStatus: {}\n\n",
285 self.label.read(cx).source(),
286 self.status
287 );
288 for content in &self.content {
289 markdown.push_str(content.to_markdown(cx).as_str());
290 markdown.push_str("\n\n");
291 }
292 markdown
293 }
294
295 async fn resolve_location(
296 location: acp::ToolCallLocation,
297 project: WeakEntity<Project>,
298 cx: &mut AsyncApp,
299 ) -> Option<AgentLocation> {
300 let buffer = project
301 .update(cx, |project, cx| {
302 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
303 Some(project.open_buffer(path, cx))
304 } else {
305 None
306 }
307 })
308 .ok()??;
309 let buffer = buffer.await.log_err()?;
310 let position = buffer
311 .update(cx, |buffer, _| {
312 if let Some(row) = location.line {
313 let snapshot = buffer.snapshot();
314 let column = snapshot.indent_size_for_line(row).len;
315 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
316 snapshot.anchor_before(point)
317 } else {
318 Anchor::MIN
319 }
320 })
321 .ok()?;
322
323 Some(AgentLocation {
324 buffer: buffer.downgrade(),
325 position,
326 })
327 }
328
329 fn resolve_locations(
330 &self,
331 project: Entity<Project>,
332 cx: &mut App,
333 ) -> Task<Vec<Option<AgentLocation>>> {
334 let locations = self.locations.clone();
335 project.update(cx, |_, cx| {
336 cx.spawn(async move |project, cx| {
337 let mut new_locations = Vec::new();
338 for location in locations {
339 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
340 }
341 new_locations
342 })
343 })
344 }
345}
346
347#[derive(Debug)]
348pub enum ToolCallStatus {
349 /// The tool call hasn't started running yet, but we start showing it to
350 /// the user.
351 Pending,
352 /// The tool call is waiting for confirmation from the user.
353 WaitingForConfirmation {
354 options: Vec<acp::PermissionOption>,
355 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
356 },
357 /// The tool call is currently running.
358 InProgress,
359 /// The tool call completed successfully.
360 Completed,
361 /// The tool call failed.
362 Failed,
363 /// The user rejected the tool call.
364 Rejected,
365 /// The user canceled generation so the tool call was canceled.
366 Canceled,
367}
368
369impl From<acp::ToolCallStatus> for ToolCallStatus {
370 fn from(status: acp::ToolCallStatus) -> Self {
371 match status {
372 acp::ToolCallStatus::Pending => Self::Pending,
373 acp::ToolCallStatus::InProgress => Self::InProgress,
374 acp::ToolCallStatus::Completed => Self::Completed,
375 acp::ToolCallStatus::Failed => Self::Failed,
376 }
377 }
378}
379
380impl Display for ToolCallStatus {
381 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
382 write!(
383 f,
384 "{}",
385 match self {
386 ToolCallStatus::Pending => "Pending",
387 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
388 ToolCallStatus::InProgress => "In Progress",
389 ToolCallStatus::Completed => "Completed",
390 ToolCallStatus::Failed => "Failed",
391 ToolCallStatus::Rejected => "Rejected",
392 ToolCallStatus::Canceled => "Canceled",
393 }
394 )
395 }
396}
397
398#[derive(Debug, PartialEq, Clone)]
399pub enum ContentBlock {
400 Empty,
401 Markdown { markdown: Entity<Markdown> },
402 ResourceLink { resource_link: acp::ResourceLink },
403}
404
405impl ContentBlock {
406 pub fn new(
407 block: acp::ContentBlock,
408 language_registry: &Arc<LanguageRegistry>,
409 cx: &mut App,
410 ) -> Self {
411 let mut this = Self::Empty;
412 this.append(block, language_registry, cx);
413 this
414 }
415
416 pub fn new_combined(
417 blocks: impl IntoIterator<Item = acp::ContentBlock>,
418 language_registry: Arc<LanguageRegistry>,
419 cx: &mut App,
420 ) -> Self {
421 let mut this = Self::Empty;
422 for block in blocks {
423 this.append(block, &language_registry, cx);
424 }
425 this
426 }
427
428 pub fn append(
429 &mut self,
430 block: acp::ContentBlock,
431 language_registry: &Arc<LanguageRegistry>,
432 cx: &mut App,
433 ) {
434 if matches!(self, ContentBlock::Empty) {
435 if let acp::ContentBlock::ResourceLink(resource_link) = block {
436 *self = ContentBlock::ResourceLink { resource_link };
437 return;
438 }
439 }
440
441 let new_content = self.block_string_contents(block);
442
443 match self {
444 ContentBlock::Empty => {
445 *self = Self::create_markdown_block(new_content, language_registry, cx);
446 }
447 ContentBlock::Markdown { markdown } => {
448 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
449 }
450 ContentBlock::ResourceLink { resource_link } => {
451 let existing_content = Self::resource_link_md(&resource_link.uri);
452 let combined = format!("{}\n{}", existing_content, new_content);
453
454 *self = Self::create_markdown_block(combined, language_registry, cx);
455 }
456 }
457 }
458
459 fn create_markdown_block(
460 content: String,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) -> ContentBlock {
464 ContentBlock::Markdown {
465 markdown: cx
466 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
467 }
468 }
469
470 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
471 match block {
472 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
473 acp::ContentBlock::ResourceLink(resource_link) => {
474 Self::resource_link_md(&resource_link.uri)
475 }
476 acp::ContentBlock::Resource(acp::EmbeddedResource {
477 resource:
478 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
479 uri,
480 ..
481 }),
482 ..
483 }) => Self::resource_link_md(&uri),
484 acp::ContentBlock::Image(image) => Self::image_md(&image),
485 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
486 }
487 }
488
489 fn resource_link_md(uri: &str) -> String {
490 if let Some(uri) = MentionUri::parse(&uri).log_err() {
491 uri.as_link().to_string()
492 } else {
493 uri.to_string()
494 }
495 }
496
497 fn image_md(_image: &acp::ImageContent) -> String {
498 "`Image`".into()
499 }
500
501 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
502 match self {
503 ContentBlock::Empty => "",
504 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
505 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
506 }
507 }
508
509 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
510 match self {
511 ContentBlock::Empty => None,
512 ContentBlock::Markdown { markdown } => Some(markdown),
513 ContentBlock::ResourceLink { .. } => None,
514 }
515 }
516
517 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
518 match self {
519 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
520 _ => None,
521 }
522 }
523}
524
525#[derive(Debug)]
526pub enum ToolCallContent {
527 ContentBlock(ContentBlock),
528 Diff(Entity<Diff>),
529 Terminal(Entity<Terminal>),
530}
531
532impl ToolCallContent {
533 pub fn from_acp(
534 content: acp::ToolCallContent,
535 language_registry: Arc<LanguageRegistry>,
536 cx: &mut App,
537 ) -> Self {
538 match content {
539 acp::ToolCallContent::Content { content } => {
540 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
541 }
542 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
543 Diff::finalized(
544 diff.path,
545 diff.old_text,
546 diff.new_text,
547 language_registry,
548 cx,
549 )
550 })),
551 }
552 }
553
554 pub fn to_markdown(&self, cx: &App) -> String {
555 match self {
556 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
557 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
558 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
559 }
560 }
561}
562
563#[derive(Debug, PartialEq)]
564pub enum ToolCallUpdate {
565 UpdateFields(acp::ToolCallUpdate),
566 UpdateDiff(ToolCallUpdateDiff),
567 UpdateTerminal(ToolCallUpdateTerminal),
568}
569
570impl ToolCallUpdate {
571 fn id(&self) -> &acp::ToolCallId {
572 match self {
573 Self::UpdateFields(update) => &update.id,
574 Self::UpdateDiff(diff) => &diff.id,
575 Self::UpdateTerminal(terminal) => &terminal.id,
576 }
577 }
578}
579
580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
581 fn from(update: acp::ToolCallUpdate) -> Self {
582 Self::UpdateFields(update)
583 }
584}
585
586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
587 fn from(diff: ToolCallUpdateDiff) -> Self {
588 Self::UpdateDiff(diff)
589 }
590}
591
592#[derive(Debug, PartialEq)]
593pub struct ToolCallUpdateDiff {
594 pub id: acp::ToolCallId,
595 pub diff: Entity<Diff>,
596}
597
598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
599 fn from(terminal: ToolCallUpdateTerminal) -> Self {
600 Self::UpdateTerminal(terminal)
601 }
602}
603
604#[derive(Debug, PartialEq)]
605pub struct ToolCallUpdateTerminal {
606 pub id: acp::ToolCallId,
607 pub terminal: Entity<Terminal>,
608}
609
610#[derive(Debug, Default)]
611pub struct Plan {
612 pub entries: Vec<PlanEntry>,
613}
614
615#[derive(Debug)]
616pub struct PlanStats<'a> {
617 pub in_progress_entry: Option<&'a PlanEntry>,
618 pub pending: u32,
619 pub completed: u32,
620}
621
622impl Plan {
623 pub fn is_empty(&self) -> bool {
624 self.entries.is_empty()
625 }
626
627 pub fn stats(&self) -> PlanStats<'_> {
628 let mut stats = PlanStats {
629 in_progress_entry: None,
630 pending: 0,
631 completed: 0,
632 };
633
634 for entry in &self.entries {
635 match &entry.status {
636 acp::PlanEntryStatus::Pending => {
637 stats.pending += 1;
638 }
639 acp::PlanEntryStatus::InProgress => {
640 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
641 }
642 acp::PlanEntryStatus::Completed => {
643 stats.completed += 1;
644 }
645 }
646 }
647
648 stats
649 }
650}
651
652#[derive(Debug)]
653pub struct PlanEntry {
654 pub content: Entity<Markdown>,
655 pub priority: acp::PlanEntryPriority,
656 pub status: acp::PlanEntryStatus,
657}
658
659impl PlanEntry {
660 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
661 Self {
662 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
663 priority: entry.priority,
664 status: entry.status,
665 }
666 }
667}
668
669#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
670pub struct AgentServerName(pub SharedString);
671
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct AcpThreadMetadata {
674 pub agent: AgentServerName,
675 pub id: acp::SessionId,
676 pub title: SharedString,
677 pub updated_at: DateTime<Utc>,
678}
679
680pub struct AcpThread {
681 title: SharedString,
682 entries: Vec<AgentThreadEntry>,
683 plan: Plan,
684 project: Entity<Project>,
685 action_log: Entity<ActionLog>,
686 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
687 send_task: Option<Task<()>>,
688 connection: Rc<dyn AgentConnection>,
689 session_id: acp::SessionId,
690}
691
692#[derive(Debug)]
693pub enum AcpThreadEvent {
694 NewEntry,
695 TitleUpdated,
696 EntryUpdated(usize),
697 EntriesRemoved(Range<usize>),
698 ToolAuthorizationRequired,
699 Stopped,
700 Error,
701 ServerExited(ExitStatus),
702}
703
704impl EventEmitter<AcpThreadEvent> for AcpThread {}
705
706#[derive(PartialEq, Eq)]
707pub enum ThreadStatus {
708 Idle,
709 WaitingForToolConfirmation,
710 Generating,
711}
712
713#[derive(Debug, Clone)]
714pub enum LoadError {
715 Unsupported {
716 error_message: SharedString,
717 upgrade_message: SharedString,
718 upgrade_command: String,
719 },
720 Exited(i32),
721 Other(SharedString),
722}
723
724impl Display for LoadError {
725 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
726 match self {
727 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
728 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
729 LoadError::Other(msg) => write!(f, "{}", msg),
730 }
731 }
732}
733
734impl Error for LoadError {}
735
736impl AcpThread {
737 pub fn new(
738 title: impl Into<SharedString>,
739 connection: Rc<dyn AgentConnection>,
740 project: Entity<Project>,
741 session_id: acp::SessionId,
742 cx: &mut Context<Self>,
743 ) -> Self {
744 let action_log = cx.new(|_| ActionLog::new(project.clone()));
745
746 Self {
747 action_log,
748 shared_buffers: Default::default(),
749 entries: Default::default(),
750 plan: Default::default(),
751 title: title.into(),
752 project,
753 send_task: None,
754 connection,
755 session_id,
756 }
757 }
758
759 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
760 &self.connection
761 }
762
763 pub fn action_log(&self) -> &Entity<ActionLog> {
764 &self.action_log
765 }
766
767 pub fn project(&self) -> &Entity<Project> {
768 &self.project
769 }
770
771 pub fn title(&self) -> SharedString {
772 self.title.clone()
773 }
774
775 pub fn entries(&self) -> &[AgentThreadEntry] {
776 &self.entries
777 }
778
779 pub fn session_id(&self) -> &acp::SessionId {
780 &self.session_id
781 }
782
783 pub fn status(&self) -> ThreadStatus {
784 if self.send_task.is_some() {
785 if self.waiting_for_tool_confirmation() {
786 ThreadStatus::WaitingForToolConfirmation
787 } else {
788 ThreadStatus::Generating
789 }
790 } else {
791 ThreadStatus::Idle
792 }
793 }
794
795 pub fn has_pending_edit_tool_calls(&self) -> bool {
796 for entry in self.entries.iter().rev() {
797 match entry {
798 AgentThreadEntry::UserMessage(_) => return false,
799 AgentThreadEntry::ToolCall(
800 call @ ToolCall {
801 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
802 ..
803 },
804 ) if call.diffs().next().is_some() => {
805 return true;
806 }
807 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
808 }
809 }
810
811 false
812 }
813
814 pub fn used_tools_since_last_user_message(&self) -> bool {
815 for entry in self.entries.iter().rev() {
816 match entry {
817 AgentThreadEntry::UserMessage(..) => return false,
818 AgentThreadEntry::AssistantMessage(..) => continue,
819 AgentThreadEntry::ToolCall(..) => return true,
820 }
821 }
822
823 false
824 }
825
826 pub fn handle_session_update(
827 &mut self,
828 update: acp::SessionUpdate,
829 cx: &mut Context<Self>,
830 ) -> Result<(), acp::Error> {
831 match update {
832 acp::SessionUpdate::UserMessageChunk { content } => {
833 self.push_user_content_block(None, content, cx);
834 }
835 acp::SessionUpdate::AgentMessageChunk { content } => {
836 self.push_assistant_content_block(content, false, cx);
837 }
838 acp::SessionUpdate::AgentThoughtChunk { content } => {
839 self.push_assistant_content_block(content, true, cx);
840 }
841 acp::SessionUpdate::ToolCall(tool_call) => {
842 self.upsert_tool_call(tool_call, cx)?;
843 }
844 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
845 self.update_tool_call(tool_call_update, cx)?;
846 }
847 acp::SessionUpdate::Plan(plan) => {
848 self.update_plan(plan, cx);
849 }
850 }
851 Ok(())
852 }
853
854 pub fn push_user_content_block(
855 &mut self,
856 message_id: Option<UserMessageId>,
857 chunk: acp::ContentBlock,
858 cx: &mut Context<Self>,
859 ) {
860 let language_registry = self.project.read(cx).languages().clone();
861 let entries_len = self.entries.len();
862
863 if let Some(last_entry) = self.entries.last_mut()
864 && let AgentThreadEntry::UserMessage(UserMessage {
865 id,
866 content,
867 chunks,
868 ..
869 }) = last_entry
870 {
871 *id = message_id.or(id.take());
872 content.append(chunk.clone(), &language_registry, cx);
873 chunks.push(chunk);
874 let idx = entries_len - 1;
875 cx.emit(AcpThreadEvent::EntryUpdated(idx));
876 } else {
877 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
878 self.push_entry(
879 AgentThreadEntry::UserMessage(UserMessage {
880 id: message_id,
881 content,
882 chunks: vec![chunk],
883 checkpoint: None,
884 }),
885 cx,
886 );
887 }
888 }
889
890 pub fn push_assistant_content_block(
891 &mut self,
892 chunk: acp::ContentBlock,
893 is_thought: bool,
894 cx: &mut Context<Self>,
895 ) {
896 let language_registry = self.project.read(cx).languages().clone();
897 let entries_len = self.entries.len();
898 if let Some(last_entry) = self.entries.last_mut()
899 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
900 {
901 let idx = entries_len - 1;
902 cx.emit(AcpThreadEvent::EntryUpdated(idx));
903 match (chunks.last_mut(), is_thought) {
904 (Some(AssistantMessageChunk::Message { block }), false)
905 | (Some(AssistantMessageChunk::Thought { block }), true) => {
906 block.append(chunk, &language_registry, cx)
907 }
908 _ => {
909 let block = ContentBlock::new(chunk, &language_registry, cx);
910 if is_thought {
911 chunks.push(AssistantMessageChunk::Thought { block })
912 } else {
913 chunks.push(AssistantMessageChunk::Message { block })
914 }
915 }
916 }
917 } else {
918 let block = ContentBlock::new(chunk, &language_registry, cx);
919 let chunk = if is_thought {
920 AssistantMessageChunk::Thought { block }
921 } else {
922 AssistantMessageChunk::Message { block }
923 };
924
925 self.push_entry(
926 AgentThreadEntry::AssistantMessage(AssistantMessage {
927 chunks: vec![chunk],
928 }),
929 cx,
930 );
931 }
932 }
933
934 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
935 self.entries.push(entry);
936 cx.emit(AcpThreadEvent::NewEntry);
937 }
938
939 pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
940 dbg!("update title", &title);
941 self.title = title;
942 cx.emit(AcpThreadEvent::TitleUpdated);
943 Ok(())
944 }
945
946 pub fn update_tool_call(
947 &mut self,
948 update: impl Into<ToolCallUpdate>,
949 cx: &mut Context<Self>,
950 ) -> Result<()> {
951 let update = update.into();
952 let languages = self.project.read(cx).languages().clone();
953
954 let (ix, current_call) = self
955 .tool_call_mut(update.id())
956 .context("Tool call not found")?;
957 match update {
958 ToolCallUpdate::UpdateFields(update) => {
959 let location_updated = update.fields.locations.is_some();
960 current_call.update_fields(update.fields, languages, cx);
961 if location_updated {
962 self.resolve_locations(update.id.clone(), cx);
963 }
964 }
965 ToolCallUpdate::UpdateDiff(update) => {
966 current_call.content.clear();
967 current_call
968 .content
969 .push(ToolCallContent::Diff(update.diff));
970 }
971 ToolCallUpdate::UpdateTerminal(update) => {
972 current_call.content.clear();
973 current_call
974 .content
975 .push(ToolCallContent::Terminal(update.terminal));
976 }
977 }
978
979 cx.emit(AcpThreadEvent::EntryUpdated(ix));
980
981 Ok(())
982 }
983
984 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
985 pub fn upsert_tool_call(
986 &mut self,
987 tool_call: acp::ToolCall,
988 cx: &mut Context<Self>,
989 ) -> Result<(), acp::Error> {
990 let status = tool_call.status.into();
991 self.upsert_tool_call_inner(tool_call.into(), status, cx)
992 }
993
994 /// Fails if id does not match an existing entry.
995 pub fn upsert_tool_call_inner(
996 &mut self,
997 tool_call_update: acp::ToolCallUpdate,
998 status: ToolCallStatus,
999 cx: &mut Context<Self>,
1000 ) -> Result<(), acp::Error> {
1001 let language_registry = self.project.read(cx).languages().clone();
1002 let id = tool_call_update.id.clone();
1003
1004 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1005 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1006 current_call.status = status;
1007
1008 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1009 } else {
1010 let call =
1011 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1012 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1013 };
1014
1015 self.resolve_locations(id, cx);
1016 Ok(())
1017 }
1018
1019 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1020 // The tool call we are looking for is typically the last one, or very close to the end.
1021 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1022 self.entries
1023 .iter_mut()
1024 .enumerate()
1025 .rev()
1026 .find_map(|(index, tool_call)| {
1027 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1028 && &tool_call.id == id
1029 {
1030 Some((index, tool_call))
1031 } else {
1032 None
1033 }
1034 })
1035 }
1036
1037 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1038 let project = self.project.clone();
1039 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1040 return;
1041 };
1042 let task = tool_call.resolve_locations(project, cx);
1043 cx.spawn(async move |this, cx| {
1044 let resolved_locations = task.await;
1045 this.update(cx, |this, cx| {
1046 let project = this.project.clone();
1047 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1048 return;
1049 };
1050 if let Some(Some(location)) = resolved_locations.last() {
1051 project.update(cx, |project, cx| {
1052 if let Some(agent_location) = project.agent_location() {
1053 let should_ignore = agent_location.buffer == location.buffer
1054 && location
1055 .buffer
1056 .update(cx, |buffer, _| {
1057 let snapshot = buffer.snapshot();
1058 let old_position =
1059 agent_location.position.to_point(&snapshot);
1060 let new_position = location.position.to_point(&snapshot);
1061 // ignore this so that when we get updates from the edit tool
1062 // the position doesn't reset to the startof line
1063 old_position.row == new_position.row
1064 && old_position.column > new_position.column
1065 })
1066 .ok()
1067 .unwrap_or_default();
1068 if !should_ignore {
1069 project.set_agent_location(Some(location.clone()), cx);
1070 }
1071 }
1072 });
1073 }
1074 if tool_call.resolved_locations != resolved_locations {
1075 tool_call.resolved_locations = resolved_locations;
1076 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1077 }
1078 })
1079 })
1080 .detach();
1081 }
1082
1083 pub fn request_tool_call_authorization(
1084 &mut self,
1085 tool_call: acp::ToolCallUpdate,
1086 options: Vec<acp::PermissionOption>,
1087 cx: &mut Context<Self>,
1088 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1089 let (tx, rx) = oneshot::channel();
1090
1091 let status = ToolCallStatus::WaitingForConfirmation {
1092 options,
1093 respond_tx: tx,
1094 };
1095
1096 self.upsert_tool_call_inner(tool_call, status, cx)?;
1097 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1098 Ok(rx)
1099 }
1100
1101 pub fn authorize_tool_call(
1102 &mut self,
1103 id: acp::ToolCallId,
1104 option_id: acp::PermissionOptionId,
1105 option_kind: acp::PermissionOptionKind,
1106 cx: &mut Context<Self>,
1107 ) {
1108 let Some((ix, call)) = self.tool_call_mut(&id) else {
1109 return;
1110 };
1111
1112 let new_status = match option_kind {
1113 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1114 ToolCallStatus::Rejected
1115 }
1116 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1117 ToolCallStatus::InProgress
1118 }
1119 };
1120
1121 let curr_status = mem::replace(&mut call.status, new_status);
1122
1123 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1124 respond_tx.send(option_id).log_err();
1125 } else if cfg!(debug_assertions) {
1126 panic!("tried to authorize an already authorized tool call");
1127 }
1128
1129 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1130 }
1131
1132 /// Returns true if the last turn is awaiting tool authorization
1133 pub fn waiting_for_tool_confirmation(&self) -> bool {
1134 for entry in self.entries.iter().rev() {
1135 match &entry {
1136 AgentThreadEntry::ToolCall(call) => match call.status {
1137 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1138 ToolCallStatus::Pending
1139 | ToolCallStatus::InProgress
1140 | ToolCallStatus::Completed
1141 | ToolCallStatus::Failed
1142 | ToolCallStatus::Rejected
1143 | ToolCallStatus::Canceled => continue,
1144 },
1145 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1146 // Reached the beginning of the turn
1147 return false;
1148 }
1149 }
1150 }
1151 false
1152 }
1153
1154 pub fn plan(&self) -> &Plan {
1155 &self.plan
1156 }
1157
1158 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1159 let new_entries_len = request.entries.len();
1160 let mut new_entries = request.entries.into_iter();
1161
1162 // Reuse existing markdown to prevent flickering
1163 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1164 let PlanEntry {
1165 content,
1166 priority,
1167 status,
1168 } = old;
1169 content.update(cx, |old, cx| {
1170 old.replace(new.content, cx);
1171 });
1172 *priority = new.priority;
1173 *status = new.status;
1174 }
1175 for new in new_entries {
1176 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1177 }
1178 self.plan.entries.truncate(new_entries_len);
1179
1180 cx.notify();
1181 }
1182
1183 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1184 self.plan
1185 .entries
1186 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1187 cx.notify();
1188 }
1189
1190 #[cfg(any(test, feature = "test-support"))]
1191 pub fn send_raw(
1192 &mut self,
1193 message: &str,
1194 cx: &mut Context<Self>,
1195 ) -> BoxFuture<'static, Result<()>> {
1196 self.send(
1197 vec![acp::ContentBlock::Text(acp::TextContent {
1198 text: message.to_string(),
1199 annotations: None,
1200 })],
1201 cx,
1202 )
1203 }
1204
1205 pub fn send(
1206 &mut self,
1207 message: Vec<acp::ContentBlock>,
1208 cx: &mut Context<Self>,
1209 ) -> BoxFuture<'static, Result<()>> {
1210 let block = ContentBlock::new_combined(
1211 message.clone(),
1212 self.project.read(cx).languages().clone(),
1213 cx,
1214 );
1215 let request = acp::PromptRequest {
1216 prompt: message.clone(),
1217 session_id: self.session_id.clone(),
1218 };
1219 let git_store = self.project.read(cx).git_store().clone();
1220
1221 let message_id = if self
1222 .connection
1223 .session_editor(&self.session_id, cx)
1224 .is_some()
1225 {
1226 Some(UserMessageId::new())
1227 } else {
1228 None
1229 };
1230
1231 self.run_turn(cx, async move |this, cx| {
1232 this.update(cx, |this, cx| {
1233 this.push_entry(
1234 AgentThreadEntry::UserMessage(UserMessage {
1235 id: message_id.clone(),
1236 content: block,
1237 chunks: message,
1238 checkpoint: None,
1239 }),
1240 cx,
1241 );
1242 })
1243 .ok();
1244
1245 let old_checkpoint = git_store
1246 .update(cx, |git, cx| git.checkpoint(cx))?
1247 .await
1248 .context("failed to get old checkpoint")
1249 .log_err();
1250 this.update(cx, |this, cx| {
1251 if let Some((_ix, message)) = this.last_user_message() {
1252 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1253 git_checkpoint,
1254 show: false,
1255 });
1256 }
1257 this.connection.prompt(message_id, request, cx)
1258 })?
1259 .await
1260 })
1261 }
1262
1263 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1264 self.run_turn(cx, async move |this, cx| {
1265 this.update(cx, |this, cx| {
1266 this.connection
1267 .resume(&this.session_id, cx)
1268 .map(|resume| resume.run(cx))
1269 })?
1270 .context("resuming a session is not supported")?
1271 .await
1272 })
1273 }
1274
1275 fn run_turn(
1276 &mut self,
1277 cx: &mut Context<Self>,
1278 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1279 ) -> BoxFuture<'static, Result<()>> {
1280 self.clear_completed_plan_entries(cx);
1281
1282 let (tx, rx) = oneshot::channel();
1283 let cancel_task = self.cancel(cx);
1284
1285 self.send_task = Some(cx.spawn(async move |this, cx| {
1286 cancel_task.await;
1287 tx.send(f(this, cx).await).ok();
1288 }));
1289
1290 cx.spawn(async move |this, cx| {
1291 let response = rx.await;
1292
1293 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1294 .await?;
1295
1296 this.update(cx, |this, cx| {
1297 match response {
1298 Ok(Err(e)) => {
1299 this.send_task.take();
1300 cx.emit(AcpThreadEvent::Error);
1301 Err(e)
1302 }
1303 result => {
1304 let canceled = matches!(
1305 result,
1306 Ok(Ok(acp::PromptResponse {
1307 stop_reason: acp::StopReason::Canceled
1308 }))
1309 );
1310
1311 // We only take the task if the current prompt wasn't canceled.
1312 //
1313 // This prompt may have been canceled because another one was sent
1314 // while it was still generating. In these cases, dropping `send_task`
1315 // would cause the next generation to be canceled.
1316 if !canceled {
1317 this.send_task.take();
1318 }
1319
1320 cx.emit(AcpThreadEvent::Stopped);
1321 Ok(())
1322 }
1323 }
1324 })?
1325 })
1326 .boxed()
1327 }
1328
1329 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1330 let Some(send_task) = self.send_task.take() else {
1331 return Task::ready(());
1332 };
1333
1334 for entry in self.entries.iter_mut() {
1335 if let AgentThreadEntry::ToolCall(call) = entry {
1336 let cancel = matches!(
1337 call.status,
1338 ToolCallStatus::Pending
1339 | ToolCallStatus::WaitingForConfirmation { .. }
1340 | ToolCallStatus::InProgress
1341 );
1342
1343 if cancel {
1344 call.status = ToolCallStatus::Canceled;
1345 }
1346 }
1347 }
1348
1349 self.connection.cancel(&self.session_id, cx);
1350
1351 // Wait for the send task to complete
1352 cx.foreground_executor().spawn(send_task)
1353 }
1354
1355 /// Rewinds this thread to before the entry at `index`, removing it and all
1356 /// subsequent entries while reverting any changes made from that point.
1357 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1358 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1359 return Task::ready(Err(anyhow!("not supported")));
1360 };
1361 let Some(message) = self.user_message(&id) else {
1362 return Task::ready(Err(anyhow!("message not found")));
1363 };
1364
1365 let checkpoint = message
1366 .checkpoint
1367 .as_ref()
1368 .map(|c| c.git_checkpoint.clone());
1369
1370 let git_store = self.project.read(cx).git_store().clone();
1371 cx.spawn(async move |this, cx| {
1372 if let Some(checkpoint) = checkpoint {
1373 git_store
1374 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1375 .await?;
1376 }
1377
1378 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1379 .await?;
1380 this.update(cx, |this, cx| {
1381 if let Some((ix, _)) = this.user_message_mut(&id) {
1382 let range = ix..this.entries.len();
1383 this.entries.truncate(ix);
1384 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1385 }
1386 })
1387 })
1388 }
1389
1390 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1391 let git_store = self.project.read(cx).git_store().clone();
1392
1393 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1394 if let Some(checkpoint) = message.checkpoint.as_ref() {
1395 checkpoint.git_checkpoint.clone()
1396 } else {
1397 return Task::ready(Ok(()));
1398 }
1399 } else {
1400 return Task::ready(Ok(()));
1401 };
1402
1403 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1404 cx.spawn(async move |this, cx| {
1405 let new_checkpoint = new_checkpoint
1406 .await
1407 .context("failed to get new checkpoint")
1408 .log_err();
1409 if let Some(new_checkpoint) = new_checkpoint {
1410 let equal = git_store
1411 .update(cx, |git, cx| {
1412 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1413 })?
1414 .await
1415 .unwrap_or(true);
1416 this.update(cx, |this, cx| {
1417 let (ix, message) = this.last_user_message().context("no user message")?;
1418 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1419 checkpoint.show = !equal;
1420 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1421 anyhow::Ok(())
1422 })??;
1423 }
1424
1425 Ok(())
1426 })
1427 }
1428
1429 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1430 self.entries
1431 .iter_mut()
1432 .enumerate()
1433 .rev()
1434 .find_map(|(ix, entry)| {
1435 if let AgentThreadEntry::UserMessage(message) = entry {
1436 Some((ix, message))
1437 } else {
1438 None
1439 }
1440 })
1441 }
1442
1443 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1444 self.entries.iter().find_map(|entry| {
1445 if let AgentThreadEntry::UserMessage(message) = entry {
1446 if message.id.as_ref() == Some(&id) {
1447 Some(message)
1448 } else {
1449 None
1450 }
1451 } else {
1452 None
1453 }
1454 })
1455 }
1456
1457 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1458 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1459 if let AgentThreadEntry::UserMessage(message) = entry {
1460 if message.id.as_ref() == Some(&id) {
1461 Some((ix, message))
1462 } else {
1463 None
1464 }
1465 } else {
1466 None
1467 }
1468 })
1469 }
1470
1471 pub fn read_text_file(
1472 &self,
1473 path: PathBuf,
1474 line: Option<u32>,
1475 limit: Option<u32>,
1476 reuse_shared_snapshot: bool,
1477 cx: &mut Context<Self>,
1478 ) -> Task<Result<String>> {
1479 let project = self.project.clone();
1480 let action_log = self.action_log.clone();
1481 cx.spawn(async move |this, cx| {
1482 let load = project.update(cx, |project, cx| {
1483 let path = project
1484 .project_path_for_absolute_path(&path, cx)
1485 .context("invalid path")?;
1486 anyhow::Ok(project.open_buffer(path, cx))
1487 });
1488 let buffer = load??.await?;
1489
1490 let snapshot = if reuse_shared_snapshot {
1491 this.read_with(cx, |this, _| {
1492 this.shared_buffers.get(&buffer.clone()).cloned()
1493 })
1494 .log_err()
1495 .flatten()
1496 } else {
1497 None
1498 };
1499
1500 let snapshot = if let Some(snapshot) = snapshot {
1501 snapshot
1502 } else {
1503 action_log.update(cx, |action_log, cx| {
1504 action_log.buffer_read(buffer.clone(), cx);
1505 })?;
1506 project.update(cx, |project, cx| {
1507 let position = buffer
1508 .read(cx)
1509 .snapshot()
1510 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1511 project.set_agent_location(
1512 Some(AgentLocation {
1513 buffer: buffer.downgrade(),
1514 position,
1515 }),
1516 cx,
1517 );
1518 })?;
1519
1520 buffer.update(cx, |buffer, _| buffer.snapshot())?
1521 };
1522
1523 this.update(cx, |this, _| {
1524 let text = snapshot.text();
1525 this.shared_buffers.insert(buffer.clone(), snapshot);
1526 if line.is_none() && limit.is_none() {
1527 return Ok(text);
1528 }
1529 let limit = limit.unwrap_or(u32::MAX) as usize;
1530 let Some(line) = line else {
1531 return Ok(text.lines().take(limit).collect::<String>());
1532 };
1533
1534 let count = text.lines().count();
1535 if count < line as usize {
1536 anyhow::bail!("There are only {} lines", count);
1537 }
1538 Ok(text
1539 .lines()
1540 .skip(line as usize + 1)
1541 .take(limit)
1542 .collect::<String>())
1543 })?
1544 })
1545 }
1546
1547 pub fn write_text_file(
1548 &self,
1549 path: PathBuf,
1550 content: String,
1551 cx: &mut Context<Self>,
1552 ) -> Task<Result<()>> {
1553 let project = self.project.clone();
1554 let action_log = self.action_log.clone();
1555 cx.spawn(async move |this, cx| {
1556 let load = project.update(cx, |project, cx| {
1557 let path = project
1558 .project_path_for_absolute_path(&path, cx)
1559 .context("invalid path")?;
1560 anyhow::Ok(project.open_buffer(path, cx))
1561 });
1562 let buffer = load??.await?;
1563 let snapshot = this.update(cx, |this, cx| {
1564 this.shared_buffers
1565 .get(&buffer)
1566 .cloned()
1567 .unwrap_or_else(|| buffer.read(cx).snapshot())
1568 })?;
1569 let edits = cx
1570 .background_executor()
1571 .spawn(async move {
1572 let old_text = snapshot.text();
1573 text_diff(old_text.as_str(), &content)
1574 .into_iter()
1575 .map(|(range, replacement)| {
1576 (
1577 snapshot.anchor_after(range.start)
1578 ..snapshot.anchor_before(range.end),
1579 replacement,
1580 )
1581 })
1582 .collect::<Vec<_>>()
1583 })
1584 .await;
1585 cx.update(|cx| {
1586 project.update(cx, |project, cx| {
1587 project.set_agent_location(
1588 Some(AgentLocation {
1589 buffer: buffer.downgrade(),
1590 position: edits
1591 .last()
1592 .map(|(range, _)| range.end)
1593 .unwrap_or(Anchor::MIN),
1594 }),
1595 cx,
1596 );
1597 });
1598
1599 action_log.update(cx, |action_log, cx| {
1600 action_log.buffer_read(buffer.clone(), cx);
1601 });
1602 buffer.update(cx, |buffer, cx| {
1603 buffer.edit(edits, None, cx);
1604 });
1605 action_log.update(cx, |action_log, cx| {
1606 action_log.buffer_edited(buffer.clone(), cx);
1607 });
1608 })?;
1609 project
1610 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1611 .await
1612 })
1613 }
1614
1615 pub fn to_markdown(&self, cx: &App) -> String {
1616 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1617 }
1618
1619 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1620 cx.emit(AcpThreadEvent::ServerExited(status));
1621 }
1622}
1623
1624fn markdown_for_raw_output(
1625 raw_output: &serde_json::Value,
1626 language_registry: &Arc<LanguageRegistry>,
1627 cx: &mut App,
1628) -> Option<Entity<Markdown>> {
1629 match raw_output {
1630 serde_json::Value::Null => None,
1631 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1632 Markdown::new(
1633 value.to_string().into(),
1634 Some(language_registry.clone()),
1635 None,
1636 cx,
1637 )
1638 })),
1639 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1640 Markdown::new(
1641 value.to_string().into(),
1642 Some(language_registry.clone()),
1643 None,
1644 cx,
1645 )
1646 })),
1647 serde_json::Value::String(value) => Some(cx.new(|cx| {
1648 Markdown::new(
1649 value.clone().into(),
1650 Some(language_registry.clone()),
1651 None,
1652 cx,
1653 )
1654 })),
1655 value => Some(cx.new(|cx| {
1656 Markdown::new(
1657 format!("```json\n{}\n```", value).into(),
1658 Some(language_registry.clone()),
1659 None,
1660 cx,
1661 )
1662 })),
1663 }
1664}
1665
1666#[cfg(test)]
1667mod tests {
1668 use super::*;
1669 use anyhow::anyhow;
1670 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1671 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1672 use indoc::indoc;
1673 use project::{FakeFs, Fs};
1674 use rand::Rng as _;
1675 use serde_json::json;
1676 use settings::SettingsStore;
1677 use smol::stream::StreamExt as _;
1678 use std::{
1679 any::Any,
1680 cell::RefCell,
1681 path::Path,
1682 rc::Rc,
1683 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1684 time::Duration,
1685 };
1686 use util::path;
1687
1688 fn init_test(cx: &mut TestAppContext) {
1689 env_logger::try_init().ok();
1690 cx.update(|cx| {
1691 let settings_store = SettingsStore::test(cx);
1692 cx.set_global(settings_store);
1693 Project::init_settings(cx);
1694 language::init(cx);
1695 });
1696 }
1697
1698 #[gpui::test]
1699 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1700 init_test(cx);
1701
1702 let fs = FakeFs::new(cx.executor());
1703 let project = Project::test(fs, [], cx).await;
1704 let connection = Rc::new(FakeAgentConnection::new());
1705 let thread = cx
1706 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1707 .await
1708 .unwrap();
1709
1710 // Test creating a new user message
1711 thread.update(cx, |thread, cx| {
1712 thread.push_user_content_block(
1713 None,
1714 acp::ContentBlock::Text(acp::TextContent {
1715 annotations: None,
1716 text: "Hello, ".to_string(),
1717 }),
1718 cx,
1719 );
1720 });
1721
1722 thread.update(cx, |thread, cx| {
1723 assert_eq!(thread.entries.len(), 1);
1724 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1725 assert_eq!(user_msg.id, None);
1726 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1727 } else {
1728 panic!("Expected UserMessage");
1729 }
1730 });
1731
1732 // Test appending to existing user message
1733 let message_1_id = UserMessageId::new();
1734 thread.update(cx, |thread, cx| {
1735 thread.push_user_content_block(
1736 Some(message_1_id.clone()),
1737 acp::ContentBlock::Text(acp::TextContent {
1738 annotations: None,
1739 text: "world!".to_string(),
1740 }),
1741 cx,
1742 );
1743 });
1744
1745 thread.update(cx, |thread, cx| {
1746 assert_eq!(thread.entries.len(), 1);
1747 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1748 assert_eq!(user_msg.id, Some(message_1_id));
1749 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1750 } else {
1751 panic!("Expected UserMessage");
1752 }
1753 });
1754
1755 // Test creating new user message after assistant message
1756 thread.update(cx, |thread, cx| {
1757 thread.push_assistant_content_block(
1758 acp::ContentBlock::Text(acp::TextContent {
1759 annotations: None,
1760 text: "Assistant response".to_string(),
1761 }),
1762 false,
1763 cx,
1764 );
1765 });
1766
1767 let message_2_id = UserMessageId::new();
1768 thread.update(cx, |thread, cx| {
1769 thread.push_user_content_block(
1770 Some(message_2_id.clone()),
1771 acp::ContentBlock::Text(acp::TextContent {
1772 annotations: None,
1773 text: "New user message".to_string(),
1774 }),
1775 cx,
1776 );
1777 });
1778
1779 thread.update(cx, |thread, cx| {
1780 assert_eq!(thread.entries.len(), 3);
1781 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1782 assert_eq!(user_msg.id, Some(message_2_id));
1783 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1784 } else {
1785 panic!("Expected UserMessage at index 2");
1786 }
1787 });
1788 }
1789
1790 #[gpui::test]
1791 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1792 init_test(cx);
1793
1794 let fs = FakeFs::new(cx.executor());
1795 let project = Project::test(fs, [], cx).await;
1796 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1797 |_, thread, mut cx| {
1798 async move {
1799 thread.update(&mut cx, |thread, cx| {
1800 thread
1801 .handle_session_update(
1802 acp::SessionUpdate::AgentThoughtChunk {
1803 content: "Thinking ".into(),
1804 },
1805 cx,
1806 )
1807 .unwrap();
1808 thread
1809 .handle_session_update(
1810 acp::SessionUpdate::AgentThoughtChunk {
1811 content: "hard!".into(),
1812 },
1813 cx,
1814 )
1815 .unwrap();
1816 })?;
1817 Ok(acp::PromptResponse {
1818 stop_reason: acp::StopReason::EndTurn,
1819 })
1820 }
1821 .boxed_local()
1822 },
1823 ));
1824
1825 let thread = cx
1826 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1827 .await
1828 .unwrap();
1829
1830 thread
1831 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1832 .await
1833 .unwrap();
1834
1835 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1836 assert_eq!(
1837 output,
1838 indoc! {r#"
1839 ## User
1840
1841 Hello from Zed!
1842
1843 ## Assistant
1844
1845 <thinking>
1846 Thinking hard!
1847 </thinking>
1848
1849 "#}
1850 );
1851 }
1852
1853 #[gpui::test]
1854 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1855 init_test(cx);
1856
1857 let fs = FakeFs::new(cx.executor());
1858 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1859 .await;
1860 let project = Project::test(fs.clone(), [], cx).await;
1861 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1862 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1863 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1864 move |_, thread, mut cx| {
1865 let read_file_tx = read_file_tx.clone();
1866 async move {
1867 let content = thread
1868 .update(&mut cx, |thread, cx| {
1869 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1870 })
1871 .unwrap()
1872 .await
1873 .unwrap();
1874 assert_eq!(content, "one\ntwo\nthree\n");
1875 read_file_tx.take().unwrap().send(()).unwrap();
1876 thread
1877 .update(&mut cx, |thread, cx| {
1878 thread.write_text_file(
1879 path!("/tmp/foo").into(),
1880 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1881 cx,
1882 )
1883 })
1884 .unwrap()
1885 .await
1886 .unwrap();
1887 Ok(acp::PromptResponse {
1888 stop_reason: acp::StopReason::EndTurn,
1889 })
1890 }
1891 .boxed_local()
1892 },
1893 ));
1894
1895 let (worktree, pathbuf) = project
1896 .update(cx, |project, cx| {
1897 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1898 })
1899 .await
1900 .unwrap();
1901 let buffer = project
1902 .update(cx, |project, cx| {
1903 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1904 })
1905 .await
1906 .unwrap();
1907
1908 let thread = cx
1909 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1910 .await
1911 .unwrap();
1912
1913 let request = thread.update(cx, |thread, cx| {
1914 thread.send_raw("Extend the count in /tmp/foo", cx)
1915 });
1916 read_file_rx.await.ok();
1917 buffer.update(cx, |buffer, cx| {
1918 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1919 });
1920 cx.run_until_parked();
1921 assert_eq!(
1922 buffer.read_with(cx, |buffer, _| buffer.text()),
1923 "zero\none\ntwo\nthree\nfour\nfive\n"
1924 );
1925 assert_eq!(
1926 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1927 "zero\none\ntwo\nthree\nfour\nfive\n"
1928 );
1929 request.await.unwrap();
1930 }
1931
1932 #[gpui::test]
1933 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1934 init_test(cx);
1935
1936 let fs = FakeFs::new(cx.executor());
1937 let project = Project::test(fs, [], cx).await;
1938 let id = acp::ToolCallId("test".into());
1939
1940 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1941 let id = id.clone();
1942 move |_, thread, mut cx| {
1943 let id = id.clone();
1944 async move {
1945 thread
1946 .update(&mut cx, |thread, cx| {
1947 thread.handle_session_update(
1948 acp::SessionUpdate::ToolCall(acp::ToolCall {
1949 id: id.clone(),
1950 title: "Label".into(),
1951 kind: acp::ToolKind::Fetch,
1952 status: acp::ToolCallStatus::InProgress,
1953 content: vec![],
1954 locations: vec![],
1955 raw_input: None,
1956 raw_output: None,
1957 }),
1958 cx,
1959 )
1960 })
1961 .unwrap()
1962 .unwrap();
1963 Ok(acp::PromptResponse {
1964 stop_reason: acp::StopReason::EndTurn,
1965 })
1966 }
1967 .boxed_local()
1968 }
1969 }));
1970
1971 let thread = cx
1972 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1973 .await
1974 .unwrap();
1975
1976 let request = thread.update(cx, |thread, cx| {
1977 thread.send_raw("Fetch https://example.com", cx)
1978 });
1979
1980 run_until_first_tool_call(&thread, cx).await;
1981
1982 thread.read_with(cx, |thread, _| {
1983 assert!(matches!(
1984 thread.entries[1],
1985 AgentThreadEntry::ToolCall(ToolCall {
1986 status: ToolCallStatus::InProgress,
1987 ..
1988 })
1989 ));
1990 });
1991
1992 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1993
1994 thread.read_with(cx, |thread, _| {
1995 assert!(matches!(
1996 &thread.entries[1],
1997 AgentThreadEntry::ToolCall(ToolCall {
1998 status: ToolCallStatus::Canceled,
1999 ..
2000 })
2001 ));
2002 });
2003
2004 thread
2005 .update(cx, |thread, cx| {
2006 thread.handle_session_update(
2007 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2008 id,
2009 fields: acp::ToolCallUpdateFields {
2010 status: Some(acp::ToolCallStatus::Completed),
2011 ..Default::default()
2012 },
2013 }),
2014 cx,
2015 )
2016 })
2017 .unwrap();
2018
2019 request.await.unwrap();
2020
2021 thread.read_with(cx, |thread, _| {
2022 assert!(matches!(
2023 thread.entries[1],
2024 AgentThreadEntry::ToolCall(ToolCall {
2025 status: ToolCallStatus::Completed,
2026 ..
2027 })
2028 ));
2029 });
2030 }
2031
2032 #[gpui::test]
2033 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2034 init_test(cx);
2035 let fs = FakeFs::new(cx.background_executor.clone());
2036 fs.insert_tree(path!("/test"), json!({})).await;
2037 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2038
2039 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2040 move |_, thread, mut cx| {
2041 async move {
2042 thread
2043 .update(&mut cx, |thread, cx| {
2044 thread.handle_session_update(
2045 acp::SessionUpdate::ToolCall(acp::ToolCall {
2046 id: acp::ToolCallId("test".into()),
2047 title: "Label".into(),
2048 kind: acp::ToolKind::Edit,
2049 status: acp::ToolCallStatus::Completed,
2050 content: vec![acp::ToolCallContent::Diff {
2051 diff: acp::Diff {
2052 path: "/test/test.txt".into(),
2053 old_text: None,
2054 new_text: "foo".into(),
2055 },
2056 }],
2057 locations: vec![],
2058 raw_input: None,
2059 raw_output: None,
2060 }),
2061 cx,
2062 )
2063 })
2064 .unwrap()
2065 .unwrap();
2066 Ok(acp::PromptResponse {
2067 stop_reason: acp::StopReason::EndTurn,
2068 })
2069 }
2070 .boxed_local()
2071 }
2072 }));
2073
2074 let thread = cx
2075 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2076 .await
2077 .unwrap();
2078
2079 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2080 .await
2081 .unwrap();
2082
2083 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2084 }
2085
2086 #[gpui::test(iterations = 10)]
2087 async fn test_checkpoints(cx: &mut TestAppContext) {
2088 init_test(cx);
2089 let fs = FakeFs::new(cx.background_executor.clone());
2090 fs.insert_tree(
2091 path!("/test"),
2092 json!({
2093 ".git": {}
2094 }),
2095 )
2096 .await;
2097 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2098
2099 let simulate_changes = Arc::new(AtomicBool::new(true));
2100 let next_filename = Arc::new(AtomicUsize::new(0));
2101 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2102 let simulate_changes = simulate_changes.clone();
2103 let next_filename = next_filename.clone();
2104 let fs = fs.clone();
2105 move |request, thread, mut cx| {
2106 let fs = fs.clone();
2107 let simulate_changes = simulate_changes.clone();
2108 let next_filename = next_filename.clone();
2109 async move {
2110 if simulate_changes.load(SeqCst) {
2111 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2112 fs.write(Path::new(&filename), b"").await?;
2113 }
2114
2115 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2116 panic!("expected text content block");
2117 };
2118 thread.update(&mut cx, |thread, cx| {
2119 thread
2120 .handle_session_update(
2121 acp::SessionUpdate::AgentMessageChunk {
2122 content: content.text.to_uppercase().into(),
2123 },
2124 cx,
2125 )
2126 .unwrap();
2127 })?;
2128 Ok(acp::PromptResponse {
2129 stop_reason: acp::StopReason::EndTurn,
2130 })
2131 }
2132 .boxed_local()
2133 }
2134 }));
2135 let thread = cx
2136 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2137 .await
2138 .unwrap();
2139
2140 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2141 .await
2142 .unwrap();
2143 thread.read_with(cx, |thread, cx| {
2144 assert_eq!(
2145 thread.to_markdown(cx),
2146 indoc! {"
2147 ## User (checkpoint)
2148
2149 Lorem
2150
2151 ## Assistant
2152
2153 LOREM
2154
2155 "}
2156 );
2157 });
2158 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2159
2160 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2161 .await
2162 .unwrap();
2163 thread.read_with(cx, |thread, cx| {
2164 assert_eq!(
2165 thread.to_markdown(cx),
2166 indoc! {"
2167 ## User (checkpoint)
2168
2169 Lorem
2170
2171 ## Assistant
2172
2173 LOREM
2174
2175 ## User (checkpoint)
2176
2177 ipsum
2178
2179 ## Assistant
2180
2181 IPSUM
2182
2183 "}
2184 );
2185 });
2186 assert_eq!(
2187 fs.files(),
2188 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2189 );
2190
2191 // Checkpoint isn't stored when there are no changes.
2192 simulate_changes.store(false, SeqCst);
2193 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2194 .await
2195 .unwrap();
2196 thread.read_with(cx, |thread, cx| {
2197 assert_eq!(
2198 thread.to_markdown(cx),
2199 indoc! {"
2200 ## User (checkpoint)
2201
2202 Lorem
2203
2204 ## Assistant
2205
2206 LOREM
2207
2208 ## User (checkpoint)
2209
2210 ipsum
2211
2212 ## Assistant
2213
2214 IPSUM
2215
2216 ## User
2217
2218 dolor
2219
2220 ## Assistant
2221
2222 DOLOR
2223
2224 "}
2225 );
2226 });
2227 assert_eq!(
2228 fs.files(),
2229 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2230 );
2231
2232 // Rewinding the conversation truncates the history and restores the checkpoint.
2233 thread
2234 .update(cx, |thread, cx| {
2235 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2236 panic!("unexpected entries {:?}", thread.entries)
2237 };
2238 thread.rewind(message.id.clone().unwrap(), cx)
2239 })
2240 .await
2241 .unwrap();
2242 thread.read_with(cx, |thread, cx| {
2243 assert_eq!(
2244 thread.to_markdown(cx),
2245 indoc! {"
2246 ## User (checkpoint)
2247
2248 Lorem
2249
2250 ## Assistant
2251
2252 LOREM
2253
2254 "}
2255 );
2256 });
2257 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2258 }
2259
2260 async fn run_until_first_tool_call(
2261 thread: &Entity<AcpThread>,
2262 cx: &mut TestAppContext,
2263 ) -> usize {
2264 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2265
2266 let subscription = cx.update(|cx| {
2267 cx.subscribe(thread, move |thread, _, cx| {
2268 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2269 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2270 return tx.try_send(ix).unwrap();
2271 }
2272 }
2273 })
2274 });
2275
2276 select! {
2277 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2278 panic!("Timeout waiting for tool call")
2279 }
2280 ix = rx.next().fuse() => {
2281 drop(subscription);
2282 ix.unwrap()
2283 }
2284 }
2285 }
2286
2287 #[derive(Clone, Default)]
2288 struct FakeAgentConnection {
2289 auth_methods: Vec<acp::AuthMethod>,
2290 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2291 on_user_message: Option<
2292 Rc<
2293 dyn Fn(
2294 acp::PromptRequest,
2295 WeakEntity<AcpThread>,
2296 AsyncApp,
2297 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2298 + 'static,
2299 >,
2300 >,
2301 }
2302
2303 impl FakeAgentConnection {
2304 fn new() -> Self {
2305 Self {
2306 auth_methods: Vec::new(),
2307 on_user_message: None,
2308 sessions: Arc::default(),
2309 }
2310 }
2311
2312 #[expect(unused)]
2313 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2314 self.auth_methods = auth_methods;
2315 self
2316 }
2317
2318 fn on_user_message(
2319 mut self,
2320 handler: impl Fn(
2321 acp::PromptRequest,
2322 WeakEntity<AcpThread>,
2323 AsyncApp,
2324 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2325 + 'static,
2326 ) -> Self {
2327 self.on_user_message.replace(Rc::new(handler));
2328 self
2329 }
2330 }
2331
2332 impl AgentConnection for FakeAgentConnection {
2333 fn auth_methods(&self) -> &[acp::AuthMethod] {
2334 &self.auth_methods
2335 }
2336
2337 fn new_thread(
2338 self: Rc<Self>,
2339 project: Entity<Project>,
2340 _cwd: &Path,
2341 cx: &mut App,
2342 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2343 let session_id = acp::SessionId(
2344 rand::thread_rng()
2345 .sample_iter(&rand::distributions::Alphanumeric)
2346 .take(7)
2347 .map(char::from)
2348 .collect::<String>()
2349 .into(),
2350 );
2351 let thread =
2352 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2353 self.sessions.lock().insert(session_id, thread.downgrade());
2354 Task::ready(Ok(thread))
2355 }
2356
2357 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2358 if self.auth_methods().iter().any(|m| m.id == method) {
2359 Task::ready(Ok(()))
2360 } else {
2361 Task::ready(Err(anyhow!("Invalid Auth Method")))
2362 }
2363 }
2364
2365 fn prompt(
2366 &self,
2367 _id: Option<UserMessageId>,
2368 params: acp::PromptRequest,
2369 cx: &mut App,
2370 ) -> Task<gpui::Result<acp::PromptResponse>> {
2371 let sessions = self.sessions.lock();
2372 let thread = sessions.get(¶ms.session_id).unwrap();
2373 if let Some(handler) = &self.on_user_message {
2374 let handler = handler.clone();
2375 let thread = thread.clone();
2376 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2377 } else {
2378 Task::ready(Ok(acp::PromptResponse {
2379 stop_reason: acp::StopReason::EndTurn,
2380 }))
2381 }
2382 }
2383
2384 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2385 let sessions = self.sessions.lock();
2386 let thread = sessions.get(&session_id).unwrap().clone();
2387
2388 cx.spawn(async move |cx| {
2389 thread
2390 .update(cx, |thread, cx| thread.cancel(cx))
2391 .unwrap()
2392 .await
2393 })
2394 .detach();
2395 }
2396
2397 fn session_editor(
2398 &self,
2399 session_id: &acp::SessionId,
2400 _cx: &mut App,
2401 ) -> Option<Rc<dyn AgentSessionEditor>> {
2402 Some(Rc::new(FakeAgentSessionEditor {
2403 _session_id: session_id.clone(),
2404 }))
2405 }
2406
2407 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2408 self
2409 }
2410 }
2411
2412 struct FakeAgentSessionEditor {
2413 _session_id: acp::SessionId,
2414 }
2415
2416 impl AgentSessionEditor for FakeAgentSessionEditor {
2417 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2418 Task::ready(Ok(()))
2419 }
2420 }
2421}