Backport "Fix ACP agents not loading when not authenticated"

Bennet Bo Fenner and Cameron created

Manual backport of https://github.com/zed-industries/zed/pull/48959 to
stable.

Co-Authored-By: Cameron <cameron@zed.dev>

Change summary

crates/agent_ui/src/acp/thread_view.rs               | 197 +++++++++----
crates/agent_ui/src/acp/thread_view/active_thread.rs |   1 
2 files changed, 129 insertions(+), 69 deletions(-)

Detailed changes

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -352,14 +352,14 @@ pub struct AcpServerView {
 impl AcpServerView {
     pub fn as_active_thread(&self) -> Option<&AcpThreadView> {
         match &self.server_state {
-            ServerState::Connected(connected) => Some(&connected.current),
+            ServerState::Connected(connected) => connected.current.as_ref(),
             _ => None,
         }
     }
 
     pub fn as_active_thread_mut(&mut self) -> Option<&mut AcpThreadView> {
         match &mut self.server_state {
-            ServerState::Connected(connected) => Some(&mut connected.current),
+            ServerState::Connected(connected) => connected.current.as_mut(),
             _ => None,
         }
     }
@@ -389,7 +389,7 @@ enum ServerState {
 // hashmap of threads, current becomes session_id
 pub struct ConnectedServerState {
     auth_state: AuthState,
-    current: AcpThreadView,
+    current: Option<AcpThreadView>,
     connection: Rc<dyn AgentConnection>,
 }
 
@@ -417,7 +417,9 @@ struct LoadingView {
 
 impl ConnectedServerState {
     pub fn has_thread_error(&self) -> bool {
-        self.current.thread_error.is_some()
+        self.current
+            .as_ref()
+            .map_or(false, |current| current.thread_error.is_some())
     }
 }
 
@@ -746,7 +748,14 @@ impl AcpServerView {
                 Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
                     Ok(err) => {
                         cx.update(|window, cx| {
-                            Self::handle_auth_required(this, err, agent.name(), window, cx)
+                            Self::handle_auth_required(
+                                this,
+                                err,
+                                agent.name(),
+                                connection.clone(),
+                                window,
+                                cx,
+                            )
                         })
                         .log_err();
                         return;
@@ -903,7 +912,7 @@ impl AcpServerView {
                         this.server_state = ServerState::Connected(ConnectedServerState {
                             connection,
                             auth_state: AuthState::Ok,
-                            current: AcpThreadView::new(
+                            current: Some(AcpThreadView::new(
                                 thread,
                                 workspace.clone(),
                                 entry_view_state,
@@ -921,7 +930,7 @@ impl AcpServerView {
                                 resume_thread.clone(),
                                 subscriptions,
                                 cx,
-                            ),
+                            )),
                         });
 
                         if this.focus_handle.contains_focused(window, cx) {
@@ -978,6 +987,7 @@ impl AcpServerView {
         this: WeakEntity<Self>,
         err: AuthRequired,
         agent_name: SharedString,
+        connection: Rc<dyn AgentConnection>,
         window: &mut Window,
         cx: &mut App,
     ) {
@@ -1017,17 +1027,25 @@ impl AcpServerView {
         };
 
         this.update(cx, |this, cx| {
+            let description = err
+                .description
+                .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
+
+            let auth_state = AuthState::Unauthenticated {
+                pending_auth_method: None,
+                configuration_view,
+                description,
+                _subscription: subscription,
+            };
+
             if let Some(connected) = this.as_connected_mut() {
-                let description = err
-                    .description
-                    .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
-
-                connected.auth_state = AuthState::Unauthenticated {
-                    pending_auth_method: None,
-                    configuration_view,
-                    description,
-                    _subscription: subscription,
-                };
+                connected.auth_state = auth_state;
+            } else {
+                this.server_state = ServerState::Connected(ConnectedServerState {
+                    auth_state,
+                    current: None,
+                    connection,
+                })
             }
             if this.message_editor.focus_handle(cx).is_focused(window) {
                 this.focus_handle.focus(window, cx)
@@ -1916,6 +1934,7 @@ impl AcpServerView {
                             provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     );
@@ -1941,6 +1960,7 @@ impl AcpServerView {
                             provider_id: None,
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     )
@@ -2306,7 +2326,7 @@ impl AcpServerView {
                                     .bg(cx.theme().colors().editor_background)
                                     .overflow_hidden();
 
-                                let is_loading_contents = matches!(&self.server_state, ServerState::Connected(ConnectedServerState { current: AcpThreadView { is_loading_contents: true, .. }, ..}));
+                                let is_loading_contents = matches!(&self.server_state, ServerState::Connected(ConnectedServerState { current: Some(AcpThreadView { is_loading_contents: true, .. }), ..}));
                                 if message.id.is_some() {
                                     this.child(
                                         base_container
@@ -2677,7 +2697,7 @@ impl AcpServerView {
 
         let key = (entry_ix, chunk_ix);
 
-        let is_open = matches!(&self.server_state, ServerState::Connected(ConnectedServerState {current: AcpThreadView { expanded_thinking_blocks, .. }, ..}) if expanded_thinking_blocks.contains(&key));
+        let is_open = matches!(&self.server_state, ServerState::Connected(ConnectedServerState {current: Some(AcpThreadView { expanded_thinking_blocks, .. }), ..}) if expanded_thinking_blocks.contains(&key));
 
         let scroll_handle = self
             .as_active_thread()
@@ -2823,9 +2843,10 @@ impl AcpServerView {
         let has_image_content = tool_call.content.iter().any(|c| c.image().is_some());
         let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
         let mut is_open = match &self.server_state {
-            ServerState::Connected(ConnectedServerState { current, .. }) => {
-                current.expanded_tool_calls.contains(&tool_call.id)
-            }
+            ServerState::Connected(ConnectedServerState {
+                current: Some(current),
+                ..
+            }) => current.expanded_tool_calls.contains(&tool_call.id),
             _ => false,
         };
 
@@ -2867,7 +2888,7 @@ impl AcpServerView {
                     )
                     .when(should_show_raw_input, |this| {
                         let is_raw_input_expanded =
-                            matches!(&self.server_state, ServerState::Connected(ConnectedServerState {current: AcpThreadView { expanded_tool_call_raw_inputs, .. }, ..}) if expanded_tool_call_raw_inputs.contains(&tool_call.id));
+                            matches!(&self.server_state, ServerState::Connected(ConnectedServerState {current: Some(AcpThreadView { expanded_tool_call_raw_inputs, .. }), ..}) if expanded_tool_call_raw_inputs.contains(&tool_call.id));
 
                         let input_header = if is_raw_input_expanded {
                             "Raw Input:"
@@ -3106,7 +3127,7 @@ impl AcpServerView {
                                         })
                                         .when_some(diff_for_discard, |this, diff| {
                                             let tool_call_id = tool_call.id.clone();
-                                            let is_discarded = matches!(&self.server_state, ServerState::Connected(ConnectedServerState{current: AcpThreadView { discarded_partial_edits, .. }, ..}) if discarded_partial_edits.contains(&tool_call_id));
+                                            let is_discarded = matches!(&self.server_state, ServerState::Connected(ConnectedServerState{current: Some(AcpThreadView { discarded_partial_edits, .. }), ..}) if discarded_partial_edits.contains(&tool_call_id));
                                             this.when(!is_discarded, |this| {
                                                 this.child(
                                                     IconButton::new(
@@ -4693,9 +4714,11 @@ impl AcpServerView {
         let command_element =
             self.render_collapsible_command(false, command_content, &tool_call.id, cx);
 
-        let is_expanded = self
-            .as_connected()
-            .is_some_and(|c| c.current.expanded_tool_calls.contains(&tool_call.id));
+        let is_expanded = self.as_connected().is_some_and(|c| {
+            c.current.as_ref().map_or(false, |view| {
+                view.expanded_tool_calls.contains(&tool_call.id)
+            })
+        });
 
         let header = h_flex()
             .id(header_id)
@@ -6007,7 +6030,8 @@ impl AcpServerView {
 
         let queued_message_editors = self
             .as_connected()
-            .map(|c| c.current.queued_message_editors.as_slice())
+            .and_then(|c| c.current.as_ref())
+            .map(|current| current.queued_message_editors.as_slice())
             .unwrap_or(&[]);
 
         let queue_len = queued_message_editors.len();
@@ -8172,13 +8196,17 @@ impl AcpServerView {
     }
 
     fn render_thread_error(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
-        let content = match self.as_active_thread()?.thread_error.as_ref()? {
+        let view = self.as_active_thread()?;
+
+        let connection = view.thread.read(cx).connection().clone();
+
+        let content = match view.thread_error.as_ref()? {
             ThreadError::Other { message, .. } => {
                 self.render_any_thread_error(message.clone(), window, cx)
             }
             ThreadError::Refusal => self.render_refusal_error(cx),
             ThreadError::AuthenticationRequired(error) => {
-                self.render_authentication_required_error(error.clone(), cx)
+                self.render_authentication_required_error(error.clone(), connection, cx)
             }
             ThreadError::PaymentRequired => self.render_payment_required_error(cx),
         };
@@ -8330,6 +8358,7 @@ impl AcpServerView {
     fn render_authentication_required_error(
         &self,
         error: SharedString,
+        connection: Rc<dyn AgentConnection>,
         cx: &mut Context<Self>,
     ) -> Callout {
         Callout::new()
@@ -8340,7 +8369,7 @@ impl AcpServerView {
             .actions_slot(
                 h_flex()
                     .gap_0p5()
-                    .child(self.authenticate_button(cx))
+                    .child(self.authenticate_button(connection, cx))
                     .child(self.create_copy_button(error)),
             )
             .dismiss_action(self.dismiss_error_button(cx))
@@ -8364,11 +8393,16 @@ impl AcpServerView {
             }))
     }
 
-    fn authenticate_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
+    fn authenticate_button(
+        &self,
+        connection: Rc<dyn AgentConnection>,
+        cx: &mut Context<Self>,
+    ) -> impl IntoElement {
         Button::new("authenticate", "Authenticate")
             .label_size(LabelSize::Small)
             .style(ButtonStyle::Filled)
             .on_click(cx.listener({
+                let connection = connection.clone();
                 move |this, _, window, cx| {
                     let agent_name = this.agent.name();
                     this.clear_thread_error(cx);
@@ -8378,25 +8412,40 @@ impl AcpServerView {
                         });
                     }
                     let this = cx.weak_entity();
-                    window.defer(cx, |window, cx| {
-                        Self::handle_auth_required(
-                            this,
-                            AuthRequired::new(),
-                            agent_name,
-                            window,
-                            cx,
-                        );
+                    window.defer(cx, {
+                        let connection = connection.clone();
+                        move |window, cx| {
+                            Self::handle_auth_required(
+                                this,
+                                AuthRequired::new(),
+                                agent_name,
+                                connection,
+                                window,
+                                cx,
+                            );
+                        }
                     })
                 }
             }))
     }
 
     pub(crate) fn reauthenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        let Some(view) = self.as_active_thread() else {
+            return;
+        };
+        let connection = view.thread.read(cx).connection().clone();
         let agent_name = self.agent.name();
         self.clear_thread_error(cx);
         let this = cx.weak_entity();
         window.defer(cx, |window, cx| {
-            Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx);
+            Self::handle_auth_required(
+                this,
+                AuthRequired::new(),
+                agent_name,
+                connection,
+                window,
+                cx,
+            );
         })
     }
 
@@ -8743,36 +8792,46 @@ impl Render for AcpServerView {
                         cx,
                     ))
                     .into_any_element(),
-                ServerState::Connected(connected) => v_flex().flex_1().map(|this| {
-                    let this = this.when(connected.current.resumed_without_history, |this| {
-                        this.child(self.render_resume_notice(cx))
-                    });
-                    if has_messages {
-                        this.child(
-                            list(
-                                connected.current.list_state.clone(),
-                                cx.processor(|this, index: usize, window, cx| {
-                                    let Some((entry, len)) =
-                                        this.as_active_thread().and_then(|active| {
-                                            let entries = &active.thread.read(cx).entries();
-                                            Some((entries.get(index)?, entries.len()))
-                                        })
-                                    else {
-                                        return Empty.into_any();
-                                    };
-                                    this.render_entry(index, len, entry, window, cx)
-                                }),
-                            )
-                            .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
-                            .flex_grow()
-                            .into_any(),
-                        )
-                        .vertical_scrollbar_for(&connected.current.list_state, window, cx)
-                        .into_any()
+                ServerState::Connected(connected) => {
+                    if let Some(current) = connected.current.as_ref() {
+                        v_flex()
+                            .flex_1()
+                            .map(|this| {
+                                let this = this.when(current.resumed_without_history, |this| {
+                                    this.child(self.render_resume_notice(cx))
+                                });
+                                if has_messages {
+                                    this.child(
+                                        list(
+                                            current.list_state.clone(),
+                                            cx.processor(|this, index: usize, window, cx| {
+                                                let Some((entry, len)) =
+                                                    this.as_active_thread().and_then(|active| {
+                                                        let entries =
+                                                            &active.thread.read(cx).entries();
+                                                        Some((entries.get(index)?, entries.len()))
+                                                    })
+                                                else {
+                                                    return Empty.into_any();
+                                                };
+                                                this.render_entry(index, len, entry, window, cx)
+                                            }),
+                                        )
+                                        .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
+                                        .flex_grow()
+                                        .into_any(),
+                                    )
+                                    .vertical_scrollbar_for(&current.list_state, window, cx)
+                                    .into_any()
+                                } else {
+                                    this.child(self.render_recent_history(cx)).into_any()
+                                }
+                            })
+                            .into_any_element()
                     } else {
-                        this.child(self.render_recent_history(cx)).into_any()
+                        div().into_any_element()
                     }
-                }),
+                }
             })
             // The activity bar is intentionally rendered outside of the ThreadState::Active match
             // above so that the scrollbar doesn't render behind it. The current setup allows