diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index dd15ab75113835bc345c8c071382c22fa8d88ba4..cffc90ea278e24fb81aba287c2668b2ac9a6655a 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -55,8 +55,8 @@ use ui::{ PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*, right_click_menu, }; -use util::defer; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; +use util::{debug_panic, defer}; use workspace::{ CollaboratorId, MultiWorkspace, NewTerminal, Toast, Workspace, notifications::NotificationId, }; @@ -180,9 +180,9 @@ pub struct AcpServerView { } impl AcpServerView { - pub fn active_thread(&self) -> Option> { + pub fn active_thread(&self) -> Option<&Entity> { match &self.server_state { - ServerState::Connected(connected) => Some(connected.current.clone()), + ServerState::Connected(connected) => connected.active_view(), _ => None, } } @@ -190,15 +190,15 @@ impl AcpServerView { pub fn parent_thread(&self, cx: &App) -> Option> { match &self.server_state { ServerState::Connected(connected) => { - let mut current = connected.current.clone(); + let mut current = connected.active_view()?; while let Some(parent_id) = current.read(cx).parent_id.clone() { if let Some(parent) = connected.threads.get(&parent_id) { - current = parent.clone(); + current = parent; } else { break; } } - Some(current) + Some(current.clone()) } _ => None, } @@ -251,7 +251,7 @@ enum ServerState { // hashmap of threads, current becomes session_id pub struct ConnectedServerState { auth_state: AuthState, - current: Entity, + active_id: Option, threads: HashMap>, connection: Rc, } @@ -279,13 +279,18 @@ struct LoadingView { } impl ConnectedServerState { + pub fn active_view(&self) -> Option<&Entity> { + self.active_id.as_ref().and_then(|id| self.threads.get(id)) + } + pub fn has_thread_error(&self, cx: &App) -> bool { - self.current.read(cx).thread_error.is_some() + self.active_view() + .map_or(false, |view| view.read(cx).thread_error.is_some()) } pub fn navigate_to_session(&mut self, session_id: acp::SessionId) { - if let Some(session) = self.threads.get(&session_id) { - self.current = session.clone(); + if self.threads.contains_key(&session_id) { + self.active_id = Some(session_id); } } @@ -388,8 +393,8 @@ impl AcpServerView { ); self.set_server_state(state, cx); - if let Some(connected) = self.as_connected() { - connected.current.update(cx, |this, cx| { + if let Some(view) = self.active_thread() { + view.update(cx, |this, cx| { this.message_editor.update(cx, |editor, cx| { editor.set_command_state( this.prompt_capabilities.clone(), @@ -522,7 +527,14 @@ impl AcpServerView { Err(e) => match e.downcast::() { Ok(err) => { cx.update(|window, cx| { - Self::handle_auth_required(this, err, agent.name(), window, cx) + Self::handle_auth_required( + this, + err, + agent.name(), + connection, + window, + cx, + ) }) .log_err(); return; @@ -553,15 +565,13 @@ impl AcpServerView { .focus(window, cx); } + let id = current.read(cx).thread.read(cx).session_id().clone(); this.set_server_state( ServerState::Connected(ConnectedServerState { connection, auth_state: AuthState::Ok, - current: current.clone(), - threads: HashMap::from_iter([( - current.read(cx).thread.read(cx).session_id().clone(), - current, - )]), + active_id: Some(id.clone()), + threads: HashMap::from_iter([(id, current)]), }), cx, ); @@ -818,6 +828,7 @@ impl AcpServerView { this: WeakEntity, err: AuthRequired, agent_name: SharedString, + connection: Rc, window: &mut Window, cx: &mut App, ) { @@ -857,26 +868,36 @@ impl AcpServerView { }; this.update(cx, |this, cx| { + let description = err + .description + .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))); + let auth_state = AuthState::Unauthenticated { + pending_auth_method: None, + configuration_view, + description, + _subscription: subscription, + }; if let Some(connected) = this.as_connected_mut() { - let description = err - .description - .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))); - - connected.auth_state = AuthState::Unauthenticated { - pending_auth_method: None, - configuration_view, - description, - _subscription: subscription, - }; - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) + connected.auth_state = auth_state; + if let Some(view) = connected.active_view() + && view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) { this.focus_handle.focus(window, cx) } + } else { + this.set_server_state( + ServerState::Connected(ConnectedServerState { + auth_state, + active_id: None, + threads: HashMap::default(), + connection, + }), + cx, + ); } cx.notify(); }) @@ -889,19 +910,15 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - match &self.server_state { - ServerState::Connected(connected) => { - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) - { - self.focus_handle.focus(window, cx) - } + if let Some(view) = self.active_thread() { + if view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) + { + self.focus_handle.focus(window, cx) } - _ => {} } let load_error = if let Some(load_err) = err.downcast_ref::() { load_err.clone() @@ -1150,19 +1167,15 @@ impl AcpServerView { } } AcpThreadEvent::LoadError(error) => { - match &self.server_state { - ServerState::Connected(connected) => { - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) - { - self.focus_handle.focus(window, cx) - } + if let Some(view) = self.active_thread() { + if view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) + { + self.focus_handle.focus(window, cx) } - _ => {} } self.set_server_state(ServerState::LoadError(error.clone()), cx); } @@ -1391,6 +1404,7 @@ impl AcpServerView { if !provider.is_authenticated(cx) { let this = cx.weak_entity(); let agent_name = self.agent.name(); + let connection = connection.clone(); window.defer(cx, |window, cx| { Self::handle_auth_required( this, @@ -1399,6 +1413,7 @@ impl AcpServerView { provider_id: Some(language_model::GOOGLE_PROVIDER_ID), }, agent_name, + connection, window, cx, ); @@ -1412,6 +1427,7 @@ impl AcpServerView { { let this = cx.weak_entity(); let agent_name = self.agent.name(); + let connection = connection.clone(); window.defer(cx, |window, cx| { Self::handle_auth_required( @@ -1424,6 +1440,7 @@ impl AcpServerView { provider_id: None, }, agent_name, + connection, window, cx, ) @@ -2416,8 +2433,19 @@ impl AcpServerView { active.update(cx, |active, cx| active.clear_thread_error(cx)); } let this = cx.weak_entity(); + let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else { + debug_panic!("This should not be possible"); + return; + }; window.defer(cx, |window, cx| { - Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx); + Self::handle_auth_required( + this, + AuthRequired::new(), + agent_name, + connection, + window, + cx, + ); }) } @@ -2527,7 +2555,14 @@ impl Render for AcpServerView { cx, )) .into_any_element(), - ServerState::Connected(connected) => connected.current.clone().into_any_element(), + ServerState::Connected(connected) => { + if let Some(view) = connected.active_view() { + view.clone().into_any_element() + } else { + debug_panic!("This state should never be reached"); + div().into_any_element() + } + } }) } } @@ -2940,6 +2975,78 @@ pub(crate) mod tests { }); } + #[gpui::test] + async fn test_auth_required_on_initial_connect(cx: &mut TestAppContext) { + init_test(cx); + + let connection = AuthGatedAgentConnection::new(); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + // When new_session returns AuthRequired, the server should transition + // to Connected + Unauthenticated rather than getting stuck in Loading. + thread_view.read_with(cx, |view, _cx| { + let connected = view + .as_connected() + .expect("Should be in Connected state even though auth is required"); + assert!( + !connected.auth_state.is_ok(), + "Auth state should be Unauthenticated" + ); + assert!( + connected.active_id.is_none(), + "There should be no active thread since no session was created" + ); + assert!( + connected.threads.is_empty(), + "There should be no threads since no session was created" + ); + }); + + thread_view.read_with(cx, |view, _cx| { + assert!( + view.active_thread().is_none(), + "active_thread() should be None when unauthenticated without a session" + ); + }); + + // Authenticate using the real authenticate flow on AcpServerView. + // This calls connection.authenticate(), which flips the internal flag, + // then on success triggers reset() -> new_session() which now succeeds. + thread_view.update_in(cx, |view, window, cx| { + view.authenticate( + acp::AuthMethodId::new(AuthGatedAgentConnection::AUTH_METHOD_ID), + window, + cx, + ); + }); + cx.run_until_parked(); + + // After auth, the server should have an active thread in the Ok state. + thread_view.read_with(cx, |view, cx| { + let connected = view + .as_connected() + .expect("Should still be in Connected state after auth"); + assert!(connected.auth_state.is_ok(), "Auth state should be Ok"); + assert!( + connected.active_id.is_some(), + "There should be an active thread after successful auth" + ); + assert_eq!( + connected.threads.len(), + 1, + "There should be exactly one thread" + ); + + let active = view + .active_thread() + .expect("active_thread() should return the new thread"); + assert!( + active.read(cx).thread_error.is_none(), + "The new thread should have no errors" + ); + }); + } + #[gpui::test] async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { init_test(cx); @@ -3497,6 +3604,99 @@ pub(crate) mod tests { } } + /// Simulates an agent that requires authentication before a session can be + /// created. `new_session` returns `AuthRequired` until `authenticate` is + /// called with the correct method, after which sessions are created normally. + #[derive(Clone)] + struct AuthGatedAgentConnection { + authenticated: Arc>, + auth_method: acp::AuthMethod, + } + + impl AuthGatedAgentConnection { + const AUTH_METHOD_ID: &str = "test-login"; + + fn new() -> Self { + Self { + authenticated: Arc::new(Mutex::new(false)), + auth_method: acp::AuthMethod::new(Self::AUTH_METHOD_ID, "Test Login"), + } + } + } + + impl AgentConnection for AuthGatedAgentConnection { + fn telemetry_id(&self) -> SharedString { + "auth-gated".into() + } + + fn new_session( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + if !*self.authenticated.lock() { + return Task::ready(Err(acp_thread::AuthRequired::new() + .with_description("Sign in to continue".to_string()) + .into())); + } + + let session_id = acp::SessionId::new("auth-gated-session"); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + Task::ready(Ok(cx.new(|cx| { + AcpThread::new( + None, + "AuthGatedAgent", + self, + project, + action_log, + session_id, + watch::Receiver::constant( + acp::PromptCapabilities::new() + .image(true) + .audio(true) + .embedded_context(true), + ), + cx, + ) + }))) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + std::slice::from_ref(&self.auth_method) + } + + fn authenticate( + &self, + method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + if method_id == self.auth_method.id { + *self.authenticated.lock() = true; + Task::ready(Ok(())) + } else { + Task::ready(Err(anyhow::anyhow!("Unknown auth method"))) + } + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn into_any(self: Rc) -> Rc { + self + } + } + #[derive(Clone)] struct SaboteurAgentConnection; @@ -3749,7 +3949,13 @@ pub(crate) mod tests { thread_view: &Entity, cx: &TestAppContext, ) -> Entity { - cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone()) + cx.read(|cx| { + thread_view + .read(cx) + .active_thread() + .expect("No active thread") + .clone() + }) } fn message_editor( diff --git a/crates/agent_ui/src/acp/thread_view/active_thread.rs b/crates/agent_ui/src/acp/thread_view/active_thread.rs index bde91fbec9a77ca9554ca31c3e3b4d97b7a21c04..fe296c6f9e0abe7b5fc7926d23eed7a37e5e0633 100644 --- a/crates/agent_ui/src/acp/thread_view/active_thread.rs +++ b/crates/agent_ui/src/acp/thread_view/active_thread.rs @@ -630,6 +630,7 @@ impl AcpThreadView { if can_login && !logout_supported { message_editor.update(cx, |editor, cx| editor.clear(window, cx)); + let connection = self.thread.read(cx).connection().clone(); window.defer(cx, { let agent_name = self.agent_name.clone(); let server_view = self.server_view.clone(); @@ -638,6 +639,7 @@ impl AcpThreadView { server_view.clone(), AuthRequired::new(), agent_name, + connection, window, cx, ); @@ -6716,11 +6718,13 @@ impl AcpThreadView { editor.set_message(message, window, cx); }); } + let connection = this.thread.read(cx).connection().clone(); window.defer(cx, |window, cx| { AcpServerView::handle_auth_required( server_view, AuthRequired::new(), agent_name, + connection, window, cx, ); diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index c57f156db693c5c24a4428994f7db7f32cb351e1..9338cde0da066bea295ea7bb0e68fb5844288852 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1006,7 +1006,7 @@ impl AgentPanel { return; }; - let Some(active_thread) = thread_view.read(cx).active_thread() else { + let Some(active_thread) = thread_view.read(cx).active_thread().cloned() else { return; }; @@ -1280,7 +1280,7 @@ impl AgentPanel { ) { if let Some(workspace) = self.workspace.upgrade() && let Some(thread_view) = self.active_thread_view() - && let Some(active_thread) = thread_view.read(cx).active_thread() + && let Some(active_thread) = thread_view.read(cx).active_thread().cloned() { active_thread.update(cx, |thread, cx| { thread