1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6pub use connection::*;
7pub use diff::*;
8pub use mention::*;
9use serde::{Deserialize, Serialize};
10pub use terminal::*;
11
12use action_log::ActionLog;
13use agent_client_protocol as acp;
14use anyhow::{Context as _, Result, anyhow};
15use chrono::{DateTime, Utc};
16use editor::Bias;
17use futures::{FutureExt, channel::oneshot, future::BoxFuture};
18use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
19use itertools::Itertools;
20use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
21use markdown::Markdown;
22use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
23use std::collections::HashMap;
24use std::error::Error;
25use std::fmt::{Formatter, Write};
26use std::ops::Range;
27use std::process::ExitStatus;
28use std::rc::Rc;
29use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
30use ui::App;
31use util::ResultExt;
32
33#[derive(Debug)]
34pub struct UserMessage {
35 pub id: Option<UserMessageId>,
36 pub content: ContentBlock,
37 pub chunks: Vec<acp::ContentBlock>,
38 pub checkpoint: Option<Checkpoint>,
39}
40
41#[derive(Debug)]
42pub struct Checkpoint {
43 git_checkpoint: GitStoreCheckpoint,
44 pub show: bool,
45}
46
47impl UserMessage {
48 fn to_markdown(&self, cx: &App) -> String {
49 let mut markdown = String::new();
50 if self
51 .checkpoint
52 .as_ref()
53 .map_or(false, |checkpoint| checkpoint.show)
54 {
55 writeln!(markdown, "## User (checkpoint)").unwrap();
56 } else {
57 writeln!(markdown, "## User").unwrap();
58 }
59 writeln!(markdown).unwrap();
60 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
61 writeln!(markdown).unwrap();
62 markdown
63 }
64}
65
66#[derive(Debug, PartialEq)]
67pub struct AssistantMessage {
68 pub chunks: Vec<AssistantMessageChunk>,
69}
70
71impl AssistantMessage {
72 pub fn to_markdown(&self, cx: &App) -> String {
73 format!(
74 "## Assistant\n\n{}\n\n",
75 self.chunks
76 .iter()
77 .map(|chunk| chunk.to_markdown(cx))
78 .join("\n\n")
79 )
80 }
81}
82
83#[derive(Debug, PartialEq)]
84pub enum AssistantMessageChunk {
85 Message { block: ContentBlock },
86 Thought { block: ContentBlock },
87}
88
89impl AssistantMessageChunk {
90 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
91 Self::Message {
92 block: ContentBlock::new(chunk.into(), language_registry, cx),
93 }
94 }
95
96 fn to_markdown(&self, cx: &App) -> String {
97 match self {
98 Self::Message { block } => block.to_markdown(cx).to_string(),
99 Self::Thought { block } => {
100 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
101 }
102 }
103 }
104}
105
106#[derive(Debug)]
107pub enum AgentThreadEntry {
108 UserMessage(UserMessage),
109 AssistantMessage(AssistantMessage),
110 ToolCall(ToolCall),
111}
112
113impl AgentThreadEntry {
114 pub fn to_markdown(&self, cx: &App) -> String {
115 match self {
116 Self::UserMessage(message) => message.to_markdown(cx),
117 Self::AssistantMessage(message) => message.to_markdown(cx),
118 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
119 }
120 }
121
122 pub fn user_message(&self) -> Option<&UserMessage> {
123 if let AgentThreadEntry::UserMessage(message) = self {
124 Some(message)
125 } else {
126 None
127 }
128 }
129
130 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
131 if let AgentThreadEntry::ToolCall(call) = self {
132 itertools::Either::Left(call.diffs())
133 } else {
134 itertools::Either::Right(std::iter::empty())
135 }
136 }
137
138 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
139 if let AgentThreadEntry::ToolCall(call) = self {
140 itertools::Either::Left(call.terminals())
141 } else {
142 itertools::Either::Right(std::iter::empty())
143 }
144 }
145
146 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
147 if let AgentThreadEntry::ToolCall(ToolCall {
148 locations,
149 resolved_locations,
150 ..
151 }) = self
152 {
153 Some((
154 locations.get(ix)?.clone(),
155 resolved_locations.get(ix)?.clone()?,
156 ))
157 } else {
158 None
159 }
160 }
161}
162
163#[derive(Debug)]
164pub struct ToolCall {
165 pub id: acp::ToolCallId,
166 pub label: Entity<Markdown>,
167 pub kind: acp::ToolKind,
168 pub content: Vec<ToolCallContent>,
169 pub status: ToolCallStatus,
170 pub locations: Vec<acp::ToolCallLocation>,
171 pub resolved_locations: Vec<Option<AgentLocation>>,
172 pub raw_input: Option<serde_json::Value>,
173 pub raw_output: Option<serde_json::Value>,
174}
175
176impl ToolCall {
177 fn from_acp(
178 tool_call: acp::ToolCall,
179 status: ToolCallStatus,
180 language_registry: Arc<LanguageRegistry>,
181 cx: &mut App,
182 ) -> Self {
183 Self {
184 id: tool_call.id,
185 label: cx.new(|cx| {
186 Markdown::new(
187 tool_call.title.into(),
188 Some(language_registry.clone()),
189 None,
190 cx,
191 )
192 }),
193 kind: tool_call.kind,
194 content: tool_call
195 .content
196 .into_iter()
197 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
198 .collect(),
199 locations: tool_call.locations,
200 resolved_locations: Vec::default(),
201 status,
202 raw_input: tool_call.raw_input,
203 raw_output: tool_call.raw_output,
204 }
205 }
206
207 fn update_fields(
208 &mut self,
209 fields: acp::ToolCallUpdateFields,
210 language_registry: Arc<LanguageRegistry>,
211 cx: &mut App,
212 ) {
213 let acp::ToolCallUpdateFields {
214 kind,
215 status,
216 title,
217 content,
218 locations,
219 raw_input,
220 raw_output,
221 } = fields;
222
223 if let Some(kind) = kind {
224 self.kind = kind;
225 }
226
227 if let Some(status) = status {
228 self.status = status.into();
229 }
230
231 if let Some(title) = title {
232 self.label.update(cx, |label, cx| {
233 label.replace(title, cx);
234 });
235 }
236
237 if let Some(content) = content {
238 self.content = content
239 .into_iter()
240 .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
241 .collect();
242 }
243
244 if let Some(locations) = locations {
245 self.locations = locations;
246 }
247
248 if let Some(raw_input) = raw_input {
249 self.raw_input = Some(raw_input);
250 }
251
252 if let Some(raw_output) = raw_output {
253 if self.content.is_empty() {
254 if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
255 {
256 self.content
257 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
258 markdown,
259 }));
260 }
261 }
262 self.raw_output = Some(raw_output);
263 }
264 }
265
266 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
267 self.content.iter().filter_map(|content| match content {
268 ToolCallContent::Diff(diff) => Some(diff),
269 ToolCallContent::ContentBlock(_) => None,
270 ToolCallContent::Terminal(_) => None,
271 })
272 }
273
274 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
275 self.content.iter().filter_map(|content| match content {
276 ToolCallContent::Terminal(terminal) => Some(terminal),
277 ToolCallContent::ContentBlock(_) => None,
278 ToolCallContent::Diff(_) => None,
279 })
280 }
281
282 fn to_markdown(&self, cx: &App) -> String {
283 let mut markdown = format!(
284 "**Tool Call: {}**\nStatus: {}\n\n",
285 self.label.read(cx).source(),
286 self.status
287 );
288 for content in &self.content {
289 markdown.push_str(content.to_markdown(cx).as_str());
290 markdown.push_str("\n\n");
291 }
292 markdown
293 }
294
295 async fn resolve_location(
296 location: acp::ToolCallLocation,
297 project: WeakEntity<Project>,
298 cx: &mut AsyncApp,
299 ) -> Option<AgentLocation> {
300 let buffer = project
301 .update(cx, |project, cx| {
302 if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
303 Some(project.open_buffer(path, cx))
304 } else {
305 None
306 }
307 })
308 .ok()??;
309 let buffer = buffer.await.log_err()?;
310 let position = buffer
311 .update(cx, |buffer, _| {
312 if let Some(row) = location.line {
313 let snapshot = buffer.snapshot();
314 let column = snapshot.indent_size_for_line(row).len;
315 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
316 snapshot.anchor_before(point)
317 } else {
318 Anchor::MIN
319 }
320 })
321 .ok()?;
322
323 Some(AgentLocation {
324 buffer: buffer.downgrade(),
325 position,
326 })
327 }
328
329 fn resolve_locations(
330 &self,
331 project: Entity<Project>,
332 cx: &mut App,
333 ) -> Task<Vec<Option<AgentLocation>>> {
334 let locations = self.locations.clone();
335 project.update(cx, |_, cx| {
336 cx.spawn(async move |project, cx| {
337 let mut new_locations = Vec::new();
338 for location in locations {
339 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
340 }
341 new_locations
342 })
343 })
344 }
345}
346
347#[derive(Debug)]
348pub enum ToolCallStatus {
349 /// The tool call hasn't started running yet, but we start showing it to
350 /// the user.
351 Pending,
352 /// The tool call is waiting for confirmation from the user.
353 WaitingForConfirmation {
354 options: Vec<acp::PermissionOption>,
355 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
356 },
357 /// The tool call is currently running.
358 InProgress,
359 /// The tool call completed successfully.
360 Completed,
361 /// The tool call failed.
362 Failed,
363 /// The user rejected the tool call.
364 Rejected,
365 /// The user canceled generation so the tool call was canceled.
366 Canceled,
367}
368
369impl From<acp::ToolCallStatus> for ToolCallStatus {
370 fn from(status: acp::ToolCallStatus) -> Self {
371 match status {
372 acp::ToolCallStatus::Pending => Self::Pending,
373 acp::ToolCallStatus::InProgress => Self::InProgress,
374 acp::ToolCallStatus::Completed => Self::Completed,
375 acp::ToolCallStatus::Failed => Self::Failed,
376 }
377 }
378}
379
380impl Display for ToolCallStatus {
381 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
382 write!(
383 f,
384 "{}",
385 match self {
386 ToolCallStatus::Pending => "Pending",
387 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
388 ToolCallStatus::InProgress => "In Progress",
389 ToolCallStatus::Completed => "Completed",
390 ToolCallStatus::Failed => "Failed",
391 ToolCallStatus::Rejected => "Rejected",
392 ToolCallStatus::Canceled => "Canceled",
393 }
394 )
395 }
396}
397
398#[derive(Debug, PartialEq, Clone)]
399pub enum ContentBlock {
400 Empty,
401 Markdown { markdown: Entity<Markdown> },
402 ResourceLink { resource_link: acp::ResourceLink },
403}
404
405impl ContentBlock {
406 pub fn new(
407 block: acp::ContentBlock,
408 language_registry: &Arc<LanguageRegistry>,
409 cx: &mut App,
410 ) -> Self {
411 let mut this = Self::Empty;
412 this.append(block, language_registry, cx);
413 this
414 }
415
416 pub fn new_combined(
417 blocks: impl IntoIterator<Item = acp::ContentBlock>,
418 language_registry: Arc<LanguageRegistry>,
419 cx: &mut App,
420 ) -> Self {
421 let mut this = Self::Empty;
422 for block in blocks {
423 this.append(block, &language_registry, cx);
424 }
425 this
426 }
427
428 pub fn append(
429 &mut self,
430 block: acp::ContentBlock,
431 language_registry: &Arc<LanguageRegistry>,
432 cx: &mut App,
433 ) {
434 if matches!(self, ContentBlock::Empty) {
435 if let acp::ContentBlock::ResourceLink(resource_link) = block {
436 *self = ContentBlock::ResourceLink { resource_link };
437 return;
438 }
439 }
440
441 let new_content = self.block_string_contents(block);
442
443 match self {
444 ContentBlock::Empty => {
445 *self = Self::create_markdown_block(new_content, language_registry, cx);
446 }
447 ContentBlock::Markdown { markdown } => {
448 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
449 }
450 ContentBlock::ResourceLink { resource_link } => {
451 let existing_content = Self::resource_link_md(&resource_link.uri);
452 let combined = format!("{}\n{}", existing_content, new_content);
453
454 *self = Self::create_markdown_block(combined, language_registry, cx);
455 }
456 }
457 }
458
459 fn create_markdown_block(
460 content: String,
461 language_registry: &Arc<LanguageRegistry>,
462 cx: &mut App,
463 ) -> ContentBlock {
464 ContentBlock::Markdown {
465 markdown: cx
466 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
467 }
468 }
469
470 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
471 match block {
472 acp::ContentBlock::Text(text_content) => text_content.text.clone(),
473 acp::ContentBlock::ResourceLink(resource_link) => {
474 Self::resource_link_md(&resource_link.uri)
475 }
476 acp::ContentBlock::Resource(acp::EmbeddedResource {
477 resource:
478 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
479 uri,
480 ..
481 }),
482 ..
483 }) => Self::resource_link_md(&uri),
484 acp::ContentBlock::Image(image) => Self::image_md(&image),
485 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
486 }
487 }
488
489 fn resource_link_md(uri: &str) -> String {
490 if let Some(uri) = MentionUri::parse(&uri).log_err() {
491 uri.as_link().to_string()
492 } else {
493 uri.to_string()
494 }
495 }
496
497 fn image_md(_image: &acp::ImageContent) -> String {
498 "`Image`".into()
499 }
500
501 fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
502 match self {
503 ContentBlock::Empty => "",
504 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
505 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
506 }
507 }
508
509 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
510 match self {
511 ContentBlock::Empty => None,
512 ContentBlock::Markdown { markdown } => Some(markdown),
513 ContentBlock::ResourceLink { .. } => None,
514 }
515 }
516
517 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
518 match self {
519 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
520 _ => None,
521 }
522 }
523}
524
525#[derive(Debug)]
526pub enum ToolCallContent {
527 ContentBlock(ContentBlock),
528 Diff(Entity<Diff>),
529 Terminal(Entity<Terminal>),
530}
531
532impl ToolCallContent {
533 pub fn from_acp(
534 content: acp::ToolCallContent,
535 language_registry: Arc<LanguageRegistry>,
536 cx: &mut App,
537 ) -> Self {
538 match content {
539 acp::ToolCallContent::Content { content } => {
540 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
541 }
542 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
543 Diff::finalized(
544 diff.path,
545 diff.old_text,
546 diff.new_text,
547 language_registry,
548 cx,
549 )
550 })),
551 }
552 }
553
554 pub fn to_markdown(&self, cx: &App) -> String {
555 match self {
556 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
557 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
558 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
559 }
560 }
561}
562
563#[derive(Debug, PartialEq)]
564pub enum ToolCallUpdate {
565 UpdateFields(acp::ToolCallUpdate),
566 UpdateDiff(ToolCallUpdateDiff),
567 UpdateTerminal(ToolCallUpdateTerminal),
568}
569
570impl ToolCallUpdate {
571 fn id(&self) -> &acp::ToolCallId {
572 match self {
573 Self::UpdateFields(update) => &update.id,
574 Self::UpdateDiff(diff) => &diff.id,
575 Self::UpdateTerminal(terminal) => &terminal.id,
576 }
577 }
578}
579
580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
581 fn from(update: acp::ToolCallUpdate) -> Self {
582 Self::UpdateFields(update)
583 }
584}
585
586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
587 fn from(diff: ToolCallUpdateDiff) -> Self {
588 Self::UpdateDiff(diff)
589 }
590}
591
592#[derive(Debug, PartialEq)]
593pub struct ToolCallUpdateDiff {
594 pub id: acp::ToolCallId,
595 pub diff: Entity<Diff>,
596}
597
598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
599 fn from(terminal: ToolCallUpdateTerminal) -> Self {
600 Self::UpdateTerminal(terminal)
601 }
602}
603
604#[derive(Debug, PartialEq)]
605pub struct ToolCallUpdateTerminal {
606 pub id: acp::ToolCallId,
607 pub terminal: Entity<Terminal>,
608}
609
610#[derive(Debug, Default)]
611pub struct Plan {
612 pub entries: Vec<PlanEntry>,
613}
614
615#[derive(Debug)]
616pub struct PlanStats<'a> {
617 pub in_progress_entry: Option<&'a PlanEntry>,
618 pub pending: u32,
619 pub completed: u32,
620}
621
622impl Plan {
623 pub fn is_empty(&self) -> bool {
624 self.entries.is_empty()
625 }
626
627 pub fn stats(&self) -> PlanStats<'_> {
628 let mut stats = PlanStats {
629 in_progress_entry: None,
630 pending: 0,
631 completed: 0,
632 };
633
634 for entry in &self.entries {
635 match &entry.status {
636 acp::PlanEntryStatus::Pending => {
637 stats.pending += 1;
638 }
639 acp::PlanEntryStatus::InProgress => {
640 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
641 }
642 acp::PlanEntryStatus::Completed => {
643 stats.completed += 1;
644 }
645 }
646 }
647
648 stats
649 }
650}
651
652#[derive(Debug)]
653pub struct PlanEntry {
654 pub content: Entity<Markdown>,
655 pub priority: acp::PlanEntryPriority,
656 pub status: acp::PlanEntryStatus,
657}
658
659impl PlanEntry {
660 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
661 Self {
662 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
663 priority: entry.priority,
664 status: entry.status,
665 }
666 }
667}
668
669#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
670pub struct AgentServerName(pub SharedString);
671
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct AcpThreadMetadata {
674 pub agent: AgentServerName,
675 pub id: acp::SessionId,
676 pub title: SharedString,
677 pub updated_at: DateTime<Utc>,
678}
679
680pub struct AcpThread {
681 title: SharedString,
682 entries: Vec<AgentThreadEntry>,
683 plan: Plan,
684 project: Entity<Project>,
685 action_log: Entity<ActionLog>,
686 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
687 send_task: Option<Task<()>>,
688 connection: Rc<dyn AgentConnection>,
689 session_id: acp::SessionId,
690}
691
692pub enum AcpThreadEvent {
693 NewEntry,
694 EntryUpdated(usize),
695 EntriesRemoved(Range<usize>),
696 ToolAuthorizationRequired,
697 Stopped,
698 Error,
699 ServerExited(ExitStatus),
700}
701
702impl EventEmitter<AcpThreadEvent> for AcpThread {}
703
704#[derive(PartialEq, Eq)]
705pub enum ThreadStatus {
706 Idle,
707 WaitingForToolConfirmation,
708 Generating,
709}
710
711#[derive(Debug, Clone)]
712pub enum LoadError {
713 Unsupported {
714 error_message: SharedString,
715 upgrade_message: SharedString,
716 upgrade_command: String,
717 },
718 Exited(i32),
719 Other(SharedString),
720}
721
722impl Display for LoadError {
723 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
724 match self {
725 LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
726 LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
727 LoadError::Other(msg) => write!(f, "{}", msg),
728 }
729 }
730}
731
732impl Error for LoadError {}
733
734impl AcpThread {
735 pub fn new(
736 title: impl Into<SharedString>,
737 connection: Rc<dyn AgentConnection>,
738 project: Entity<Project>,
739 session_id: acp::SessionId,
740 cx: &mut Context<Self>,
741 ) -> Self {
742 let action_log = cx.new(|_| ActionLog::new(project.clone()));
743
744 Self {
745 action_log,
746 shared_buffers: Default::default(),
747 entries: Default::default(),
748 plan: Default::default(),
749 title: title.into(),
750 project,
751 send_task: None,
752 connection,
753 session_id,
754 }
755 }
756
757 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
758 &self.connection
759 }
760
761 pub fn action_log(&self) -> &Entity<ActionLog> {
762 &self.action_log
763 }
764
765 pub fn project(&self) -> &Entity<Project> {
766 &self.project
767 }
768
769 pub fn title(&self) -> SharedString {
770 self.title.clone()
771 }
772
773 pub fn entries(&self) -> &[AgentThreadEntry] {
774 &self.entries
775 }
776
777 pub fn session_id(&self) -> &acp::SessionId {
778 &self.session_id
779 }
780
781 pub fn status(&self) -> ThreadStatus {
782 if self.send_task.is_some() {
783 if self.waiting_for_tool_confirmation() {
784 ThreadStatus::WaitingForToolConfirmation
785 } else {
786 ThreadStatus::Generating
787 }
788 } else {
789 ThreadStatus::Idle
790 }
791 }
792
793 pub fn has_pending_edit_tool_calls(&self) -> bool {
794 for entry in self.entries.iter().rev() {
795 match entry {
796 AgentThreadEntry::UserMessage(_) => return false,
797 AgentThreadEntry::ToolCall(
798 call @ ToolCall {
799 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
800 ..
801 },
802 ) if call.diffs().next().is_some() => {
803 return true;
804 }
805 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
806 }
807 }
808
809 false
810 }
811
812 pub fn used_tools_since_last_user_message(&self) -> bool {
813 for entry in self.entries.iter().rev() {
814 match entry {
815 AgentThreadEntry::UserMessage(..) => return false,
816 AgentThreadEntry::AssistantMessage(..) => continue,
817 AgentThreadEntry::ToolCall(..) => return true,
818 }
819 }
820
821 false
822 }
823
824 pub fn handle_session_update(
825 &mut self,
826 update: acp::SessionUpdate,
827 cx: &mut Context<Self>,
828 ) -> Result<(), acp::Error> {
829 match update {
830 acp::SessionUpdate::UserMessageChunk { content } => {
831 self.push_user_content_block(None, content, cx);
832 }
833 acp::SessionUpdate::AgentMessageChunk { content } => {
834 self.push_assistant_content_block(content, false, cx);
835 }
836 acp::SessionUpdate::AgentThoughtChunk { content } => {
837 self.push_assistant_content_block(content, true, cx);
838 }
839 acp::SessionUpdate::ToolCall(tool_call) => {
840 self.upsert_tool_call(tool_call, cx)?;
841 }
842 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
843 self.update_tool_call(tool_call_update, cx)?;
844 }
845 acp::SessionUpdate::Plan(plan) => {
846 self.update_plan(plan, cx);
847 }
848 }
849 Ok(())
850 }
851
852 pub fn push_user_content_block(
853 &mut self,
854 message_id: Option<UserMessageId>,
855 chunk: acp::ContentBlock,
856 cx: &mut Context<Self>,
857 ) {
858 let language_registry = self.project.read(cx).languages().clone();
859 let entries_len = self.entries.len();
860
861 if let Some(last_entry) = self.entries.last_mut()
862 && let AgentThreadEntry::UserMessage(UserMessage {
863 id,
864 content,
865 chunks,
866 ..
867 }) = last_entry
868 {
869 *id = message_id.or(id.take());
870 content.append(chunk.clone(), &language_registry, cx);
871 chunks.push(chunk);
872 let idx = entries_len - 1;
873 cx.emit(AcpThreadEvent::EntryUpdated(idx));
874 } else {
875 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
876 self.push_entry(
877 AgentThreadEntry::UserMessage(UserMessage {
878 id: message_id,
879 content,
880 chunks: vec![chunk],
881 checkpoint: None,
882 }),
883 cx,
884 );
885 }
886 }
887
888 pub fn push_assistant_content_block(
889 &mut self,
890 chunk: acp::ContentBlock,
891 is_thought: bool,
892 cx: &mut Context<Self>,
893 ) {
894 let language_registry = self.project.read(cx).languages().clone();
895 let entries_len = self.entries.len();
896 if let Some(last_entry) = self.entries.last_mut()
897 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
898 {
899 let idx = entries_len - 1;
900 cx.emit(AcpThreadEvent::EntryUpdated(idx));
901 match (chunks.last_mut(), is_thought) {
902 (Some(AssistantMessageChunk::Message { block }), false)
903 | (Some(AssistantMessageChunk::Thought { block }), true) => {
904 block.append(chunk, &language_registry, cx)
905 }
906 _ => {
907 let block = ContentBlock::new(chunk, &language_registry, cx);
908 if is_thought {
909 chunks.push(AssistantMessageChunk::Thought { block })
910 } else {
911 chunks.push(AssistantMessageChunk::Message { block })
912 }
913 }
914 }
915 } else {
916 let block = ContentBlock::new(chunk, &language_registry, cx);
917 let chunk = if is_thought {
918 AssistantMessageChunk::Thought { block }
919 } else {
920 AssistantMessageChunk::Message { block }
921 };
922
923 self.push_entry(
924 AgentThreadEntry::AssistantMessage(AssistantMessage {
925 chunks: vec![chunk],
926 }),
927 cx,
928 );
929 }
930 }
931
932 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
933 self.entries.push(entry);
934 cx.emit(AcpThreadEvent::NewEntry);
935 }
936
937 pub fn update_tool_call(
938 &mut self,
939 update: impl Into<ToolCallUpdate>,
940 cx: &mut Context<Self>,
941 ) -> Result<()> {
942 let update = update.into();
943 let languages = self.project.read(cx).languages().clone();
944
945 let (ix, current_call) = self
946 .tool_call_mut(update.id())
947 .context("Tool call not found")?;
948 match update {
949 ToolCallUpdate::UpdateFields(update) => {
950 let location_updated = update.fields.locations.is_some();
951 current_call.update_fields(update.fields, languages, cx);
952 if location_updated {
953 self.resolve_locations(update.id.clone(), cx);
954 }
955 }
956 ToolCallUpdate::UpdateDiff(update) => {
957 current_call.content.clear();
958 current_call
959 .content
960 .push(ToolCallContent::Diff(update.diff));
961 }
962 ToolCallUpdate::UpdateTerminal(update) => {
963 current_call.content.clear();
964 current_call
965 .content
966 .push(ToolCallContent::Terminal(update.terminal));
967 }
968 }
969
970 cx.emit(AcpThreadEvent::EntryUpdated(ix));
971
972 Ok(())
973 }
974
975 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
976 pub fn upsert_tool_call(
977 &mut self,
978 tool_call: acp::ToolCall,
979 cx: &mut Context<Self>,
980 ) -> Result<(), acp::Error> {
981 let status = tool_call.status.into();
982 self.upsert_tool_call_inner(tool_call.into(), status, cx)
983 }
984
985 /// Fails if id does not match an existing entry.
986 pub fn upsert_tool_call_inner(
987 &mut self,
988 tool_call_update: acp::ToolCallUpdate,
989 status: ToolCallStatus,
990 cx: &mut Context<Self>,
991 ) -> Result<(), acp::Error> {
992 let language_registry = self.project.read(cx).languages().clone();
993 let id = tool_call_update.id.clone();
994
995 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
996 current_call.update_fields(tool_call_update.fields, language_registry, cx);
997 current_call.status = status;
998
999 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1000 } else {
1001 let call =
1002 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1003 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1004 };
1005
1006 self.resolve_locations(id, cx);
1007 Ok(())
1008 }
1009
1010 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1011 // The tool call we are looking for is typically the last one, or very close to the end.
1012 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1013 self.entries
1014 .iter_mut()
1015 .enumerate()
1016 .rev()
1017 .find_map(|(index, tool_call)| {
1018 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1019 && &tool_call.id == id
1020 {
1021 Some((index, tool_call))
1022 } else {
1023 None
1024 }
1025 })
1026 }
1027
1028 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1029 let project = self.project.clone();
1030 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1031 return;
1032 };
1033 let task = tool_call.resolve_locations(project, cx);
1034 cx.spawn(async move |this, cx| {
1035 let resolved_locations = task.await;
1036 this.update(cx, |this, cx| {
1037 let project = this.project.clone();
1038 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1039 return;
1040 };
1041 if let Some(Some(location)) = resolved_locations.last() {
1042 project.update(cx, |project, cx| {
1043 if let Some(agent_location) = project.agent_location() {
1044 let should_ignore = agent_location.buffer == location.buffer
1045 && location
1046 .buffer
1047 .update(cx, |buffer, _| {
1048 let snapshot = buffer.snapshot();
1049 let old_position =
1050 agent_location.position.to_point(&snapshot);
1051 let new_position = location.position.to_point(&snapshot);
1052 // ignore this so that when we get updates from the edit tool
1053 // the position doesn't reset to the startof line
1054 old_position.row == new_position.row
1055 && old_position.column > new_position.column
1056 })
1057 .ok()
1058 .unwrap_or_default();
1059 if !should_ignore {
1060 project.set_agent_location(Some(location.clone()), cx);
1061 }
1062 }
1063 });
1064 }
1065 if tool_call.resolved_locations != resolved_locations {
1066 tool_call.resolved_locations = resolved_locations;
1067 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1068 }
1069 })
1070 })
1071 .detach();
1072 }
1073
1074 pub fn request_tool_call_authorization(
1075 &mut self,
1076 tool_call: acp::ToolCallUpdate,
1077 options: Vec<acp::PermissionOption>,
1078 cx: &mut Context<Self>,
1079 ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1080 let (tx, rx) = oneshot::channel();
1081
1082 let status = ToolCallStatus::WaitingForConfirmation {
1083 options,
1084 respond_tx: tx,
1085 };
1086
1087 self.upsert_tool_call_inner(tool_call, status, cx)?;
1088 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1089 Ok(rx)
1090 }
1091
1092 pub fn authorize_tool_call(
1093 &mut self,
1094 id: acp::ToolCallId,
1095 option_id: acp::PermissionOptionId,
1096 option_kind: acp::PermissionOptionKind,
1097 cx: &mut Context<Self>,
1098 ) {
1099 let Some((ix, call)) = self.tool_call_mut(&id) else {
1100 return;
1101 };
1102
1103 let new_status = match option_kind {
1104 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1105 ToolCallStatus::Rejected
1106 }
1107 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1108 ToolCallStatus::InProgress
1109 }
1110 };
1111
1112 let curr_status = mem::replace(&mut call.status, new_status);
1113
1114 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1115 respond_tx.send(option_id).log_err();
1116 } else if cfg!(debug_assertions) {
1117 panic!("tried to authorize an already authorized tool call");
1118 }
1119
1120 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1121 }
1122
1123 /// Returns true if the last turn is awaiting tool authorization
1124 pub fn waiting_for_tool_confirmation(&self) -> bool {
1125 for entry in self.entries.iter().rev() {
1126 match &entry {
1127 AgentThreadEntry::ToolCall(call) => match call.status {
1128 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1129 ToolCallStatus::Pending
1130 | ToolCallStatus::InProgress
1131 | ToolCallStatus::Completed
1132 | ToolCallStatus::Failed
1133 | ToolCallStatus::Rejected
1134 | ToolCallStatus::Canceled => continue,
1135 },
1136 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1137 // Reached the beginning of the turn
1138 return false;
1139 }
1140 }
1141 }
1142 false
1143 }
1144
1145 pub fn plan(&self) -> &Plan {
1146 &self.plan
1147 }
1148
1149 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1150 let new_entries_len = request.entries.len();
1151 let mut new_entries = request.entries.into_iter();
1152
1153 // Reuse existing markdown to prevent flickering
1154 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1155 let PlanEntry {
1156 content,
1157 priority,
1158 status,
1159 } = old;
1160 content.update(cx, |old, cx| {
1161 old.replace(new.content, cx);
1162 });
1163 *priority = new.priority;
1164 *status = new.status;
1165 }
1166 for new in new_entries {
1167 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1168 }
1169 self.plan.entries.truncate(new_entries_len);
1170
1171 cx.notify();
1172 }
1173
1174 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1175 self.plan
1176 .entries
1177 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1178 cx.notify();
1179 }
1180
1181 #[cfg(any(test, feature = "test-support"))]
1182 pub fn send_raw(
1183 &mut self,
1184 message: &str,
1185 cx: &mut Context<Self>,
1186 ) -> BoxFuture<'static, Result<()>> {
1187 self.send(
1188 vec![acp::ContentBlock::Text(acp::TextContent {
1189 text: message.to_string(),
1190 annotations: None,
1191 })],
1192 cx,
1193 )
1194 }
1195
1196 pub fn send(
1197 &mut self,
1198 message: Vec<acp::ContentBlock>,
1199 cx: &mut Context<Self>,
1200 ) -> BoxFuture<'static, Result<()>> {
1201 let block = ContentBlock::new_combined(
1202 message.clone(),
1203 self.project.read(cx).languages().clone(),
1204 cx,
1205 );
1206 let request = acp::PromptRequest {
1207 prompt: message.clone(),
1208 session_id: self.session_id.clone(),
1209 };
1210 let git_store = self.project.read(cx).git_store().clone();
1211
1212 let message_id = if self
1213 .connection
1214 .session_editor(&self.session_id, cx)
1215 .is_some()
1216 {
1217 Some(UserMessageId::new())
1218 } else {
1219 None
1220 };
1221 self.push_entry(
1222 AgentThreadEntry::UserMessage(UserMessage {
1223 id: message_id.clone(),
1224 content: block,
1225 chunks: message,
1226 checkpoint: None,
1227 }),
1228 cx,
1229 );
1230
1231 self.run_turn(cx, async move |this, cx| {
1232 let old_checkpoint = git_store
1233 .update(cx, |git, cx| git.checkpoint(cx))?
1234 .await
1235 .context("failed to get old checkpoint")
1236 .log_err();
1237 this.update(cx, |this, cx| {
1238 if let Some((_ix, message)) = this.last_user_message() {
1239 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1240 git_checkpoint,
1241 show: false,
1242 });
1243 }
1244 this.connection.prompt(message_id, request, cx)
1245 })?
1246 .await
1247 })
1248 }
1249
1250 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1251 self.run_turn(cx, async move |this, cx| {
1252 this.update(cx, |this, cx| {
1253 this.connection
1254 .resume(&this.session_id, cx)
1255 .map(|resume| resume.run(cx))
1256 })?
1257 .context("resuming a session is not supported")?
1258 .await
1259 })
1260 }
1261
1262 fn run_turn(
1263 &mut self,
1264 cx: &mut Context<Self>,
1265 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1266 ) -> BoxFuture<'static, Result<()>> {
1267 self.clear_completed_plan_entries(cx);
1268
1269 let (tx, rx) = oneshot::channel();
1270 let cancel_task = self.cancel(cx);
1271
1272 self.send_task = Some(cx.spawn(async move |this, cx| {
1273 cancel_task.await;
1274 tx.send(f(this, cx).await).ok();
1275 }));
1276
1277 cx.spawn(async move |this, cx| {
1278 let response = rx.await;
1279
1280 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1281 .await?;
1282
1283 this.update(cx, |this, cx| {
1284 match response {
1285 Ok(Err(e)) => {
1286 this.send_task.take();
1287 cx.emit(AcpThreadEvent::Error);
1288 Err(e)
1289 }
1290 result => {
1291 let canceled = matches!(
1292 result,
1293 Ok(Ok(acp::PromptResponse {
1294 stop_reason: acp::StopReason::Canceled
1295 }))
1296 );
1297
1298 // We only take the task if the current prompt wasn't canceled.
1299 //
1300 // This prompt may have been canceled because another one was sent
1301 // while it was still generating. In these cases, dropping `send_task`
1302 // would cause the next generation to be canceled.
1303 if !canceled {
1304 this.send_task.take();
1305 }
1306
1307 cx.emit(AcpThreadEvent::Stopped);
1308 Ok(())
1309 }
1310 }
1311 })?
1312 })
1313 .boxed()
1314 }
1315
1316 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1317 let Some(send_task) = self.send_task.take() else {
1318 return Task::ready(());
1319 };
1320
1321 for entry in self.entries.iter_mut() {
1322 if let AgentThreadEntry::ToolCall(call) = entry {
1323 let cancel = matches!(
1324 call.status,
1325 ToolCallStatus::Pending
1326 | ToolCallStatus::WaitingForConfirmation { .. }
1327 | ToolCallStatus::InProgress
1328 );
1329
1330 if cancel {
1331 call.status = ToolCallStatus::Canceled;
1332 }
1333 }
1334 }
1335
1336 self.connection.cancel(&self.session_id, cx);
1337
1338 // Wait for the send task to complete
1339 cx.foreground_executor().spawn(send_task)
1340 }
1341
1342 /// Rewinds this thread to before the entry at `index`, removing it and all
1343 /// subsequent entries while reverting any changes made from that point.
1344 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1345 let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1346 return Task::ready(Err(anyhow!("not supported")));
1347 };
1348 let Some(message) = self.user_message(&id) else {
1349 return Task::ready(Err(anyhow!("message not found")));
1350 };
1351
1352 let checkpoint = message
1353 .checkpoint
1354 .as_ref()
1355 .map(|c| c.git_checkpoint.clone());
1356
1357 let git_store = self.project.read(cx).git_store().clone();
1358 cx.spawn(async move |this, cx| {
1359 if let Some(checkpoint) = checkpoint {
1360 git_store
1361 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1362 .await?;
1363 }
1364
1365 cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1366 .await?;
1367 this.update(cx, |this, cx| {
1368 if let Some((ix, _)) = this.user_message_mut(&id) {
1369 let range = ix..this.entries.len();
1370 this.entries.truncate(ix);
1371 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1372 }
1373 })
1374 })
1375 }
1376
1377 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1378 let git_store = self.project.read(cx).git_store().clone();
1379
1380 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1381 if let Some(checkpoint) = message.checkpoint.as_ref() {
1382 checkpoint.git_checkpoint.clone()
1383 } else {
1384 return Task::ready(Ok(()));
1385 }
1386 } else {
1387 return Task::ready(Ok(()));
1388 };
1389
1390 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1391 cx.spawn(async move |this, cx| {
1392 let new_checkpoint = new_checkpoint
1393 .await
1394 .context("failed to get new checkpoint")
1395 .log_err();
1396 if let Some(new_checkpoint) = new_checkpoint {
1397 let equal = git_store
1398 .update(cx, |git, cx| {
1399 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1400 })?
1401 .await
1402 .unwrap_or(true);
1403 this.update(cx, |this, cx| {
1404 let (ix, message) = this.last_user_message().context("no user message")?;
1405 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1406 checkpoint.show = !equal;
1407 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1408 anyhow::Ok(())
1409 })??;
1410 }
1411
1412 Ok(())
1413 })
1414 }
1415
1416 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1417 self.entries
1418 .iter_mut()
1419 .enumerate()
1420 .rev()
1421 .find_map(|(ix, entry)| {
1422 if let AgentThreadEntry::UserMessage(message) = entry {
1423 Some((ix, message))
1424 } else {
1425 None
1426 }
1427 })
1428 }
1429
1430 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1431 self.entries.iter().find_map(|entry| {
1432 if let AgentThreadEntry::UserMessage(message) = entry {
1433 if message.id.as_ref() == Some(&id) {
1434 Some(message)
1435 } else {
1436 None
1437 }
1438 } else {
1439 None
1440 }
1441 })
1442 }
1443
1444 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1445 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1446 if let AgentThreadEntry::UserMessage(message) = entry {
1447 if message.id.as_ref() == Some(&id) {
1448 Some((ix, message))
1449 } else {
1450 None
1451 }
1452 } else {
1453 None
1454 }
1455 })
1456 }
1457
1458 pub fn read_text_file(
1459 &self,
1460 path: PathBuf,
1461 line: Option<u32>,
1462 limit: Option<u32>,
1463 reuse_shared_snapshot: bool,
1464 cx: &mut Context<Self>,
1465 ) -> Task<Result<String>> {
1466 let project = self.project.clone();
1467 let action_log = self.action_log.clone();
1468 cx.spawn(async move |this, cx| {
1469 let load = project.update(cx, |project, cx| {
1470 let path = project
1471 .project_path_for_absolute_path(&path, cx)
1472 .context("invalid path")?;
1473 anyhow::Ok(project.open_buffer(path, cx))
1474 });
1475 let buffer = load??.await?;
1476
1477 let snapshot = if reuse_shared_snapshot {
1478 this.read_with(cx, |this, _| {
1479 this.shared_buffers.get(&buffer.clone()).cloned()
1480 })
1481 .log_err()
1482 .flatten()
1483 } else {
1484 None
1485 };
1486
1487 let snapshot = if let Some(snapshot) = snapshot {
1488 snapshot
1489 } else {
1490 action_log.update(cx, |action_log, cx| {
1491 action_log.buffer_read(buffer.clone(), cx);
1492 })?;
1493 project.update(cx, |project, cx| {
1494 let position = buffer
1495 .read(cx)
1496 .snapshot()
1497 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1498 project.set_agent_location(
1499 Some(AgentLocation {
1500 buffer: buffer.downgrade(),
1501 position,
1502 }),
1503 cx,
1504 );
1505 })?;
1506
1507 buffer.update(cx, |buffer, _| buffer.snapshot())?
1508 };
1509
1510 this.update(cx, |this, _| {
1511 let text = snapshot.text();
1512 this.shared_buffers.insert(buffer.clone(), snapshot);
1513 if line.is_none() && limit.is_none() {
1514 return Ok(text);
1515 }
1516 let limit = limit.unwrap_or(u32::MAX) as usize;
1517 let Some(line) = line else {
1518 return Ok(text.lines().take(limit).collect::<String>());
1519 };
1520
1521 let count = text.lines().count();
1522 if count < line as usize {
1523 anyhow::bail!("There are only {} lines", count);
1524 }
1525 Ok(text
1526 .lines()
1527 .skip(line as usize + 1)
1528 .take(limit)
1529 .collect::<String>())
1530 })?
1531 })
1532 }
1533
1534 pub fn write_text_file(
1535 &self,
1536 path: PathBuf,
1537 content: String,
1538 cx: &mut Context<Self>,
1539 ) -> Task<Result<()>> {
1540 let project = self.project.clone();
1541 let action_log = self.action_log.clone();
1542 cx.spawn(async move |this, cx| {
1543 let load = project.update(cx, |project, cx| {
1544 let path = project
1545 .project_path_for_absolute_path(&path, cx)
1546 .context("invalid path")?;
1547 anyhow::Ok(project.open_buffer(path, cx))
1548 });
1549 let buffer = load??.await?;
1550 let snapshot = this.update(cx, |this, cx| {
1551 this.shared_buffers
1552 .get(&buffer)
1553 .cloned()
1554 .unwrap_or_else(|| buffer.read(cx).snapshot())
1555 })?;
1556 let edits = cx
1557 .background_executor()
1558 .spawn(async move {
1559 let old_text = snapshot.text();
1560 text_diff(old_text.as_str(), &content)
1561 .into_iter()
1562 .map(|(range, replacement)| {
1563 (
1564 snapshot.anchor_after(range.start)
1565 ..snapshot.anchor_before(range.end),
1566 replacement,
1567 )
1568 })
1569 .collect::<Vec<_>>()
1570 })
1571 .await;
1572 cx.update(|cx| {
1573 project.update(cx, |project, cx| {
1574 project.set_agent_location(
1575 Some(AgentLocation {
1576 buffer: buffer.downgrade(),
1577 position: edits
1578 .last()
1579 .map(|(range, _)| range.end)
1580 .unwrap_or(Anchor::MIN),
1581 }),
1582 cx,
1583 );
1584 });
1585
1586 action_log.update(cx, |action_log, cx| {
1587 action_log.buffer_read(buffer.clone(), cx);
1588 });
1589 buffer.update(cx, |buffer, cx| {
1590 buffer.edit(edits, None, cx);
1591 });
1592 action_log.update(cx, |action_log, cx| {
1593 action_log.buffer_edited(buffer.clone(), cx);
1594 });
1595 })?;
1596 project
1597 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1598 .await
1599 })
1600 }
1601
1602 pub fn to_markdown(&self, cx: &App) -> String {
1603 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1604 }
1605
1606 pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1607 cx.emit(AcpThreadEvent::ServerExited(status));
1608 }
1609}
1610
1611fn markdown_for_raw_output(
1612 raw_output: &serde_json::Value,
1613 language_registry: &Arc<LanguageRegistry>,
1614 cx: &mut App,
1615) -> Option<Entity<Markdown>> {
1616 match raw_output {
1617 serde_json::Value::Null => None,
1618 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1619 Markdown::new(
1620 value.to_string().into(),
1621 Some(language_registry.clone()),
1622 None,
1623 cx,
1624 )
1625 })),
1626 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1627 Markdown::new(
1628 value.to_string().into(),
1629 Some(language_registry.clone()),
1630 None,
1631 cx,
1632 )
1633 })),
1634 serde_json::Value::String(value) => Some(cx.new(|cx| {
1635 Markdown::new(
1636 value.clone().into(),
1637 Some(language_registry.clone()),
1638 None,
1639 cx,
1640 )
1641 })),
1642 value => Some(cx.new(|cx| {
1643 Markdown::new(
1644 format!("```json\n{}\n```", value).into(),
1645 Some(language_registry.clone()),
1646 None,
1647 cx,
1648 )
1649 })),
1650 }
1651}
1652
1653#[cfg(test)]
1654mod tests {
1655 use super::*;
1656 use anyhow::anyhow;
1657 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1658 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1659 use indoc::indoc;
1660 use project::{FakeFs, Fs};
1661 use rand::Rng as _;
1662 use serde_json::json;
1663 use settings::SettingsStore;
1664 use smol::stream::StreamExt as _;
1665 use std::{
1666 any::Any,
1667 cell::RefCell,
1668 path::Path,
1669 rc::Rc,
1670 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1671 time::Duration,
1672 };
1673 use util::path;
1674
1675 fn init_test(cx: &mut TestAppContext) {
1676 env_logger::try_init().ok();
1677 cx.update(|cx| {
1678 let settings_store = SettingsStore::test(cx);
1679 cx.set_global(settings_store);
1680 Project::init_settings(cx);
1681 language::init(cx);
1682 });
1683 }
1684
1685 #[gpui::test]
1686 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1687 init_test(cx);
1688
1689 let fs = FakeFs::new(cx.executor());
1690 let project = Project::test(fs, [], cx).await;
1691 let connection = Rc::new(FakeAgentConnection::new());
1692 let thread = cx
1693 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1694 .await
1695 .unwrap();
1696
1697 // Test creating a new user message
1698 thread.update(cx, |thread, cx| {
1699 thread.push_user_content_block(
1700 None,
1701 acp::ContentBlock::Text(acp::TextContent {
1702 annotations: None,
1703 text: "Hello, ".to_string(),
1704 }),
1705 cx,
1706 );
1707 });
1708
1709 thread.update(cx, |thread, cx| {
1710 assert_eq!(thread.entries.len(), 1);
1711 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1712 assert_eq!(user_msg.id, None);
1713 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1714 } else {
1715 panic!("Expected UserMessage");
1716 }
1717 });
1718
1719 // Test appending to existing user message
1720 let message_1_id = UserMessageId::new();
1721 thread.update(cx, |thread, cx| {
1722 thread.push_user_content_block(
1723 Some(message_1_id.clone()),
1724 acp::ContentBlock::Text(acp::TextContent {
1725 annotations: None,
1726 text: "world!".to_string(),
1727 }),
1728 cx,
1729 );
1730 });
1731
1732 thread.update(cx, |thread, cx| {
1733 assert_eq!(thread.entries.len(), 1);
1734 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1735 assert_eq!(user_msg.id, Some(message_1_id));
1736 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1737 } else {
1738 panic!("Expected UserMessage");
1739 }
1740 });
1741
1742 // Test creating new user message after assistant message
1743 thread.update(cx, |thread, cx| {
1744 thread.push_assistant_content_block(
1745 acp::ContentBlock::Text(acp::TextContent {
1746 annotations: None,
1747 text: "Assistant response".to_string(),
1748 }),
1749 false,
1750 cx,
1751 );
1752 });
1753
1754 let message_2_id = UserMessageId::new();
1755 thread.update(cx, |thread, cx| {
1756 thread.push_user_content_block(
1757 Some(message_2_id.clone()),
1758 acp::ContentBlock::Text(acp::TextContent {
1759 annotations: None,
1760 text: "New user message".to_string(),
1761 }),
1762 cx,
1763 );
1764 });
1765
1766 thread.update(cx, |thread, cx| {
1767 assert_eq!(thread.entries.len(), 3);
1768 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1769 assert_eq!(user_msg.id, Some(message_2_id));
1770 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1771 } else {
1772 panic!("Expected UserMessage at index 2");
1773 }
1774 });
1775 }
1776
1777 #[gpui::test]
1778 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1779 init_test(cx);
1780
1781 let fs = FakeFs::new(cx.executor());
1782 let project = Project::test(fs, [], cx).await;
1783 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1784 |_, thread, mut cx| {
1785 async move {
1786 thread.update(&mut cx, |thread, cx| {
1787 thread
1788 .handle_session_update(
1789 acp::SessionUpdate::AgentThoughtChunk {
1790 content: "Thinking ".into(),
1791 },
1792 cx,
1793 )
1794 .unwrap();
1795 thread
1796 .handle_session_update(
1797 acp::SessionUpdate::AgentThoughtChunk {
1798 content: "hard!".into(),
1799 },
1800 cx,
1801 )
1802 .unwrap();
1803 })?;
1804 Ok(acp::PromptResponse {
1805 stop_reason: acp::StopReason::EndTurn,
1806 })
1807 }
1808 .boxed_local()
1809 },
1810 ));
1811
1812 let thread = cx
1813 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1814 .await
1815 .unwrap();
1816
1817 thread
1818 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1819 .await
1820 .unwrap();
1821
1822 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1823 assert_eq!(
1824 output,
1825 indoc! {r#"
1826 ## User
1827
1828 Hello from Zed!
1829
1830 ## Assistant
1831
1832 <thinking>
1833 Thinking hard!
1834 </thinking>
1835
1836 "#}
1837 );
1838 }
1839
1840 #[gpui::test]
1841 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1842 init_test(cx);
1843
1844 let fs = FakeFs::new(cx.executor());
1845 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1846 .await;
1847 let project = Project::test(fs.clone(), [], cx).await;
1848 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1849 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1850 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1851 move |_, thread, mut cx| {
1852 let read_file_tx = read_file_tx.clone();
1853 async move {
1854 let content = thread
1855 .update(&mut cx, |thread, cx| {
1856 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1857 })
1858 .unwrap()
1859 .await
1860 .unwrap();
1861 assert_eq!(content, "one\ntwo\nthree\n");
1862 read_file_tx.take().unwrap().send(()).unwrap();
1863 thread
1864 .update(&mut cx, |thread, cx| {
1865 thread.write_text_file(
1866 path!("/tmp/foo").into(),
1867 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1868 cx,
1869 )
1870 })
1871 .unwrap()
1872 .await
1873 .unwrap();
1874 Ok(acp::PromptResponse {
1875 stop_reason: acp::StopReason::EndTurn,
1876 })
1877 }
1878 .boxed_local()
1879 },
1880 ));
1881
1882 let (worktree, pathbuf) = project
1883 .update(cx, |project, cx| {
1884 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1885 })
1886 .await
1887 .unwrap();
1888 let buffer = project
1889 .update(cx, |project, cx| {
1890 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1891 })
1892 .await
1893 .unwrap();
1894
1895 let thread = cx
1896 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1897 .await
1898 .unwrap();
1899
1900 let request = thread.update(cx, |thread, cx| {
1901 thread.send_raw("Extend the count in /tmp/foo", cx)
1902 });
1903 read_file_rx.await.ok();
1904 buffer.update(cx, |buffer, cx| {
1905 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1906 });
1907 cx.run_until_parked();
1908 assert_eq!(
1909 buffer.read_with(cx, |buffer, _| buffer.text()),
1910 "zero\none\ntwo\nthree\nfour\nfive\n"
1911 );
1912 assert_eq!(
1913 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1914 "zero\none\ntwo\nthree\nfour\nfive\n"
1915 );
1916 request.await.unwrap();
1917 }
1918
1919 #[gpui::test]
1920 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1921 init_test(cx);
1922
1923 let fs = FakeFs::new(cx.executor());
1924 let project = Project::test(fs, [], cx).await;
1925 let id = acp::ToolCallId("test".into());
1926
1927 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1928 let id = id.clone();
1929 move |_, thread, mut cx| {
1930 let id = id.clone();
1931 async move {
1932 thread
1933 .update(&mut cx, |thread, cx| {
1934 thread.handle_session_update(
1935 acp::SessionUpdate::ToolCall(acp::ToolCall {
1936 id: id.clone(),
1937 title: "Label".into(),
1938 kind: acp::ToolKind::Fetch,
1939 status: acp::ToolCallStatus::InProgress,
1940 content: vec![],
1941 locations: vec![],
1942 raw_input: None,
1943 raw_output: None,
1944 }),
1945 cx,
1946 )
1947 })
1948 .unwrap()
1949 .unwrap();
1950 Ok(acp::PromptResponse {
1951 stop_reason: acp::StopReason::EndTurn,
1952 })
1953 }
1954 .boxed_local()
1955 }
1956 }));
1957
1958 let thread = cx
1959 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1960 .await
1961 .unwrap();
1962
1963 let request = thread.update(cx, |thread, cx| {
1964 thread.send_raw("Fetch https://example.com", cx)
1965 });
1966
1967 run_until_first_tool_call(&thread, cx).await;
1968
1969 thread.read_with(cx, |thread, _| {
1970 assert!(matches!(
1971 thread.entries[1],
1972 AgentThreadEntry::ToolCall(ToolCall {
1973 status: ToolCallStatus::InProgress,
1974 ..
1975 })
1976 ));
1977 });
1978
1979 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1980
1981 thread.read_with(cx, |thread, _| {
1982 assert!(matches!(
1983 &thread.entries[1],
1984 AgentThreadEntry::ToolCall(ToolCall {
1985 status: ToolCallStatus::Canceled,
1986 ..
1987 })
1988 ));
1989 });
1990
1991 thread
1992 .update(cx, |thread, cx| {
1993 thread.handle_session_update(
1994 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1995 id,
1996 fields: acp::ToolCallUpdateFields {
1997 status: Some(acp::ToolCallStatus::Completed),
1998 ..Default::default()
1999 },
2000 }),
2001 cx,
2002 )
2003 })
2004 .unwrap();
2005
2006 request.await.unwrap();
2007
2008 thread.read_with(cx, |thread, _| {
2009 assert!(matches!(
2010 thread.entries[1],
2011 AgentThreadEntry::ToolCall(ToolCall {
2012 status: ToolCallStatus::Completed,
2013 ..
2014 })
2015 ));
2016 });
2017 }
2018
2019 #[gpui::test]
2020 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2021 init_test(cx);
2022 let fs = FakeFs::new(cx.background_executor.clone());
2023 fs.insert_tree(path!("/test"), json!({})).await;
2024 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2025
2026 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2027 move |_, thread, mut cx| {
2028 async move {
2029 thread
2030 .update(&mut cx, |thread, cx| {
2031 thread.handle_session_update(
2032 acp::SessionUpdate::ToolCall(acp::ToolCall {
2033 id: acp::ToolCallId("test".into()),
2034 title: "Label".into(),
2035 kind: acp::ToolKind::Edit,
2036 status: acp::ToolCallStatus::Completed,
2037 content: vec![acp::ToolCallContent::Diff {
2038 diff: acp::Diff {
2039 path: "/test/test.txt".into(),
2040 old_text: None,
2041 new_text: "foo".into(),
2042 },
2043 }],
2044 locations: vec![],
2045 raw_input: None,
2046 raw_output: None,
2047 }),
2048 cx,
2049 )
2050 })
2051 .unwrap()
2052 .unwrap();
2053 Ok(acp::PromptResponse {
2054 stop_reason: acp::StopReason::EndTurn,
2055 })
2056 }
2057 .boxed_local()
2058 }
2059 }));
2060
2061 let thread = cx
2062 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2063 .await
2064 .unwrap();
2065
2066 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2067 .await
2068 .unwrap();
2069
2070 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2071 }
2072
2073 #[gpui::test(iterations = 10)]
2074 async fn test_checkpoints(cx: &mut TestAppContext) {
2075 init_test(cx);
2076 let fs = FakeFs::new(cx.background_executor.clone());
2077 fs.insert_tree(
2078 path!("/test"),
2079 json!({
2080 ".git": {}
2081 }),
2082 )
2083 .await;
2084 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2085
2086 let simulate_changes = Arc::new(AtomicBool::new(true));
2087 let next_filename = Arc::new(AtomicUsize::new(0));
2088 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2089 let simulate_changes = simulate_changes.clone();
2090 let next_filename = next_filename.clone();
2091 let fs = fs.clone();
2092 move |request, thread, mut cx| {
2093 let fs = fs.clone();
2094 let simulate_changes = simulate_changes.clone();
2095 let next_filename = next_filename.clone();
2096 async move {
2097 if simulate_changes.load(SeqCst) {
2098 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2099 fs.write(Path::new(&filename), b"").await?;
2100 }
2101
2102 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2103 panic!("expected text content block");
2104 };
2105 thread.update(&mut cx, |thread, cx| {
2106 thread
2107 .handle_session_update(
2108 acp::SessionUpdate::AgentMessageChunk {
2109 content: content.text.to_uppercase().into(),
2110 },
2111 cx,
2112 )
2113 .unwrap();
2114 })?;
2115 Ok(acp::PromptResponse {
2116 stop_reason: acp::StopReason::EndTurn,
2117 })
2118 }
2119 .boxed_local()
2120 }
2121 }));
2122 let thread = cx
2123 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2124 .await
2125 .unwrap();
2126
2127 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2128 .await
2129 .unwrap();
2130 thread.read_with(cx, |thread, cx| {
2131 assert_eq!(
2132 thread.to_markdown(cx),
2133 indoc! {"
2134 ## User (checkpoint)
2135
2136 Lorem
2137
2138 ## Assistant
2139
2140 LOREM
2141
2142 "}
2143 );
2144 });
2145 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2146
2147 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2148 .await
2149 .unwrap();
2150 thread.read_with(cx, |thread, cx| {
2151 assert_eq!(
2152 thread.to_markdown(cx),
2153 indoc! {"
2154 ## User (checkpoint)
2155
2156 Lorem
2157
2158 ## Assistant
2159
2160 LOREM
2161
2162 ## User (checkpoint)
2163
2164 ipsum
2165
2166 ## Assistant
2167
2168 IPSUM
2169
2170 "}
2171 );
2172 });
2173 assert_eq!(
2174 fs.files(),
2175 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2176 );
2177
2178 // Checkpoint isn't stored when there are no changes.
2179 simulate_changes.store(false, SeqCst);
2180 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2181 .await
2182 .unwrap();
2183 thread.read_with(cx, |thread, cx| {
2184 assert_eq!(
2185 thread.to_markdown(cx),
2186 indoc! {"
2187 ## User (checkpoint)
2188
2189 Lorem
2190
2191 ## Assistant
2192
2193 LOREM
2194
2195 ## User (checkpoint)
2196
2197 ipsum
2198
2199 ## Assistant
2200
2201 IPSUM
2202
2203 ## User
2204
2205 dolor
2206
2207 ## Assistant
2208
2209 DOLOR
2210
2211 "}
2212 );
2213 });
2214 assert_eq!(
2215 fs.files(),
2216 vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2217 );
2218
2219 // Rewinding the conversation truncates the history and restores the checkpoint.
2220 thread
2221 .update(cx, |thread, cx| {
2222 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2223 panic!("unexpected entries {:?}", thread.entries)
2224 };
2225 thread.rewind(message.id.clone().unwrap(), cx)
2226 })
2227 .await
2228 .unwrap();
2229 thread.read_with(cx, |thread, cx| {
2230 assert_eq!(
2231 thread.to_markdown(cx),
2232 indoc! {"
2233 ## User (checkpoint)
2234
2235 Lorem
2236
2237 ## Assistant
2238
2239 LOREM
2240
2241 "}
2242 );
2243 });
2244 assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2245 }
2246
2247 async fn run_until_first_tool_call(
2248 thread: &Entity<AcpThread>,
2249 cx: &mut TestAppContext,
2250 ) -> usize {
2251 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2252
2253 let subscription = cx.update(|cx| {
2254 cx.subscribe(thread, move |thread, _, cx| {
2255 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2256 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2257 return tx.try_send(ix).unwrap();
2258 }
2259 }
2260 })
2261 });
2262
2263 select! {
2264 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2265 panic!("Timeout waiting for tool call")
2266 }
2267 ix = rx.next().fuse() => {
2268 drop(subscription);
2269 ix.unwrap()
2270 }
2271 }
2272 }
2273
2274 #[derive(Clone, Default)]
2275 struct FakeAgentConnection {
2276 auth_methods: Vec<acp::AuthMethod>,
2277 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2278 on_user_message: Option<
2279 Rc<
2280 dyn Fn(
2281 acp::PromptRequest,
2282 WeakEntity<AcpThread>,
2283 AsyncApp,
2284 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2285 + 'static,
2286 >,
2287 >,
2288 }
2289
2290 impl FakeAgentConnection {
2291 fn new() -> Self {
2292 Self {
2293 auth_methods: Vec::new(),
2294 on_user_message: None,
2295 sessions: Arc::default(),
2296 }
2297 }
2298
2299 #[expect(unused)]
2300 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2301 self.auth_methods = auth_methods;
2302 self
2303 }
2304
2305 fn on_user_message(
2306 mut self,
2307 handler: impl Fn(
2308 acp::PromptRequest,
2309 WeakEntity<AcpThread>,
2310 AsyncApp,
2311 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2312 + 'static,
2313 ) -> Self {
2314 self.on_user_message.replace(Rc::new(handler));
2315 self
2316 }
2317 }
2318
2319 impl AgentConnection for FakeAgentConnection {
2320 fn auth_methods(&self) -> &[acp::AuthMethod] {
2321 &self.auth_methods
2322 }
2323
2324 fn new_thread(
2325 self: Rc<Self>,
2326 project: Entity<Project>,
2327 _cwd: &Path,
2328 cx: &mut App,
2329 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2330 let session_id = acp::SessionId(
2331 rand::thread_rng()
2332 .sample_iter(&rand::distributions::Alphanumeric)
2333 .take(7)
2334 .map(char::from)
2335 .collect::<String>()
2336 .into(),
2337 );
2338 let thread =
2339 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2340 self.sessions.lock().insert(session_id, thread.downgrade());
2341 Task::ready(Ok(thread))
2342 }
2343
2344 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2345 if self.auth_methods().iter().any(|m| m.id == method) {
2346 Task::ready(Ok(()))
2347 } else {
2348 Task::ready(Err(anyhow!("Invalid Auth Method")))
2349 }
2350 }
2351
2352 fn prompt(
2353 &self,
2354 _id: Option<UserMessageId>,
2355 params: acp::PromptRequest,
2356 cx: &mut App,
2357 ) -> Task<gpui::Result<acp::PromptResponse>> {
2358 let sessions = self.sessions.lock();
2359 let thread = sessions.get(¶ms.session_id).unwrap();
2360 if let Some(handler) = &self.on_user_message {
2361 let handler = handler.clone();
2362 let thread = thread.clone();
2363 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2364 } else {
2365 Task::ready(Ok(acp::PromptResponse {
2366 stop_reason: acp::StopReason::EndTurn,
2367 }))
2368 }
2369 }
2370
2371 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2372 let sessions = self.sessions.lock();
2373 let thread = sessions.get(&session_id).unwrap().clone();
2374
2375 cx.spawn(async move |cx| {
2376 thread
2377 .update(cx, |thread, cx| thread.cancel(cx))
2378 .unwrap()
2379 .await
2380 })
2381 .detach();
2382 }
2383
2384 fn session_editor(
2385 &self,
2386 session_id: &acp::SessionId,
2387 _cx: &mut App,
2388 ) -> Option<Rc<dyn AgentSessionEditor>> {
2389 Some(Rc::new(FakeAgentSessionEditor {
2390 _session_id: session_id.clone(),
2391 }))
2392 }
2393
2394 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2395 self
2396 }
2397 }
2398
2399 struct FakeAgentSessionEditor {
2400 _session_id: acp::SessionId,
2401 }
2402
2403 impl AgentSessionEditor for FakeAgentSessionEditor {
2404 fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2405 Task::ready(Ok(()))
2406 }
2407 }
2408}