acp: Fix issue with mentions when `embedded_context` is set to `false` (#42260)

Bennet Bo Fenner created

Release Notes:

- acp: Fixed an issue where Zed would not respect
`PromptCapabilities::embedded_context`

Change summary

crates/agent_ui/src/acp/message_editor.rs | 305 ++++++++++++++++--------
1 file changed, 199 insertions(+), 106 deletions(-)

Detailed changes

crates/agent_ui/src/acp/message_editor.rs 🔗

@@ -356,7 +356,7 @@ impl MessageEditor {
 
         let task = match mention_uri.clone() {
             MentionUri::Fetch { url } => self.confirm_mention_for_fetch(url, cx),
-            MentionUri::Directory { .. } => Task::ready(Ok(Mention::UriOnly)),
+            MentionUri::Directory { .. } => Task::ready(Ok(Mention::Link)),
             MentionUri::Thread { id, .. } => self.confirm_mention_for_thread(id, cx),
             MentionUri::TextThread { path, .. } => self.confirm_mention_for_text_thread(path, cx),
             MentionUri::File { abs_path } => self.confirm_mention_for_file(abs_path, cx),
@@ -373,7 +373,6 @@ impl MessageEditor {
                 )))
             }
             MentionUri::Selection { .. } => {
-                // Handled elsewhere
                 debug_panic!("unexpected selection URI");
                 Task::ready(Err(anyhow!("unexpected selection URI")))
             }
@@ -704,13 +703,11 @@ impl MessageEditor {
             return Task::ready(Err(err));
         }
 
-        let contents = self.mention_set.contents(
-            &self.prompt_capabilities.borrow(),
-            full_mention_content,
-            self.project.clone(),
-            cx,
-        );
+        let contents = self
+            .mention_set
+            .contents(full_mention_content, self.project.clone(), cx);
         let editor = self.editor.clone();
+        let supports_embedded_context = self.prompt_capabilities.borrow().embedded_context;
 
         cx.spawn(async move |_, cx| {
             let contents = contents.await?;
@@ -741,18 +738,32 @@ impl MessageEditor {
                                 tracked_buffers,
                             } => {
                                 all_tracked_buffers.extend(tracked_buffers.iter().cloned());
-                                acp::ContentBlock::Resource(acp::EmbeddedResource {
-                                    annotations: None,
-                                    resource: acp::EmbeddedResourceResource::TextResourceContents(
-                                        acp::TextResourceContents {
-                                            mime_type: None,
-                                            text: content.clone(),
-                                            uri: uri.to_uri().to_string(),
-                                            meta: None,
-                                        },
-                                    ),
-                                    meta: None,
-                                })
+                                if supports_embedded_context {
+                                    acp::ContentBlock::Resource(acp::EmbeddedResource {
+                                        annotations: None,
+                                        resource:
+                                            acp::EmbeddedResourceResource::TextResourceContents(
+                                                acp::TextResourceContents {
+                                                    mime_type: None,
+                                                    text: content.clone(),
+                                                    uri: uri.to_uri().to_string(),
+                                                    meta: None,
+                                                },
+                                            ),
+                                        meta: None,
+                                    })
+                                } else {
+                                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                                        name: uri.name(),
+                                        uri: uri.to_uri().to_string(),
+                                        annotations: None,
+                                        description: None,
+                                        mime_type: None,
+                                        size: None,
+                                        title: None,
+                                        meta: None,
+                                    })
+                                }
                             }
                             Mention::Image(mention_image) => {
                                 let uri = match uri {
@@ -774,18 +785,16 @@ impl MessageEditor {
                                     meta: None,
                                 })
                             }
-                            Mention::UriOnly => {
-                                acp::ContentBlock::ResourceLink(acp::ResourceLink {
-                                    name: uri.name(),
-                                    uri: uri.to_uri().to_string(),
-                                    annotations: None,
-                                    description: None,
-                                    mime_type: None,
-                                    size: None,
-                                    title: None,
-                                    meta: None,
-                                })
-                            }
+                            Mention::Link => acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                                name: uri.name(),
+                                uri: uri.to_uri().to_string(),
+                                annotations: None,
+                                description: None,
+                                mime_type: None,
+                                size: None,
+                                title: None,
+                                meta: None,
+                            }),
                         };
                         chunks.push(chunk);
                         ix = crease_range.end;
