acp: Model-specific prompt capabilities for 1PA (#36879)

Cole Miller created

Adds support for per-session prompt capabilities and capability changes
on the Zed side (ACP itself still only has per-connection static
capabilities for now), and uses it to reflect image support accurately
in 1PA threads based on the currently-selected model.

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs       | 38 +++++++++++++++++++-----
crates/acp_thread/src/connection.rs       | 18 ++++-------
crates/agent2/src/agent.rs                | 13 ++-----
crates/agent2/src/thread.rs               | 21 +++++++++++++
crates/agent_servers/src/acp.rs           |  9 ++---
crates/agent_servers/src/claude.rs        | 16 ++++-----
crates/agent_ui/src/acp/message_editor.rs |  2 
crates/agent_ui/src/acp/thread_view.rs    | 20 +++++++-----
crates/agent_ui/src/agent_diff.rs         |  1 
crates/watch/src/watch.rs                 | 13 ++++++++
10 files changed, 98 insertions(+), 53 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -756,6 +756,8 @@ pub struct AcpThread {
     connection: Rc<dyn AgentConnection>,
     session_id: acp::SessionId,
     token_usage: Option<TokenUsage>,
+    prompt_capabilities: acp::PromptCapabilities,
+    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 }
 
 #[derive(Debug)]
@@ -770,6 +772,7 @@ pub enum AcpThreadEvent {
     Stopped,
     Error,
     LoadError(LoadError),
+    PromptCapabilitiesUpdated,
 }
 
 impl EventEmitter<AcpThreadEvent> for AcpThread {}
@@ -821,7 +824,20 @@ impl AcpThread {
         project: Entity<Project>,
         action_log: Entity<ActionLog>,
         session_id: acp::SessionId,
+        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
+        cx: &mut Context<Self>,
     ) -> Self {
+        let prompt_capabilities = *prompt_capabilities_rx.borrow();
+        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
+            loop {
+                let caps = prompt_capabilities_rx.recv().await?;
+                this.update(cx, |this, cx| {
+                    this.prompt_capabilities = caps;
+                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
+                })?;
+            }
+        });
+
         Self {
             action_log,
             shared_buffers: Default::default(),
@@ -833,9 +849,15 @@ impl AcpThread {
             connection,
             session_id,
             token_usage: None,
+            prompt_capabilities,
+            _observe_prompt_capabilities: task,
         }
     }
 
+    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
+        self.prompt_capabilities
+    }
+
     pub fn connection(&self) -> &Rc<dyn AgentConnection> {
         &self.connection
     }
@@ -2599,13 +2621,19 @@ mod tests {
                     .into(),
             );
             let action_log = cx.new(|_| ActionLog::new(project.clone()));
-            let thread = cx.new(|_cx| {
+            let thread = cx.new(|cx| {
                 AcpThread::new(
                     "Test",
                     self.clone(),
                     project,
                     action_log,
                     session_id.clone(),
+                    watch::Receiver::constant(acp::PromptCapabilities {
+                        image: true,
+                        audio: true,
+                        embedded_context: true,
+                    }),
+                    cx,
                 )
             });
             self.sessions.lock().insert(session_id, thread.downgrade());
@@ -2639,14 +2667,6 @@ mod tests {
             }
         }
 
-        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-            acp::PromptCapabilities {
-                image: true,
-                audio: true,
-                embedded_context: true,
-            }
-        }
-
         fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
             let sessions = self.sessions.lock();
             let thread = sessions.get(session_id).unwrap().clone();

crates/acp_thread/src/connection.rs 🔗

