Re-land #48959 (#48990)

Mikayla Maki created

- [x] Tests or screenshots needed?
- [x] Code Reviewed
- [x] Manual QA

Release Notes:

- N/A

Change summary

crates/agent_ui/src/acp/thread_view.rs               | 328 +++++++++++--
crates/agent_ui/src/acp/thread_view/active_thread.rs |   4 
crates/agent_ui/src/agent_panel.rs                   |   4 
3 files changed, 273 insertions(+), 63 deletions(-)

Detailed changes

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -55,8 +55,8 @@ use ui::{
     PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*,
     right_click_menu,
 };
-use util::defer;
 use util::{ResultExt, size::format_file_size, time::duration_alt_display};
+use util::{debug_panic, defer};
 use workspace::{
     CollaboratorId, MultiWorkspace, NewTerminal, Toast, Workspace, notifications::NotificationId,
 };
@@ -180,9 +180,9 @@ pub struct AcpServerView {
 }
 
 impl AcpServerView {
-    pub fn active_thread(&self) -> Option<Entity<AcpThreadView>> {
+    pub fn active_thread(&self) -> Option<&Entity<AcpThreadView>> {
         match &self.server_state {
-            ServerState::Connected(connected) => Some(connected.current.clone()),
+            ServerState::Connected(connected) => connected.active_view(),
             _ => None,
         }
     }
@@ -190,15 +190,15 @@ impl AcpServerView {
     pub fn parent_thread(&self, cx: &App) -> Option<Entity<AcpThreadView>> {
         match &self.server_state {
             ServerState::Connected(connected) => {
-                let mut current = connected.current.clone();
+                let mut current = connected.active_view()?;
                 while let Some(parent_id) = current.read(cx).parent_id.clone() {
                     if let Some(parent) = connected.threads.get(&parent_id) {
-                        current = parent.clone();
+                        current = parent;
                     } else {
                         break;
                     }
                 }
-                Some(current)
+                Some(current.clone())
             }
             _ => None,
         }
@@ -251,7 +251,7 @@ enum ServerState {
 // hashmap of threads, current becomes session_id
 pub struct ConnectedServerState {
     auth_state: AuthState,
-    current: Entity<AcpThreadView>,
+    active_id: Option<acp::SessionId>,
     threads: HashMap<acp::SessionId, Entity<AcpThreadView>>,
     connection: Rc<dyn AgentConnection>,
 }
@@ -279,13 +279,18 @@ struct LoadingView {
 }
 
 impl ConnectedServerState {
+    pub fn active_view(&self) -> Option<&Entity<AcpThreadView>> {
+        self.active_id.as_ref().and_then(|id| self.threads.get(id))
+    }
+
     pub fn has_thread_error(&self, cx: &App) -> bool {
-        self.current.read(cx).thread_error.is_some()
+        self.active_view()
+            .map_or(false, |view| view.read(cx).thread_error.is_some())
     }
 
     pub fn navigate_to_session(&mut self, session_id: acp::SessionId) {
-        if let Some(session) = self.threads.get(&session_id) {
-            self.current = session.clone();
+        if self.threads.contains_key(&session_id) {
+            self.active_id = Some(session_id);
         }
     }
 
@@ -388,8 +393,8 @@ impl AcpServerView {
         );
         self.set_server_state(state, cx);
 
-        if let Some(connected) = self.as_connected() {
-            connected.current.update(cx, |this, cx| {
+        if let Some(view) = self.active_thread() {
+            view.update(cx, |this, cx| {
                 this.message_editor.update(cx, |editor, cx| {
                     editor.set_command_state(
                         this.prompt_capabilities.clone(),
@@ -522,7 +527,14 @@ impl AcpServerView {
                 Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
                     Ok(err) => {
                         cx.update(|window, cx| {
-                            Self::handle_auth_required(this, err, agent.name(), window, cx)
+                            Self::handle_auth_required(
+                                this,
+                                err,
+                                agent.name(),
+                                connection,
+                                window,
+                                cx,
+                            )
                         })
                         .log_err();
                         return;
@@ -553,15 +565,13 @@ impl AcpServerView {
                                 .focus(window, cx);
                         }
 
+                        let id = current.read(cx).thread.read(cx).session_id().clone();
                         this.set_server_state(
                             ServerState::Connected(ConnectedServerState {
                                 connection,
                                 auth_state: AuthState::Ok,
-                                current: current.clone(),
-                                threads: HashMap::from_iter([(
-                                    current.read(cx).thread.read(cx).session_id().clone(),
-                                    current,
-                                )]),
+                                active_id: Some(id.clone()),
+                                threads: HashMap::from_iter([(id, current)]),
                             }),
                             cx,
                         );
@@ -818,6 +828,7 @@ impl AcpServerView {
         this: WeakEntity<Self>,
         err: AuthRequired,
         agent_name: SharedString,
+        connection: Rc<dyn AgentConnection>,
         window: &mut Window,
         cx: &mut App,
     ) {
@@ -857,26 +868,36 @@ impl AcpServerView {
         };
 
         this.update(cx, |this, cx| {
+            let description = err
+                .description
+                .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
+            let auth_state = AuthState::Unauthenticated {
+                pending_auth_method: None,
+                configuration_view,
+                description,
+                _subscription: subscription,
+            };
             if let Some(connected) = this.as_connected_mut() {
-                let description = err
-                    .description
-                    .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
-
-                connected.auth_state = AuthState::Unauthenticated {
-                    pending_auth_method: None,
-                    configuration_view,
-                    description,
-                    _subscription: subscription,
-                };
-                if connected
-                    .current
-                    .read(cx)
-                    .message_editor
-                    .focus_handle(cx)
-                    .is_focused(window)
+                connected.auth_state = auth_state;
+                if let Some(view) = connected.active_view()
+                    && view
+                        .read(cx)
+                        .message_editor
+                        .focus_handle(cx)
+                        .is_focused(window)
                 {
                     this.focus_handle.focus(window, cx)
                 }
+            } else {
+                this.set_server_state(
+                    ServerState::Connected(ConnectedServerState {
+                        auth_state,
+                        active_id: None,
+                        threads: HashMap::default(),
+                        connection,
+                    }),
+                    cx,
+                );
             }
             cx.notify();
         })
@@ -889,19 +910,15 @@ impl AcpServerView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        match &self.server_state {
-            ServerState::Connected(connected) => {
-                if connected
-                    .current
-                    .read(cx)
-                    .message_editor
-                    .focus_handle(cx)
-                    .is_focused(window)
-                {
-                    self.focus_handle.focus(window, cx)
-                }
+        if let Some(view) = self.active_thread() {
+            if view
+                .read(cx)
+                .message_editor
+                .focus_handle(cx)
+                .is_focused(window)
+            {
+                self.focus_handle.focus(window, cx)
             }
-            _ => {}
         }
         let load_error = if let Some(load_err) = err.downcast_ref::<LoadError>() {
             load_err.clone()
@@ -1150,19 +1167,15 @@ impl AcpServerView {
                 }
             }
             AcpThreadEvent::LoadError(error) => {
-                match &self.server_state {
-                    ServerState::Connected(connected) => {
-                        if connected
-                            .current
-                            .read(cx)
-                            .message_editor
-                            .focus_handle(cx)
-                            .is_focused(window)
-                        {
-                            self.focus_handle.focus(window, cx)
-                        }
+                if let Some(view) = self.active_thread() {
+                    if view
+                        .read(cx)
+                        .message_editor
+                        .focus_handle(cx)
+                        .is_focused(window)
+                    {
+                        self.focus_handle.focus(window, cx)
                     }
-                    _ => {}
                 }
                 self.set_server_state(ServerState::LoadError(error.clone()), cx);
             }
@@ -1391,6 +1404,7 @@ impl AcpServerView {
             if !provider.is_authenticated(cx) {
                 let this = cx.weak_entity();
                 let agent_name = self.agent.name();
+                let connection = connection.clone();
                 window.defer(cx, |window, cx| {
                     Self::handle_auth_required(
                         this,
@@ -1399,6 +1413,7 @@ impl AcpServerView {
                             provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     );
@@ -1412,6 +1427,7 @@ impl AcpServerView {
         {
             let this = cx.weak_entity();
             let agent_name = self.agent.name();
+            let connection = connection.clone();
 
             window.defer(cx, |window, cx| {
                     Self::handle_auth_required(
@@ -1424,6 +1440,7 @@ impl AcpServerView {
                             provider_id: None,
                         },
                         agent_name,
+                        connection,
                         window,
                         cx,
                     )
@@ -2416,8 +2433,19 @@ impl AcpServerView {
             active.update(cx, |active, cx| active.clear_thread_error(cx));
         }
         let this = cx.weak_entity();
+        let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else {
+            debug_panic!("This should not be possible");
+            return;
+        };
         window.defer(cx, |window, cx| {
-            Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx);
+            Self::handle_auth_required(
+                this,
+                AuthRequired::new(),
+                agent_name,
+                connection,
+                window,
+                cx,
+            );
         })
     }
 
@@ -2527,7 +2555,14 @@ impl Render for AcpServerView {
                         cx,
                     ))
                     .into_any_element(),
-                ServerState::Connected(connected) => connected.current.clone().into_any_element(),
+                ServerState::Connected(connected) => {
+                    if let Some(view) = connected.active_view() {
+                        view.clone().into_any_element()
+                    } else {
+                        debug_panic!("This state should never be reached");
+                        div().into_any_element()
+                    }
+                }
             })
     }
 }
@@ -2940,6 +2975,78 @@ pub(crate) mod tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_auth_required_on_initial_connect(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let connection = AuthGatedAgentConnection::new();
+        let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
+
+        // When new_session returns AuthRequired, the server should transition
+        // to Connected + Unauthenticated rather than getting stuck in Loading.
+        thread_view.read_with(cx, |view, _cx| {
+            let connected = view
+                .as_connected()
+                .expect("Should be in Connected state even though auth is required");
+            assert!(
+                !connected.auth_state.is_ok(),
+                "Auth state should be Unauthenticated"
+            );
+            assert!(
+                connected.active_id.is_none(),
+                "There should be no active thread since no session was created"
+            );
+            assert!(
+                connected.threads.is_empty(),
+                "There should be no threads since no session was created"
+            );
+        });
+
+        thread_view.read_with(cx, |view, _cx| {
+            assert!(
+                view.active_thread().is_none(),
+                "active_thread() should be None when unauthenticated without a session"
+            );
+        });
+
+        // Authenticate using the real authenticate flow on AcpServerView.
+        // This calls connection.authenticate(), which flips the internal flag,
+        // then on success triggers reset() -> new_session() which now succeeds.
+        thread_view.update_in(cx, |view, window, cx| {
+            view.authenticate(
+                acp::AuthMethodId::new(AuthGatedAgentConnection::AUTH_METHOD_ID),
+                window,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        // After auth, the server should have an active thread in the Ok state.
+        thread_view.read_with(cx, |view, cx| {
+            let connected = view
+                .as_connected()
+                .expect("Should still be in Connected state after auth");
+            assert!(connected.auth_state.is_ok(), "Auth state should be Ok");
+            assert!(
+                connected.active_id.is_some(),
+                "There should be an active thread after successful auth"
+            );
+            assert_eq!(
+                connected.threads.len(),
+                1,
+                "There should be exactly one thread"
+            );
+
+            let active = view
+                .active_thread()
+                .expect("active_thread() should return the new thread");
+            assert!(
+                active.read(cx).thread_error.is_none(),
+                "The new thread should have no errors"
+            );
+        });
+    }
+
     #[gpui::test]
     async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
         init_test(cx);
@@ -3497,6 +3604,99 @@ pub(crate) mod tests {
         }
     }
 
+    /// Simulates an agent that requires authentication before a session can be
+    /// created. `new_session` returns `AuthRequired` until `authenticate` is
+    /// called with the correct method, after which sessions are created normally.
+    #[derive(Clone)]
+    struct AuthGatedAgentConnection {
+        authenticated: Arc<Mutex<bool>>,
+        auth_method: acp::AuthMethod,
+    }
+
+    impl AuthGatedAgentConnection {
+        const AUTH_METHOD_ID: &str = "test-login";
+
+        fn new() -> Self {
+            Self {
+                authenticated: Arc::new(Mutex::new(false)),
+                auth_method: acp::AuthMethod::new(Self::AUTH_METHOD_ID, "Test Login"),
+            }
+        }
+    }
+
+    impl AgentConnection for AuthGatedAgentConnection {
+        fn telemetry_id(&self) -> SharedString {
+            "auth-gated".into()
+        }
+
+        fn new_session(
+            self: Rc<Self>,
+            project: Entity<Project>,
+            _cwd: &Path,
+            cx: &mut gpui::App,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            if !*self.authenticated.lock() {
+                return Task::ready(Err(acp_thread::AuthRequired::new()
+                    .with_description("Sign in to continue".to_string())
+                    .into()));
+            }
+
+            let session_id = acp::SessionId::new("auth-gated-session");
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            Task::ready(Ok(cx.new(|cx| {
+                AcpThread::new(
+                    None,
+                    "AuthGatedAgent",
+                    self,
+                    project,
+                    action_log,
+                    session_id,
+                    watch::Receiver::constant(
+                        acp::PromptCapabilities::new()
+                            .image(true)
+                            .audio(true)
+                            .embedded_context(true),
+                    ),
+                    cx,
+                )
+            })))
+        }
+
+        fn auth_methods(&self) -> &[acp::AuthMethod] {
+            std::slice::from_ref(&self.auth_method)
+        }
+
+        fn authenticate(
+            &self,
+            method_id: acp::AuthMethodId,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<()>> {
+            if method_id == self.auth_method.id {
+                *self.authenticated.lock() = true;
+                Task::ready(Ok(()))
+            } else {
+                Task::ready(Err(anyhow::anyhow!("Unknown auth method")))
+            }
+        }
+
+        fn prompt(
+            &self,
+            _id: Option<acp_thread::UserMessageId>,
+            _params: acp::PromptRequest,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<acp::PromptResponse>> {
+            unimplemented!()
+        }
+
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
+            unimplemented!()
+        }
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
+    }
+
     #[derive(Clone)]
     struct SaboteurAgentConnection;
 
@@ -3749,7 +3949,13 @@ pub(crate) mod tests {
         thread_view: &Entity<AcpServerView>,
         cx: &TestAppContext,
     ) -> Entity<AcpThreadView> {
-        cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone())
+        cx.read(|cx| {
+            thread_view
+                .read(cx)
+                .active_thread()
+                .expect("No active thread")
+                .clone()
+        })
     }
 
     fn message_editor(

crates/agent_ui/src/acp/thread_view/active_thread.rs 🔗

@@ -630,6 +630,7 @@ impl AcpThreadView {
             if can_login && !logout_supported {
                 message_editor.update(cx, |editor, cx| editor.clear(window, cx));
 
+                let connection = self.thread.read(cx).connection().clone();
                 window.defer(cx, {
                     let agent_name = self.agent_name.clone();
                     let server_view = self.server_view.clone();
@@ -638,6 +639,7 @@ impl AcpThreadView {
                             server_view.clone(),
                             AuthRequired::new(),
                             agent_name,
+                            connection,
                             window,
                             cx,
                         );
@@ -6716,11 +6718,13 @@ impl AcpThreadView {
                             editor.set_message(message, window, cx);
                         });
                     }
+                    let connection = this.thread.read(cx).connection().clone();
                     window.defer(cx, |window, cx| {
                         AcpServerView::handle_auth_required(
                             server_view,
                             AuthRequired::new(),
                             agent_name,
+                            connection,
                             window,
                             cx,
                         );

crates/agent_ui/src/agent_panel.rs 🔗

@@ -1006,7 +1006,7 @@ impl AgentPanel {
             return;
         };
 
-        let Some(active_thread) = thread_view.read(cx).active_thread() else {
+        let Some(active_thread) = thread_view.read(cx).active_thread().cloned() else {
             return;
         };
 
@@ -1280,7 +1280,7 @@ impl AgentPanel {
     ) {
         if let Some(workspace) = self.workspace.upgrade()
             && let Some(thread_view) = self.active_thread_view()
-            && let Some(active_thread) = thread_view.read(cx).active_thread()
+            && let Some(active_thread) = thread_view.read(cx).active_thread().cloned()
         {
             active_thread.update(cx, |thread, cx| {
                 thread