@@ -1114,7 +1123,7 @@ impl MessageEditor {
                         let start = text.len();
                         write!(&mut text, "{}", mention_uri.as_link()).ok();
                         let end = text.len();
-                        mentions.push((start..end, mention_uri, Mention::UriOnly));
+                        mentions.push((start..end, mention_uri, Mention::Link));
                     }
                 }
                 acp::ContentBlock::Image(acp::ImageContent {
@@ -1520,7 +1529,7 @@ pub enum Mention {
         tracked_buffers: Vec<Entity<Buffer>>,
     },
     Image(MentionImage),
-    UriOnly,
+    Link,
 }
 
 #[derive(Clone, Debug, Eq, PartialEq)]
@@ -1537,21 +1546,10 @@ pub struct MentionSet {
 impl MentionSet {
     fn contents(
         &self,
-        prompt_capabilities: &acp::PromptCapabilities,
         full_mention_content: bool,
         project: Entity<Project>,
         cx: &mut App,
     ) -> Task<Result<HashMap<CreaseId, (MentionUri, Mention)>>> {
-        if !prompt_capabilities.embedded_context {
-            let mentions = self
-                .mentions
-                .iter()
-                .map(|(crease_id, (uri, _))| (*crease_id, (uri.clone(), Mention::UriOnly)))
-                .collect();
-
-            return Task::ready(Ok(mentions));
-        }
-
         let mentions = self.mentions.clone();
         cx.spawn(async move |cx| {
             let mut contents = HashMap::default();
@@ -2285,21 +2283,11 @@ mod tests {
             assert_eq!(fold_ranges(editor, cx).len(), 1);
         });
 
-        let all_prompt_capabilities = acp::PromptCapabilities {
-            image: true,
-            audio: true,
-            embedded_context: true,
-            meta: None,
-        };
-
         let contents = message_editor
             .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &all_prompt_capabilities,
-                    false,
-                    project.clone(),
-                    cx,
-                )
+                message_editor
+                    .mention_set()
+                    .contents(false, project.clone(), cx)
             })
             .await
             .unwrap()
@@ -2317,30 +2305,6 @@ mod tests {
             );
         }
 
-        let contents = message_editor
-            .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &acp::PromptCapabilities::default(),
-                    false,
-                    project.clone(),
-                    cx,
-                )
-            })
-            .await
-            .unwrap()
-            .into_values()
-            .collect::<Vec<_>>();
-
-        {
-            let [(uri, Mention::UriOnly)] = contents.as_slice() else {
-                panic!("Unexpected mentions");
-            };
-            pretty_assertions::assert_eq!(
-                uri,
-                &MentionUri::parse(&url_one, PathStyle::local()).unwrap()
-            );
-        }
-
         cx.simulate_input(" ");
 
         editor.update(&mut cx, |editor, cx| {
@@ -2376,12 +2340,9 @@ mod tests {
 
         let contents = message_editor
             .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &all_prompt_capabilities,
-                    false,
-                    project.clone(),
-                    cx,
-                )
+                message_editor
+                    .mention_set()
+                    .contents(false, project.clone(), cx)
             })
             .await
             .unwrap()
@@ -2502,12 +2463,9 @@ mod tests {
 
         let contents = message_editor
             .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &all_prompt_capabilities,
-                    false,
-                    project.clone(),
-                    cx,
-                )
+                message_editor
+                    .mention_set()
+                    .contents(false, project.clone(), cx)
             })
             .await
             .unwrap()
@@ -2553,12 +2511,9 @@ mod tests {
         // Getting the message contents fails
         message_editor
             .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &all_prompt_capabilities,
-                    false,
-                    project.clone(),
-                    cx,
-                )
+                message_editor
+                    .mention_set()
+                    .contents(false, project.clone(), cx)
             })
             .await
             .expect_err("Should fail to load x.png");
