message editor: Only allow types of content the agent can handle (#36616)

Agus Zubiaga created

Uses the new
[`acp::PromptCapabilities`](https://github.com/zed-industries/agent-client-protocol/blob/a39b7f635d67528f0a4e05e086ab283b9fc5cb93/rust/agent.rs#L194-L215)
to disable non-file mentions and images for agents that don't support
them.

Release Notes:

- N/A

Change summary

Cargo.lock                                     |   4 
Cargo.toml                                     |   2 
crates/acp_thread/src/acp_thread.rs            |   8 +
crates/acp_thread/src/connection.rs            |  10 +
crates/agent2/src/agent.rs                     |   8 +
crates/agent_servers/src/acp/v0.rs             |   8 +
crates/agent_servers/src/acp/v1.rs             |   6 
crates/agent_servers/src/claude.rs             |   8 +
crates/agent_ui/src/acp/completion_provider.rs | 124 +++++++++++++------
crates/agent_ui/src/acp/message_editor.rs      |  93 +++++++++++++-
crates/agent_ui/src/acp/thread_view.rs         |  13 ++
11 files changed, 234 insertions(+), 50 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -171,9 +171,9 @@ dependencies = [
 
 [[package]]
 name = "agent-client-protocol"
-version = "0.0.26"
+version = "0.0.28"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "160971bb53ca0b2e70ebc857c21e24eb448745f1396371015f4c59e9a9e51ed0"
+checksum = "4c887e795097665ab95119580534e7cc1335b59e1a7fec296943e534b970f4ed"
 dependencies = [
  "anyhow",
  "futures 0.3.31",

Cargo.toml 🔗

@@ -423,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
 #
 
 agentic-coding-protocol = "0.0.10"
-agent-client-protocol = "0.0.26"
+agent-client-protocol = "0.0.28"
 aho-corasick = "1.1"
 alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
 any_vec = "0.14"

crates/acp_thread/src/acp_thread.rs 🔗

@@ -2598,6 +2598,14 @@ mod tests {
             }
         }
 
+        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+            acp::PromptCapabilities {
+                image: true,
+                audio: true,
+                embedded_context: true,
+            }
+        }
+
         fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
             let sessions = self.sessions.lock();
             let thread = sessions.get(session_id).unwrap().clone();

crates/acp_thread/src/connection.rs 🔗

@@ -38,6 +38,8 @@ pub trait AgentConnection {
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>>;
 
+    fn prompt_capabilities(&self) -> acp::PromptCapabilities;
+
     fn resume(
         &self,
         _session_id: &acp::SessionId,
@@ -334,6 +336,14 @@ mod test_support {
             Task::ready(Ok(thread))
         }
 
+        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+            acp::PromptCapabilities {
+                image: true,
+                audio: true,
+                embedded_context: true,
+            }
+        }
+
         fn authenticate(
             &self,
             _method_id: acp::AuthMethodId,

crates/agent2/src/agent.rs 🔗

@@ -913,6 +913,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         })
     }
 
+    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+        acp::PromptCapabilities {
+            image: true,
+            audio: false,
+            embedded_context: true,
+        }
+    }
+
     fn resume(
         &self,
         session_id: &acp::SessionId,

crates/agent_servers/src/acp/v0.rs 🔗

@@ -498,6 +498,14 @@ impl AgentConnection for AcpConnection {
         })
     }
 
+    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+        acp::PromptCapabilities {
+            image: false,
+            audio: false,
+            embedded_context: false,
+        }
+    }
+
     fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
         let task = self
             .connection

crates/agent_servers/src/acp/v1.rs 🔗

@@ -21,6 +21,7 @@ pub struct AcpConnection {
     connection: Rc<acp::ClientSideConnection>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
     auth_methods: Vec<acp::AuthMethod>,
+    prompt_capabilities: acp::PromptCapabilities,
     _io_task: Task<Result<()>>,
 }
 
@@ -119,6 +120,7 @@ impl AcpConnection {
             connection: connection.into(),
             server_name,
             sessions,
+            prompt_capabilities: response.agent_capabilities.prompt_capabilities,
             _io_task: io_task,
         })
     }
