diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 94fbff72f780ab5f4a1fa00d53a1b068c8505247..c5e68ca6008dcea3ea8245c1da4418bf32d76c53 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -55,8 +55,8 @@ use ui::{ PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*, right_click_menu, }; -use util::defer; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; +use util::{debug_panic, defer}; use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId}; use zed_actions::agent::{Chat, ToggleModelSelector}; use zed_actions::assistant::OpenRulesLibrary; @@ -178,9 +178,9 @@ pub struct AcpServerView { } impl AcpServerView { - pub fn active_thread(&self) -> Option> { + pub fn active_thread(&self) -> Option<&Entity> { match &self.server_state { - ServerState::Connected(connected) => Some(connected.current.clone()), + ServerState::Connected(connected) => connected.active_view(), _ => None, } } @@ -188,15 +188,15 @@ impl AcpServerView { pub fn parent_thread(&self, cx: &App) -> Option> { match &self.server_state { ServerState::Connected(connected) => { - let mut current = connected.current.clone(); + let mut current = connected.active_view()?; while let Some(parent_id) = current.read(cx).parent_id.clone() { if let Some(parent) = connected.threads.get(&parent_id) { - current = parent.clone(); + current = parent; } else { break; } } - Some(current) + Some(current.clone()) } _ => None, } @@ -249,7 +249,7 @@ enum ServerState { // hashmap of threads, current becomes session_id pub struct ConnectedServerState { auth_state: AuthState, - current: Entity, + active_id: Option, threads: HashMap>, connection: Rc, } @@ -277,13 +277,18 @@ struct LoadingView { } impl ConnectedServerState { + pub fn active_view(&self) -> Option<&Entity> { + self.active_id.as_ref().and_then(|id| self.threads.get(id)) + } + pub fn has_thread_error(&self, cx: &App) -> bool { - self.current.read(cx).thread_error.is_some() + self.active_view() + .map_or(false, |view| view.read(cx).thread_error.is_some()) } pub fn navigate_to_session(&mut self, session_id: acp::SessionId) { - if let Some(session) = self.threads.get(&session_id) { - self.current = session.clone(); + if self.threads.contains_key(&session_id) { + self.active_id = Some(session_id); } } @@ -386,8 +391,8 @@ impl AcpServerView { ); self.set_server_state(state, cx); - if let Some(connected) = self.as_connected() { - connected.current.update(cx, |this, cx| { + if let Some(view) = self.active_thread() { + view.update(cx, |this, cx| { this.message_editor.update(cx, |editor, cx| { editor.set_command_state( this.prompt_capabilities.clone(), @@ -520,7 +525,14 @@ impl AcpServerView { Err(e) => match e.downcast::() { Ok(err) => { cx.update(|window, cx| { - Self::handle_auth_required(this, err, agent.name(), window, cx) + Self::handle_auth_required( + this, + err, + agent.name(), + connection, + window, + cx, + ) }) .log_err(); return; @@ -551,15 +563,13 @@ impl AcpServerView { .focus(window, cx); } + let id = current.read(cx).thread.read(cx).session_id().clone(); this.set_server_state( ServerState::Connected(ConnectedServerState { connection, auth_state: AuthState::Ok, - current: current.clone(), - threads: HashMap::from_iter([( - current.read(cx).thread.read(cx).session_id().clone(), - current, - )]), + active_id: Some(id.clone()), + threads: HashMap::from_iter([(id, current)]), }), cx, ); @@ -816,6 +826,7 @@ impl AcpServerView { this: WeakEntity, err: AuthRequired, agent_name: SharedString, + connection: Rc, window: &mut Window, cx: &mut App, ) { @@ -855,26 +866,36 @@ impl AcpServerView { }; this.update(cx, |this, cx| { + let description = err + .description + .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))); + let auth_state = AuthState::Unauthenticated { + pending_auth_method: None, + configuration_view, + description, + _subscription: subscription, + }; if let Some(connected) = this.as_connected_mut() { - let description = err - .description - .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))); - - connected.auth_state = AuthState::Unauthenticated { - pending_auth_method: None, - configuration_view, - description, - _subscription: subscription, - }; - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) + connected.auth_state = auth_state; + if let Some(view) = connected.active_view() + && view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) { this.focus_handle.focus(window, cx) } + } else { + this.set_server_state( + ServerState::Connected(ConnectedServerState { + auth_state, + active_id: None, + threads: HashMap::default(), + connection, + }), + cx, + ); } cx.notify(); }) @@ -887,19 +908,15 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - match &self.server_state { - ServerState::Connected(connected) => { - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) - { - self.focus_handle.focus(window, cx) - } + if let Some(view) = self.active_thread() { + if view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) + { + self.focus_handle.focus(window, cx) } - _ => {} } let load_error = if let Some(load_err) = err.downcast_ref::() { load_err.clone() @@ -1148,19 +1165,15 @@ impl AcpServerView { } } AcpThreadEvent::LoadError(error) => { - match &self.server_state { - ServerState::Connected(connected) => { - if connected - .current - .read(cx) - .message_editor - .focus_handle(cx) - .is_focused(window) - { - self.focus_handle.focus(window, cx) - } + if let Some(view) = self.active_thread() { + if view + .read(cx) + .message_editor + .focus_handle(cx) + .is_focused(window) + { + self.focus_handle.focus(window, cx) } - _ => {} } self.set_server_state(ServerState::LoadError(error.clone()), cx); } @@ -1397,6 +1410,7 @@ impl AcpServerView { provider_id: Some(language_model::GOOGLE_PROVIDER_ID), }, agent_name, + connection, window, cx, ); @@ -1422,6 +1436,7 @@ impl AcpServerView { provider_id: None, }, agent_name, + connection, window, cx, ) @@ -2397,8 +2412,19 @@ impl AcpServerView { active.update(cx, |active, cx| active.clear_thread_error(cx)); } let this = cx.weak_entity(); + let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else { + debug_panic!("This should not be possible"); + return; + }; window.defer(cx, |window, cx| { - Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx); + Self::handle_auth_required( + this, + AuthRequired::new(), + agent_name, + connection, + window, + cx, + ); }) } @@ -2508,7 +2534,14 @@ impl Render for AcpServerView { cx, )) .into_any_element(), - ServerState::Connected(connected) => connected.current.clone().into_any_element(), + ServerState::Connected(connected) => { + if let Some(view) = connected.active_view() { + view.clone().into_any_element() + } else { + debug_panic!("This state should never be reached"); + div().into_any_element() + } + } }) } } @@ -3589,7 +3622,13 @@ pub(crate) mod tests { thread_view: &Entity, cx: &TestAppContext, ) -> Entity { - cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone()) + cx.read(|cx| { + thread_view + .read(cx) + .active_thread() + .expect("No active thread") + .clone() + }) } fn message_editor( diff --git a/crates/agent_ui/src/acp/thread_view/active_thread.rs b/crates/agent_ui/src/acp/thread_view/active_thread.rs index bde91fbec9a77ca9554ca31c3e3b4d97b7a21c04..fe296c6f9e0abe7b5fc7926d23eed7a37e5e0633 100644 --- a/crates/agent_ui/src/acp/thread_view/active_thread.rs +++ b/crates/agent_ui/src/acp/thread_view/active_thread.rs @@ -630,6 +630,7 @@ impl AcpThreadView { if can_login && !logout_supported { message_editor.update(cx, |editor, cx| editor.clear(window, cx)); + let connection = self.thread.read(cx).connection().clone(); window.defer(cx, { let agent_name = self.agent_name.clone(); let server_view = self.server_view.clone(); @@ -638,6 +639,7 @@ impl AcpThreadView { server_view.clone(), AuthRequired::new(), agent_name, + connection, window, cx, ); @@ -6716,11 +6718,13 @@ impl AcpThreadView { editor.set_message(message, window, cx); }); } + let connection = this.thread.read(cx).connection().clone(); window.defer(cx, |window, cx| { AcpServerView::handle_auth_required( server_view, AuthRequired::new(), agent_name, + connection, window, cx, ); diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index ccfc0cd7073b08249a9bdc07cf3525f92e689e9a..3010bff352314e2a98a16bf7976c13bb6996e5f1 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -922,7 +922,7 @@ impl AgentPanel { return; }; - let Some(active_thread) = thread_view.read(cx).active_thread() else { + let Some(active_thread) = thread_view.read(cx).active_thread().cloned() else { return; }; @@ -1195,7 +1195,7 @@ impl AgentPanel { ) { if let Some(workspace) = self.workspace.upgrade() && let Some(thread_view) = self.active_thread_view() - && let Some(active_thread) = thread_view.read(cx).active_thread() + && let Some(active_thread) = thread_view.read(cx).active_thread().cloned() { active_thread.update(cx, |thread, cx| { thread