@@ -38,8 +38,6 @@ pub trait AgentConnection {
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>>;
 
-    fn prompt_capabilities(&self) -> acp::PromptCapabilities;
-
     fn resume(
         &self,
         _session_id: &acp::SessionId,
@@ -329,13 +327,19 @@ mod test_support {
         ) -> Task<gpui::Result<Entity<AcpThread>>> {
             let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
             let action_log = cx.new(|_| ActionLog::new(project.clone()));
-            let thread = cx.new(|_cx| {
+            let thread = cx.new(|cx| {
                 AcpThread::new(
                     "Test",
                     self.clone(),
                     project,
                     action_log,
                     session_id.clone(),
+                    watch::Receiver::constant(acp::PromptCapabilities {
+                        image: true,
+                        audio: true,
+                        embedded_context: true,
+                    }),
+                    cx,
                 )
             });
             self.sessions.lock().insert(
@@ -348,14 +352,6 @@ mod test_support {
             Task::ready(Ok(thread))
         }
 
-        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-            acp::PromptCapabilities {
-                image: true,
-                audio: true,
-                embedded_context: true,
-            }
-        }
-
         fn authenticate(
             &self,
             _method_id: acp::AuthMethodId,

crates/agent2/src/agent.rs 🔗

@@ -240,13 +240,16 @@ impl NativeAgent {
         let title = thread.title();
         let project = thread.project.clone();
         let action_log = thread.action_log.clone();
-        let acp_thread = cx.new(|_cx| {
+        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
+        let acp_thread = cx.new(|cx| {
             acp_thread::AcpThread::new(
                 title,
                 connection,
                 project.clone(),
                 action_log.clone(),
                 session_id.clone(),
+                prompt_capabilities_rx,
+                cx,
             )
         });
         let subscriptions = vec![
@@ -925,14 +928,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         })
     }
 
-    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-        acp::PromptCapabilities {
-            image: true,
-            audio: false,
-            embedded_context: true,
-        }
-    }
-
     fn resume(
         &self,
         session_id: &acp::SessionId,

crates/agent2/src/thread.rs 🔗

@@ -575,11 +575,22 @@ pub struct Thread {
     templates: Arc<Templates>,
     model: Option<Arc<dyn LanguageModel>>,
     summarization_model: Option<Arc<dyn LanguageModel>>,
+    prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
+    pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
     pub(crate) project: Entity<Project>,
     pub(crate) action_log: Entity<ActionLog>,
 }
 
 impl Thread {
+    fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
+        let image = model.map_or(true, |model| model.supports_images());
+        acp::PromptCapabilities {
+            image,
+            audio: false,
+            embedded_context: true,
+        }
+    }
+
     pub fn new(
         project: Entity<Project>,
         project_context: Entity<ProjectContext>,
@@ -590,6 +601,8 @@ impl Thread {
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
         let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
+        let (prompt_capabilities_tx, prompt_capabilities_rx) =
+            watch::channel(Self::prompt_capabilities(model.as_deref()));
         Self {
             id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
             prompt_id: PromptId::new(),
@@ -617,6 +630,8 @@ impl Thread {
             templates,
             model,
             summarization_model: None,
+            prompt_capabilities_tx,
+            prompt_capabilities_rx,
             project,
             action_log,
         }
@@ -750,6 +765,8 @@ impl Thread {
                 .or_else(|| registry.default_model())
                 .map(|model| model.model)
         });
+        let (prompt_capabilities_tx, prompt_capabilities_rx) =
+            watch::channel(Self::prompt_capabilities(model.as_deref()));
 
         Self {
             id,
@@ -779,6 +796,8 @@ impl Thread {
             project,
             action_log,
             updated_at: db_thread.updated_at,
+            prompt_capabilities_tx,
+            prompt_capabilities_rx,
         }
     }
 
@@ -946,10 +965,12 @@ impl Thread {
     pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
         let old_usage = self.latest_token_usage();
         self.model = Some(model);
+        let new_caps = Self::prompt_capabilities(self.model.as_deref());
         let new_usage = self.latest_token_usage();
         if old_usage != new_usage {
             cx.emit(TokenUsageUpdated(new_usage));
         }
+        self.prompt_capabilities_tx.send(new_caps).log_err();
         cx.notify()
     }
 

crates/agent_servers/src/acp.rs 🔗

@@ -185,13 +185,16 @@ impl AgentConnection for AcpConnection {
 
             let session_id = response.session_id;
             let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
-            let thread = cx.new(|_cx| {
+            let thread = cx.new(|cx| {
                 AcpThread::new(
                     self.server_name.clone(),
                     self.clone(),
                     project,
                     action_log,
                     session_id.clone(),
+                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
+                    watch::Receiver::constant(self.prompt_capabilities),
+                    cx,
                 )
             })?;
 
@@ -279,10 +282,6 @@ impl AgentConnection for AcpConnection {
         })
     }
 
-    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-        self.prompt_capabilities
-    }
-
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
         if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
             session.suppress_abort_err = true;

crates/agent_servers/src/claude.rs 🔗

@@ -249,13 +249,19 @@ impl AgentConnection for ClaudeAgentConnection {
             });
 
             let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
-            let thread = cx.new(|_cx| {
+            let thread = cx.new(|cx| {
                 AcpThread::new(
                     "Claude Code",
                     self.clone(),
                     project,
                     action_log,
                     session_id.clone(),
+                    watch::Receiver::constant(acp::PromptCapabilities {
+                        image: true,
+                        audio: false,
+                        embedded_context: true,
+                    }),
+                    cx,
                 )
             })?;
 
@@ -319,14 +325,6 @@ impl AgentConnection for ClaudeAgentConnection {
         cx.foreground_executor().spawn(async move { end_rx.await? })
     }
 
-    fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-        acp::PromptCapabilities {
-            image: true,
-            audio: false,
-            embedded_context: true,
-        }
-    }
-
     fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
         let sessions = self.sessions.borrow();
         let Some(session) = sessions.get(session_id) else {

crates/agent_ui/src/acp/message_editor.rs 🔗

@@ -373,7 +373,7 @@ impl MessageEditor {
 
         if Img::extensions().contains(&extension) && !extension.contains("svg") {
             if !self.prompt_capabilities.get().image {
-                return Task::ready(Err(anyhow!("This agent does not support images yet")));
+                return Task::ready(Err(anyhow!("This model does not support images yet")));
             }
             let task = self
                 .project

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -474,7 +474,7 @@ impl AcpThreadView {
                         let action_log = thread.read(cx).action_log().clone();
 
                         this.prompt_capabilities
-                            .set(connection.prompt_capabilities());
+                            .set(thread.read(cx).prompt_capabilities());
 
                         let count = thread.read(cx).entries().len();
                         this.list_state.splice(0..0, count);
@@ -1163,6 +1163,10 @@ impl AcpThreadView {
                     });
                 }
             }
+            AcpThreadEvent::PromptCapabilitiesUpdated => {
+                self.prompt_capabilities
+                    .set(thread.read(cx).prompt_capabilities());
+            }
             AcpThreadEvent::TokenUsageUpdated => {}
         }
         cx.notify();
@@ -5367,6 +5371,12 @@ pub(crate) mod tests {
                     project,
                     action_log,
                     SessionId("test".into()),
+                    watch::Receiver::constant(acp::PromptCapabilities {
+                        image: true,
+                        audio: true,
+                        embedded_context: true,
+                    }),
+                    cx,
                 )
             })))
         }
@@ -5375,14 +5385,6 @@ pub(crate) mod tests {
             &[]
         }
 
-        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
-            acp::PromptCapabilities {
-                image: true,
-                audio: true,
-                embedded_context: true,
-            }
-        }
-
         fn authenticate(
             &self,
             _method_id: acp::AuthMethodId,

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1529,6 +1529,7 @@ impl AgentDiff {
             | AcpThreadEvent::TokenUsageUpdated
             | AcpThreadEvent::EntriesRemoved(_)
             | AcpThreadEvent::ToolAuthorizationRequired
+            | AcpThreadEvent::PromptCapabilitiesUpdated
             | AcpThreadEvent::Retry(_) => {}
         }
     }

crates/watch/src/watch.rs 🔗

@@ -162,6 +162,19 @@ impl<T> Receiver<T> {
             pending_waker_id: None,
         }
     }
+
+    /// Creates a new [`Receiver`] holding an initial value that will never change.
+    pub fn constant(value: T) -> Self {
+        let state = Arc::new(RwLock::new(State {
+            value,
+            wakers: BTreeMap::new(),
+            next_waker_id: WakerId::default(),
+            version: 0,
+            closed: false,
+        }));
+
+        Self { state, version: 0 }
+    }
 }
 
 impl<T: Clone> Receiver<T> {