Fix ACP agents not loading when not authenticated (#48959)

Bennet Bo Fenner and cameron created

Closes #48857

- [x] Code Reviewed
- [x] Manual QA

Release Notes:

- Fixed an issue where some ACP agents would not be loading correctly
when unauthenticated

---------

Co-authored-by: cameron <cameron.studdstreet@gmail.com>

Change summary

crates/agent_ui/src/acp/thread_view.rs               | 161 ++++++++-----
crates/agent_ui/src/acp/thread_view/active_thread.rs |   4 
crates/agent_ui/src/agent_panel.rs                   |   4 
3 files changed, 106 insertions(+), 63 deletions(-)

Detailed changes

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -55,8 +55,8 @@ use ui::{
     PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*,
     right_click_menu,
 };
-use util::defer;
 use util::{ResultExt, size::format_file_size, time::duration_alt_display};
+use util::{debug_panic, defer};
 use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId};
 use zed_actions::agent::{Chat, ToggleModelSelector};
 use zed_actions::assistant::OpenRulesLibrary;
@@ -178,9 +178,9 @@ pub struct AcpServerView {
 }
 
 impl AcpServerView {
-    pub fn active_thread(&self) -> Option<Entity<AcpThreadView>> {
+    pub fn active_thread(&self) -> Option<&Entity<AcpThreadView>> {
         match &self.server_state {
-            ServerState::Connected(connected) => Some(connected.current.clone()),
+            ServerState::Connected(connected) => connected.active_view(),
             _ => None,
         }
     }
@@ -188,15 +188,15 @@ impl AcpServerView {
     pub fn parent_thread(&self, cx: &App) -> Option<Entity<AcpThreadView>> {
         match &self.server_state {
             ServerState::Connected(connected) => {
-                let mut current = connected.current.clone();
+                let mut current = connected.active_view()?;
                 while let Some(parent_id) = current.read(cx).parent_id.clone() {
                     if let Some(parent) = connected.threads.get(&parent_id) {
-                        current = parent.clone();
+                        current = parent;
                     } else {
                         break;
                     }
                 }
-                Some(current)
+                Some(current.clone())
             }
             _ => None,
         }
@@ -249,7 +249,7 @@ enum ServerState {
 // hashmap of threads, current becomes session_id
 pub struct ConnectedServerState {
     auth_state: AuthState,
-    current: Entity<AcpThreadView>,
+    active_id: Option<acp::SessionId>,
     threads: HashMap<acp::SessionId, Entity<AcpThreadView>>,
     connection: Rc<dyn AgentConnection>,
 }
@@ -277,13 +277,18 @@ struct LoadingView {
 }
 
 impl ConnectedServerState {
+    pub fn active_view(&self) -> Option<&Entity<AcpThreadView>> {
+        self.active_id.as_ref().and_then(|id| self.threads.get(id))
+    }
+
     pub fn has_thread_error(&self, cx: &App) -> bool {
-        self.current.read(cx).thread_error.is_some()
+        self.active_view()
+            .map_or(false, |view| view.read(cx).thread_error.is_some())
     }
 
     pub fn navigate_to_session(&mut self, session_id: acp::SessionId) {
-        if let Some(session) = self.threads.get(&session_id) {
-            self.current = session.clone();
+        if self.threads.contains_key(&session_id) {
+            self.active_id = Some(session_id);
         }
     }
 
@@ -386,8 +391,8 @@ impl AcpServerView {
         );
         self.set_server_state(state, cx);
 
-        if let Some(connected) = self.as_connected() {
-            connected.current.update(cx, |this, cx| {
+        if let Some(view) = self.active_thread() {
+            view.update(cx, |this, cx| {
                 this.message_editor.update(cx, |editor, cx| {
                     editor.set_command_state(
                         this.prompt_capabilities.clone(),
@@ -520,7 +525,14 @@ impl AcpServerView {
                 Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
                     Ok(err) => {
                         cx.update(|window, cx| {
-                            Self::handle_auth_required(this, err, agent.name(), window, cx)
+                            Self::handle_auth_required(
+                                this,
+                                err,
+                                agent.name(),
+                                connection,
+                                window,
+                                cx,
+                            )
                         })
                         .log_err();
                         return;
@@ -551,15 +563,13 @@ impl AcpServerView {
                                 .focus(window, cx);
                         }
 
+                        let id = current.read(cx).thread.read(cx).session_id().clone();
                         this.set_server_state(
                             ServerState::Connected(ConnectedServerState {
                                 connection,
                                 auth_state: AuthState::Ok,
-                                current: current.clone(),
-                                threads: HashMap::from_iter([(
-                                    current.read(cx).thread.read(cx).session_id().clone(),
-                                    current,
-                                )]),
+                                active_id: Some(id.clone()),
+                                threads: HashMap::from_iter([(id, current)]),
                             }),
                             cx,
                         );
@@ -816,6 +826,7 @@ impl AcpServerView {
         this: WeakEntity<Self>,
         err: AuthRequired,
         agent_name: SharedString,
+        connection: Rc<dyn AgentConnection>,
         window: &mut Window,
         cx: &mut App,
     ) {
@@ -855,26 +866,36 @@ impl AcpServerView {
         };
 
         this.update(cx, |this, cx| {
+            let description = err
+                .description
+                .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
+            let auth_state = AuthState::Unauthenticated {
+                pending_auth_method: None,
+                configuration_view,
+                description,
+                _subscription: subscription,
+            };
             if let Some(connected) = this.as_connected_mut() {
-                let description = err
-                    .description
-                    .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
-
-                connected.auth_state = AuthState::Unauthenticated {
-                    pending_auth_method: None,
-                    configuration_view,
-                    description,
-                    _subscription: subscription,
-                };
-                if connected
-                    .current
-                    .read(cx)
-                    .message_editor
-                    .focus_handle(cx)
-                    .is_focused(window)
+                connected.auth_state = auth_state;
+                if let Some(view) = connected.active_view()
+                    && view
+                        .read(cx)
+                        .message_editor
+                        .focus_handle(cx)
+                        .is_focused(window)
                 {
                     this.focus_handle.focus(window, cx)
                 }
+            } else {
+                this.set_server_state(
+                    ServerState::Connected(ConnectedServerState {
+                        auth_state,
+                        active_id: None,
+                        threads: HashMap::default(),
+                        connection,
+                    }),
+                    cx,
+                );
             }
             cx.notify();
         })
@@ -887,19 +908,15 @@ impl AcpServerView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        match &self.server_state {
-            ServerState::Connected(connected) => {
-                if connected
-                    .current
-                    .read(cx)
-                    .message_editor
-                    .focus_handle(cx)
-                    .is_focused(window)
-                {
-                    self.focus_handle.focus(window, cx)
-                }
+        if let Some(view) = self.active_thread() {
+            if view
+                .read(cx)
+                .message_editor
+                .focus_handle(cx)
+                .is_focused(window)
+            {
+                self.focus_handle.focus(window, cx)
             }
-            _ => {}
         }
         let load_error = if let Some(load_err) = err.downcast_ref::<LoadError>() {
             load_err.clone()
@@ -1148,19 +1165,15 @@ impl AcpServerView {
                 }
             }
             AcpThreadEvent::LoadError(error) => {
-                match &self.server_state {
-                    ServerState::Connected(connected) => {
-                        if connected
-                            .current
-                            .read(cx)
-                            .message_editor
-                            .focus_handle(cx)
-                            .is_focused(window)
-                        {
-                            self.focus_handle.focus(window, cx)
-                        }
+                if let Some(view) = self.active_thread() {
+                    if view
+                        .read(cx)
+                        .message_editor
+                        .focus_handle(cx)
+                        .is_focused(window)
+                    {
+                        self.focus_handle.focus(window, cx)
                     }
-                    _ => {}
                 }
                 self.set_server_state(ServerState::LoadError(error.clone()), cx);
             }
@@ -1397,6 +1410,7 @@ impl AcpServerView {
                             provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     );
@@ -1422,6 +1436,7 @@ impl AcpServerView {
                             provider_id: None,
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     )
@@ -2397,8 +2412,19 @@ impl AcpServerView {
             active.update(cx, |active, cx| active.clear_thread_error(cx));
         }
         let this = cx.weak_entity();
+        let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else {
+            debug_panic!("This should not be possible");
+            return;
+        };
         window.defer(cx, |window, cx| {
-            Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx);
+            Self::handle_auth_required(
+                this,
+                AuthRequired::new(),
+                agent_name,
+                connection,
+                window,
+                cx,
+            );
         })
     }
 
@@ -2508,7 +2534,14 @@ impl Render for AcpServerView {
                         cx,
                     ))
                     .into_any_element(),
-                ServerState::Connected(connected) => connected.current.clone().into_any_element(),
+                ServerState::Connected(connected) => {
+                    if let Some(view) = connected.active_view() {
+                        view.clone().into_any_element()
+                    } else {
+                        debug_panic!("This state should never be reached");
+                        div().into_any_element()
+                    }
+                }
             })
     }
 }
@@ -3589,7 +3622,13 @@ pub(crate) mod tests {
         thread_view: &Entity<AcpServerView>,
         cx: &TestAppContext,
     ) -> Entity<AcpThreadView> {
-        cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone())
+        cx.read(|cx| {
+            thread_view
+                .read(cx)
+                .active_thread()
+                .expect("No active thread")
+                .clone()
+        })
     }
 
     fn message_editor(

crates/agent_ui/src/acp/thread_view/active_thread.rs 🔗

@@ -630,6 +630,7 @@ impl AcpThreadView {
             if can_login && !logout_supported {
                 message_editor.update(cx, |editor, cx| editor.clear(window, cx));
 
+                let connection = self.thread.read(cx).connection().clone();
                 window.defer(cx, {
                     let agent_name = self.agent_name.clone();
                     let server_view = self.server_view.clone();
@@ -638,6 +639,7 @@ impl AcpThreadView {
                             server_view.clone(),
                             AuthRequired::new(),
                             agent_name,
+                            connection,
                             window,
                             cx,
                         );
@@ -6716,11 +6718,13 @@ impl AcpThreadView {
                             editor.set_message(message, window, cx);
                         });
                     }
+                    let connection = this.thread.read(cx).connection().clone();
                     window.defer(cx, |window, cx| {
                         AcpServerView::handle_auth_required(
                             server_view,
                             AuthRequired::new(),
                             agent_name,
+                            connection,
                             window,
                             cx,
                         );

crates/agent_ui/src/agent_panel.rs 🔗

@@ -922,7 +922,7 @@ impl AgentPanel {
             return;
         };
 
-        let Some(active_thread) = thread_view.read(cx).active_thread() else {
+        let Some(active_thread) = thread_view.read(cx).active_thread().cloned() else {
             return;
         };
 
@@ -1195,7 +1195,7 @@ impl AgentPanel {
     ) {
         if let Some(workspace) = self.workspace.upgrade()
             && let Some(thread_view) = self.active_thread_view()
-            && let Some(active_thread) = thread_view.read(cx).active_thread()
+            && let Some(active_thread) = thread_view.read(cx).active_thread().cloned()
         {
             active_thread.update(cx, |thread, cx| {
                 thread