@@ -2609,12 +2564,9 @@ mod tests {
         // Now getting the contents succeeds, because the invalid mention was removed
         let contents = message_editor
             .update(&mut cx, |message_editor, cx| {
-                message_editor.mention_set().contents(
-                    &all_prompt_capabilities,
-                    false,
-                    project.clone(),
-                    cx,
-                )
+                message_editor
+                    .mention_set()
+                    .contents(false, project.clone(), cx)
             })
             .await
             .unwrap();
@@ -2896,6 +2848,147 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_editor_respects_embedded_context_capability(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+
+        let file_content = "fn main() { println!(\"Hello, world!\"); }\n";
+
+        fs.insert_tree(
+            "/project",
+            json!({
+                "src": {
+                    "main.rs": file_content,
+                }
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
+
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
+        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
+
+        let (message_editor, editor) = workspace.update_in(cx, |workspace, window, cx| {
+            let workspace_handle = cx.weak_entity();
+            let message_editor = cx.new(|cx| {
+                MessageEditor::new(
+                    workspace_handle,
+                    project.clone(),
+                    history_store.clone(),
+                    None,
+                    Default::default(),
+                    Default::default(),
+                    "Test Agent".into(),
+                    "Test",
+                    EditorMode::AutoHeight {
+                        max_lines: None,
+                        min_lines: 1,
+                    },
+                    window,
+                    cx,
+                )
+            });
+            workspace.active_pane().update(cx, |pane, cx| {
+                pane.add_item(
+                    Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))),
+                    true,
+                    true,
+                    None,
+                    window,
+                    cx,
+                );
+            });
+            message_editor.read(cx).focus_handle(cx).focus(window);
+            let editor = message_editor.read(cx).editor().clone();
+            (message_editor, editor)
+        });
+
+        cx.simulate_input("What is in @file main");
+
+        editor.update_in(cx, |editor, window, cx| {
+            assert!(editor.has_visible_completions_menu());
+            assert_eq!(editor.text(cx), "What is in @file main");
+            editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx);
+        });
+
+        let content = message_editor
+            .update(cx, |editor, cx| editor.contents(false, cx))
+            .await
+            .unwrap()
+            .0;
+
+        let main_rs_uri = if cfg!(windows) {
+            "file:///C:/project/src/main.rs".to_string()
+        } else {
+            "file:///project/src/main.rs".to_string()
+        };
+
+        // When embedded context is `false` we should get a resource link
+        pretty_assertions::assert_eq!(
+            content,
+            vec![
+                acp::ContentBlock::Text(acp::TextContent {
+                    text: "What is in ".to_string(),
+                    annotations: None,
+                    meta: None
+                }),
+                acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                    uri: main_rs_uri.clone(),
+                    name: "main.rs".to_string(),
+                    annotations: None,
+                    meta: None,
+                    description: None,
+                    mime_type: None,
+                    size: None,
+                    title: None,
+                })
+            ]
+        );
+
+        message_editor.update(cx, |editor, _cx| {
+            editor.prompt_capabilities.replace(acp::PromptCapabilities {
+                embedded_context: true,
+                ..Default::default()
+            })
+        });
+
+        let content = message_editor
+            .update(cx, |editor, cx| editor.contents(false, cx))
+            .await
+            .unwrap()
+            .0;
+
+        // When embedded context is `true` we should get a resource
+        pretty_assertions::assert_eq!(
+            content,
+            vec![
+                acp::ContentBlock::Text(acp::TextContent {
+                    text: "What is in ".to_string(),
+                    annotations: None,
+                    meta: None
+                }),
+                acp::ContentBlock::Resource(acp::EmbeddedResource {
+                    resource: acp::EmbeddedResourceResource::TextResourceContents(
+                        acp::TextResourceContents {
+                            text: file_content.to_string(),
+                            uri: main_rs_uri,
+                            mime_type: None,
+                            meta: None
+                        }
+                    ),
+                    annotations: None,
+                    meta: None
+                })
+            ]
+        );
+    }
+
     #[gpui::test]
     async fn test_autoscroll_after_insert_selections(cx: &mut TestAppContext) {
         init_test(cx);