@@ -206,6 +208,10 @@ impl AgentConnection for AcpConnection {
         })
     }
 
+    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+        self.prompt_capabilities
+    }
+
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
         let conn = self.connection.clone();
         let params = acp::CancelNotification {

crates/agent_servers/src/claude.rs 🔗

@@ -319,6 +319,14 @@ impl AgentConnection for ClaudeAgentConnection {
         cx.foreground_executor().spawn(async move { end_rx.await? })
     }
 
+    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+        acp::PromptCapabilities {
+            image: true,
+            audio: false,
+            embedded_context: true,
+        }
+    }
+
     fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
         let sessions = self.sessions.borrow();
         let Some(session) = sessions.get(session_id) else {

crates/agent_ui/src/acp/completion_provider.rs 🔗

@@ -1,8 +1,11 @@
+use std::cell::Cell;
 use std::ops::Range;
+use std::rc::Rc;
 use std::sync::Arc;
 use std::sync::atomic::AtomicBool;
 
 use acp_thread::MentionUri;
+use agent_client_protocol as acp;
 use agent2::{HistoryEntry, HistoryStore};
 use anyhow::Result;
 use editor::{CompletionProvider, Editor, ExcerptId};
@@ -63,6 +66,7 @@ pub struct ContextPickerCompletionProvider {
     workspace: WeakEntity<Workspace>,
     history_store: Entity<HistoryStore>,
     prompt_store: Option<Entity<PromptStore>>,
+    prompt_capabilities: Rc<Cell<acp::PromptCapabilities>>,
 }
 
 impl ContextPickerCompletionProvider {
@@ -71,12 +75,14 @@ impl ContextPickerCompletionProvider {
         workspace: WeakEntity<Workspace>,
         history_store: Entity<HistoryStore>,
         prompt_store: Option<Entity<PromptStore>>,
+        prompt_capabilities: Rc<Cell<acp::PromptCapabilities>>,
     ) -> Self {
         Self {
             message_editor,
             workspace,
             history_store,
             prompt_store,
+            prompt_capabilities,
         }
     }
 
@@ -544,17 +550,19 @@ impl ContextPickerCompletionProvider {
                 }),
         );
 
-        const RECENT_COUNT: usize = 2;
-        let threads = self
-            .history_store
-            .read(cx)
-            .recently_opened_entries(cx)
-            .into_iter()
-            .filter(|thread| !mentions.contains(&thread.mention_uri()))
-            .take(RECENT_COUNT)
-            .collect::<Vec<_>>();
-
-        recent.extend(threads.into_iter().map(Match::RecentThread));
+        if self.prompt_capabilities.get().embedded_context {
+            const RECENT_COUNT: usize = 2;
+            let threads = self
+                .history_store
+                .read(cx)
+                .recently_opened_entries(cx)
+                .into_iter()
+                .filter(|thread| !mentions.contains(&thread.mention_uri()))
+                .take(RECENT_COUNT)
+                .collect::<Vec<_>>();
+
+            recent.extend(threads.into_iter().map(Match::RecentThread));
+        }
 
         recent
     }
@@ -564,11 +572,17 @@ impl ContextPickerCompletionProvider {
         workspace: &Entity<Workspace>,
         cx: &mut App,
     ) -> Vec<ContextPickerEntry> {
-        let mut entries = vec![
-            ContextPickerEntry::Mode(ContextPickerMode::File),
-            ContextPickerEntry::Mode(ContextPickerMode::Symbol),
-            ContextPickerEntry::Mode(ContextPickerMode::Thread),
-        ];
+        let embedded_context = self.prompt_capabilities.get().embedded_context;
+        let mut entries = if embedded_context {
+            vec![
+                ContextPickerEntry::Mode(ContextPickerMode::File),
+                ContextPickerEntry::Mode(ContextPickerMode::Symbol),
+                ContextPickerEntry::Mode(ContextPickerMode::Thread),
+            ]
+        } else {
+            // File is always available, but we don't need a mode entry
+            vec![]
+        };
 
         let has_selection = workspace
             .read(cx)
@@ -583,11 +597,13 @@ impl ContextPickerCompletionProvider {
             ));
         }
 
