agent: Fix session close capability check (#51479)

Bennet Bo Fenner created

Release Notes:

- agent: Fixed an issue where external agents would return an error
because unsupported ACP method was called

Change summary

crates/agent_servers/src/acp.rs        |   2 
crates/agent_ui/src/connection_view.rs | 235 +++++++++++++++++++++++++++
2 files changed, 232 insertions(+), 5 deletions(-)

Detailed changes

crates/agent_servers/src/acp.rs 🔗

@@ -753,7 +753,7 @@ impl AgentConnection for AcpConnection {
         session_id: &acp::SessionId,
         cx: &mut App,
     ) -> Task<Result<()>> {
-        if !self.agent_capabilities.session_capabilities.close.is_none() {
+        if !self.supports_close_session() {
             return Task::ready(Err(anyhow!(LoadError::Other(
                 "Closing sessions is not supported by this agent.".into()
             ))));

crates/agent_ui/src/connection_view.rs 🔗

@@ -462,10 +462,13 @@ impl ConnectedServerState {
     }
 
     pub fn close_all_sessions(&self, cx: &mut App) -> Task<()> {
-        let tasks = self
-            .threads
-            .keys()
-            .map(|id| self.connection.clone().close_session(id, cx));
+        let tasks = self.threads.keys().filter_map(|id| {
+            if self.connection.supports_close_session() {
+                Some(self.connection.clone().close_session(id, cx))
+            } else {
+                None
+            }
+        });
         let task = futures::future::join_all(tasks);
         cx.background_spawn(async move {
             task.await;
@@ -6536,4 +6539,228 @@ pub(crate) mod tests {
             "Main editor should have existing content and queued message separated by two newlines"
         );
     }
+
+    #[gpui::test]
+    async fn test_close_all_sessions_skips_when_unsupported(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let (multi_workspace, cx) =
+            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+
+        let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
+        let connection_store =
+            cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+
+        // StubAgentConnection defaults to supports_close_session() -> false
+        let thread_view = cx.update(|window, cx| {
+            cx.new(|cx| {
+                ConnectionView::new(
+                    Rc::new(StubAgentServer::default_response()),
+                    connection_store,
+                    Agent::Custom {
+                        name: "Test".into(),
+                    },
+                    None,
+                    None,
+                    None,
+                    None,
+                    workspace.downgrade(),
+                    project,
+                    Some(thread_store),
+                    None,
+                    window,
+                    cx,
+                )
+            })
+        });
+
+        cx.run_until_parked();
+
+        thread_view.read_with(cx, |view, _cx| {
+            let connected = view.as_connected().expect("Should be connected");
+            assert!(
+                !connected.threads.is_empty(),
+                "There should be at least one thread"
+            );
+            assert!(
+                !connected.connection.supports_close_session(),
+                "StubAgentConnection should not support close"
+            );
+        });
+
+        thread_view
+            .update(cx, |view, cx| {
+                view.as_connected()
+                    .expect("Should be connected")
+                    .close_all_sessions(cx)
+            })
+            .await;
+    }
+
+    #[gpui::test]
+    async fn test_close_all_sessions_calls_close_when_supported(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let (thread_view, cx) =
+            setup_thread_view(StubAgentServer::new(CloseCapableConnection::new()), cx).await;
+
+        cx.run_until_parked();
+
+        let close_capable = thread_view.read_with(cx, |view, _cx| {
+            let connected = view.as_connected().expect("Should be connected");
+            assert!(
+                !connected.threads.is_empty(),
+                "There should be at least one thread"
+            );
+            assert!(
+                connected.connection.supports_close_session(),
+                "CloseCapableConnection should support close"
+            );
+            connected
+                .connection
+                .clone()
+                .into_any()
+                .downcast::<CloseCapableConnection>()
+                .expect("Should be CloseCapableConnection")
+        });
+
+        thread_view
+            .update(cx, |view, cx| {
+                view.as_connected()
+                    .expect("Should be connected")
+                    .close_all_sessions(cx)
+            })
+            .await;
+
+        let closed_count = close_capable.closed_sessions.lock().len();
+        assert!(
+            closed_count > 0,
+            "close_session should have been called for each thread"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_close_session_returns_error_when_unsupported(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let (thread_view, cx) = setup_thread_view(StubAgentServer::default_response(), cx).await;
+
+        cx.run_until_parked();
+
+        let result = thread_view
+            .update(cx, |view, cx| {
+                let connected = view.as_connected().expect("Should be connected");
+                assert!(
+                    !connected.connection.supports_close_session(),
+                    "StubAgentConnection should not support close"
+                );
+                let session_id = connected
+                    .threads
+                    .keys()
+                    .next()
+                    .expect("Should have at least one thread")
+                    .clone();
+                connected.connection.clone().close_session(&session_id, cx)
+            })
+            .await;
+
+        assert!(
+            result.is_err(),
+            "close_session should return an error when close is not supported"
+        );
+        assert!(
+            result.unwrap_err().to_string().contains("not supported"),
+            "Error message should indicate that closing is not supported"
+        );
+    }
+
+    #[derive(Clone)]
+    struct CloseCapableConnection {
+        closed_sessions: Arc<Mutex<Vec<acp::SessionId>>>,
+    }
+
+    impl CloseCapableConnection {
+        fn new() -> Self {
+            Self {
+                closed_sessions: Arc::new(Mutex::new(Vec::new())),
+            }
+        }
+    }
+
+    impl AgentConnection for CloseCapableConnection {
+        fn telemetry_id(&self) -> SharedString {
+            "close-capable".into()
+        }
+
+        fn new_session(
+            self: Rc<Self>,
+            project: Entity<Project>,
+            cwd: &Path,
+            cx: &mut gpui::App,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            let thread = cx.new(|cx| {
+                AcpThread::new(
+                    None,
+                    "CloseCapableConnection",
+                    Some(cwd.to_path_buf()),
+                    self,
+                    project,
+                    action_log,
+                    SessionId::new("close-capable-session"),
+                    watch::Receiver::constant(
+                        acp::PromptCapabilities::new()
+                            .image(true)
+                            .audio(true)
+                            .embedded_context(true),
+                    ),
+                    cx,
+                )
+            });
+            Task::ready(Ok(thread))
+        }
+
+        fn supports_close_session(&self) -> bool {
+            true
+        }
+
+        fn close_session(
+            self: Rc<Self>,
+            session_id: &acp::SessionId,
+            _cx: &mut App,
+        ) -> Task<Result<()>> {
+            self.closed_sessions.lock().push(session_id.clone());
+            Task::ready(Ok(()))
+        }
+
+        fn auth_methods(&self) -> &[acp::AuthMethod] {
+            &[]
+        }
+
+        fn authenticate(
+            &self,
+            _method_id: acp::AuthMethodId,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<()>> {
+            Task::ready(Ok(()))
+        }
+
+        fn prompt(
+            &self,
+            _id: Option<acp_thread::UserMessageId>,
+            _params: acp::PromptRequest,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<acp::PromptResponse>> {
+            Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
+        }
+
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
+    }
 }