-        if self.prompt_store.is_some() {
-            entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
-        }
+        if embedded_context {
+            if self.prompt_store.is_some() {
+                entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
+            }
 
-        entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch));
+            entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch));
+        }
 
         entries
     }
@@ -625,7 +641,11 @@ impl CompletionProvider for ContextPickerCompletionProvider {
             let offset_to_line = buffer.point_to_offset(line_start);
             let mut lines = buffer.text_for_range(line_start..position).lines();
             let line = lines.next()?;
-            MentionCompletion::try_parse(line, offset_to_line)
+            MentionCompletion::try_parse(
+                self.prompt_capabilities.get().embedded_context,
+                line,
+                offset_to_line,
+            )
         });
         let Some(state) = state else {
             return Task::ready(Ok(Vec::new()));
@@ -745,12 +765,16 @@ impl CompletionProvider for ContextPickerCompletionProvider {
         let offset_to_line = buffer.point_to_offset(line_start);
         let mut lines = buffer.text_for_range(line_start..position).lines();
         if let Some(line) = lines.next() {
-            MentionCompletion::try_parse(line, offset_to_line)
-                .map(|completion| {
-                    completion.source_range.start <= offset_to_line + position.column as usize
-                        && completion.source_range.end >= offset_to_line + position.column as usize
-                })
-                .unwrap_or(false)
+            MentionCompletion::try_parse(
+                self.prompt_capabilities.get().embedded_context,
+                line,
+                offset_to_line,
+            )
+            .map(|completion| {
+                completion.source_range.start <= offset_to_line + position.column as usize
+                    && completion.source_range.end >= offset_to_line + position.column as usize
+            })
+            .unwrap_or(false)
         } else {
             false
         }
@@ -841,7 +865,7 @@ struct MentionCompletion {
 }
 
 impl MentionCompletion {
-    fn try_parse(line: &str, offset_to_line: usize) -> Option<Self> {
+    fn try_parse(allow_non_file_mentions: bool, line: &str, offset_to_line: usize) -> Option<Self> {
         let last_mention_start = line.rfind('@')?;
         if last_mention_start >= line.len() {
             return Some(Self::default());
@@ -865,7 +889,9 @@ impl MentionCompletion {
         if let Some(mode_text) = parts.next() {
             end += mode_text.len();
 
-            if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() {
+            if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok()
+                && (allow_non_file_mentions || matches!(parsed_mode, ContextPickerMode::File))
+            {
                 mode = Some(parsed_mode);
             } else {
                 argument = Some(mode_text.to_string());
@@ -898,10 +924,10 @@ mod tests {
 
     #[test]
     fn test_mention_completion_parse() {
-        assert_eq!(MentionCompletion::try_parse("Lorem Ipsum", 0), None);
+        assert_eq!(MentionCompletion::try_parse(true, "Lorem Ipsum", 0), None);
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @", 0),
+            MentionCompletion::try_parse(true, "Lorem @", 0),
             Some(MentionCompletion {
                 source_range: 6..7,
                 mode: None,
@@ -910,7 +936,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @file", 0),
+            MentionCompletion::try_parse(true, "Lorem @file", 0),
             Some(MentionCompletion {
                 source_range: 6..11,
                 mode: Some(ContextPickerMode::File),
@@ -919,7 +945,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @file ", 0),
+            MentionCompletion::try_parse(true, "Lorem @file ", 0),
             Some(MentionCompletion {
                 source_range: 6..12,
                 mode: Some(ContextPickerMode::File),
@@ -928,7 +954,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @file main.rs", 0),
+            MentionCompletion::try_parse(true, "Lorem @file main.rs", 0),
             Some(MentionCompletion {
                 source_range: 6..19,
                 mode: Some(ContextPickerMode::File),
@@ -937,7 +963,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @file main.rs ", 0),
+            MentionCompletion::try_parse(true, "Lorem @file main.rs ", 0),
             Some(MentionCompletion {
                 source_range: 6..19,
                 mode: Some(ContextPickerMode::File),
@@ -946,7 +972,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @file main.rs Ipsum", 0),
+            MentionCompletion::try_parse(true, "Lorem @file main.rs Ipsum", 0),
             Some(MentionCompletion {
                 source_range: 6..19,
                 mode: Some(ContextPickerMode::File),
@@ -955,7 +981,7 @@ mod tests {
         );
 
         assert_eq!(
-            MentionCompletion::try_parse("Lorem @main", 0),
+            MentionCompletion::try_parse(true, "Lorem @main", 0),
             Some(MentionCompletion {
                 source_range: 6..11,
                 mode: None,
@@ -963,6 +989,28 @@ mod tests {
             })
         );
 
-        assert_eq!(MentionCompletion::try_parse("test@", 0), None);
+        assert_eq!(MentionCompletion::try_parse(true, "test@", 0), None);
+
+        // Allowed non-file mentions
+
+        assert_eq!(
+            MentionCompletion::try_parse(true, "Lorem @symbol main", 0),
+            Some(MentionCompletion {
+                source_range: 6..18,
+                mode: Some(ContextPickerMode::Symbol),
+                argument: Some("main".to_string()),
+            })
+        );
+
+        // Disallowed non-file mentions
+
+        assert_eq!(
+            MentionCompletion::try_parse(false, "Lorem @symbol main", 0),
+            Some(MentionCompletion {
+                source_range: 6..18,
+                mode: None,
+                argument: Some("main".to_string()),
+            })
+        );
     }
 }

crates/agent_ui/src/acp/message_editor.rs 🔗

@@ -51,7 +51,10 @@ use ui::{
 };
 use url::Url;
 use util::ResultExt;
-use workspace::{Workspace, notifications::NotifyResultExt as _};
+use workspace::{
+    Toast, Workspace,
+    notifications::{NotificationId, NotifyResultExt as _},
+};
 use zed_actions::agent::Chat;
 
 const PARSE_SLASH_COMMAND_DEBOUNCE: Duration = Duration::from_millis(50);
@@ -64,6 +67,7 @@ pub struct MessageEditor {
     history_store: Entity<HistoryStore>,
     prompt_store: Option<Entity<PromptStore>>,
     prevent_slash_commands: bool,
+    prompt_capabilities: Rc<Cell<acp::PromptCapabilities>>,
     _subscriptions: Vec<Subscription>,
     _parse_slash_command_task: Task<()>,
 }
@@ -96,11 +100,13 @@ impl MessageEditor {
             },
             None,
         );
+        let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default()));
         let completion_provider = ContextPickerCompletionProvider::new(
             cx.weak_entity(),
             workspace.clone(),
             history_store.clone(),
             prompt_store.clone(),
+            prompt_capabilities.clone(),
         );
         let semantics_provider = Rc::new(SlashCommandSemanticsProvider {
             range: Cell::new(None),
@@ -158,6 +164,7 @@ impl MessageEditor {
             history_store,
             prompt_store,
             prevent_slash_commands,
+            prompt_capabilities,
             _subscriptions: subscriptions,
             _parse_slash_command_task: Task::ready(()),
         }
@@ -193,6 +200,10 @@ impl MessageEditor {
         .detach();
     }
 
+    pub fn set_prompt_capabilities(&mut self, capabilities: acp::PromptCapabilities) {
+        self.prompt_capabilities.set(capabilities);
+    }
+
     #[cfg(test)]
     pub(crate) fn editor(&self) -> &Entity<Editor> {
         &self.editor
@@ -230,7 +241,7 @@ impl MessageEditor {
         let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else {
             return Task::ready(());
         };
-        let Some(anchor) = snapshot
+        let Some(start_anchor) = snapshot
             .buffer_snapshot
             .anchor_in_excerpt(*excerpt_id, start)
         else {
@@ -244,6 +255,33 @@ impl MessageEditor {
                 .unwrap_or_default();
 
             if Img::extensions().contains(&extension) && !extension.contains("svg") {
+                if !self.prompt_capabilities.get().image {
+                    struct ImagesNotAllowed;
+
+                    let end_anchor = snapshot.buffer_snapshot.anchor_before(
+                        start_anchor.to_offset(&snapshot.buffer_snapshot) + content_len + 1,
+                    );
+
+                    self.editor.update(cx, |editor, cx| {
+                        // Remove mention
+                        editor.edit([((start_anchor..end_anchor), "")], cx);
+                    });
+
+                    self.workspace
+                        .update(cx, |workspace, cx| {
+                            workspace.show_toast(
+                                Toast::new(
+                                    NotificationId::unique::<ImagesNotAllowed>(),
+                                    "This agent does not support images yet",
+                                )
+                                .autohide(),
+                                cx,
+                            );
+                        })
+                        .ok();
+                    return Task::ready(());
+                }
+
                 let project = self.project.clone();
                 let Some(project_path) = project
                     .read(cx)
@@ -277,7 +315,7 @@ impl MessageEditor {
                 };
                 return self.confirm_mention_for_image(
                     crease_id,
-                    anchor,
+                    start_anchor,
                     Some(abs_path.clone()),
                     image,
                     window,
@@ -301,17 +339,22 @@ impl MessageEditor {
 
         match mention_uri {
             MentionUri::Fetch { url } => {
-                self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx)
+                self.confirm_mention_for_fetch(crease_id, start_anchor, url, window, cx)
             }
             MentionUri::Directory { abs_path } => {
-                self.confirm_mention_for_directory(crease_id, anchor, abs_path, window, cx)
+                self.confirm_mention_for_directory(crease_id, start_anchor, abs_path, window, cx)
             }
             MentionUri::Thread { id, name } => {
-                self.confirm_mention_for_thread(crease_id, anchor, id, name, window, cx)
-            }
-            MentionUri::TextThread { path, name } => {
-                self.confirm_mention_for_text_thread(crease_id, anchor, path, name, window, cx)
+                self.confirm_mention_for_thread(crease_id, start_anchor, id, name, window, cx)
             }
+            MentionUri::TextThread { path, name } => self.confirm_mention_for_text_thread(
+                crease_id,
+                start_anchor,
+                path,
+                name,
+                window,
+                cx,
+            ),
             MentionUri::File { .. }
             | MentionUri::Symbol { .. }
             | MentionUri::Rule { .. }
@@ -778,6 +821,10 @@ impl MessageEditor {
     }
 
     fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) {
+        if !self.prompt_capabilities.get().image {
+            return;
+        }
+
         let images = cx
             .read_from_clipboard()
             .map(|item| {
@@ -2009,6 +2056,34 @@ mod tests {
             (message_editor, editor)
         });
 
+        cx.simulate_input("Lorem @");
+
+        editor.update_in(&mut cx, |editor, window, cx| {
+            assert_eq!(editor.text(cx), "Lorem @");
+            assert!(editor.has_visible_completions_menu());
+
+            // Only files since we have default capabilities
+            assert_eq!(
+                current_completion_labels(editor),
+                &[
+                    "eight.txt dir/b/",
+                    "seven.txt dir/b/",
+                    "six.txt dir/b/",
+                    "five.txt dir/b/",
+                ]
+            );
+            editor.set_text("", window, cx);
+        });
+
+        message_editor.update(&mut cx, |editor, _cx| {
+            // Enable all prompt capabilities
+            editor.set_prompt_capabilities(acp::PromptCapabilities {
+                image: true,
+                audio: true,
+                embedded_context: true,
+            });
+        });
+
         cx.simulate_input("Lorem ");
 
         editor.update(&mut cx, |editor, cx| {

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -492,6 +492,11 @@ impl AcpThreadView {
                             })
                         });
 
+                        this.message_editor.update(cx, |message_editor, _cx| {
+                            message_editor
+                                .set_prompt_capabilities(connection.prompt_capabilities());
+                        });
+
                         cx.notify();
                     }
                     Err(err) => {
@@ -4762,6 +4767,14 @@ pub(crate) mod tests {
             &[]
         }
 
+        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+            acp::PromptCapabilities {
+                image: true,
+                audio: true,
+                embedded_context: true,
+            }
+        }
+
         fn authenticate(
             &self,
             _method_id: acp::AuthMethodId,