@@ -55,8 +55,8 @@ use ui::{
PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*,
right_click_menu,
};
-use util::defer;
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
+use util::{debug_panic, defer};
use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId};
use zed_actions::agent::{Chat, ToggleModelSelector};
use zed_actions::assistant::OpenRulesLibrary;
@@ -178,9 +178,9 @@ pub struct AcpServerView {
}
impl AcpServerView {
- pub fn active_thread(&self) -> Option<Entity<AcpThreadView>> {
+ pub fn active_thread(&self) -> Option<&Entity<AcpThreadView>> {
match &self.server_state {
- ServerState::Connected(connected) => Some(connected.current.clone()),
+ ServerState::Connected(connected) => connected.active_view(),
_ => None,
}
}
@@ -188,15 +188,15 @@ impl AcpServerView {
pub fn parent_thread(&self, cx: &App) -> Option<Entity<AcpThreadView>> {
match &self.server_state {
ServerState::Connected(connected) => {
- let mut current = connected.current.clone();
+ let mut current = connected.active_view()?;
while let Some(parent_id) = current.read(cx).parent_id.clone() {
if let Some(parent) = connected.threads.get(&parent_id) {
- current = parent.clone();
+ current = parent;
} else {
break;
}
}
- Some(current)
+ Some(current.clone())
}
_ => None,
}
@@ -249,7 +249,7 @@ enum ServerState {
// hashmap of threads, current becomes session_id
pub struct ConnectedServerState {
auth_state: AuthState,
- current: Entity<AcpThreadView>,
+ active_id: Option<acp::SessionId>,
threads: HashMap<acp::SessionId, Entity<AcpThreadView>>,
connection: Rc<dyn AgentConnection>,
}
@@ -277,13 +277,18 @@ struct LoadingView {
}
impl ConnectedServerState {
+ pub fn active_view(&self) -> Option<&Entity<AcpThreadView>> {
+ self.active_id.as_ref().and_then(|id| self.threads.get(id))
+ }
+
pub fn has_thread_error(&self, cx: &App) -> bool {
- self.current.read(cx).thread_error.is_some()
+ self.active_view()
+ .map_or(false, |view| view.read(cx).thread_error.is_some())
}
pub fn navigate_to_session(&mut self, session_id: acp::SessionId) {
- if let Some(session) = self.threads.get(&session_id) {
- self.current = session.clone();
+ if self.threads.contains_key(&session_id) {
+ self.active_id = Some(session_id);
}
}
@@ -386,8 +391,8 @@ impl AcpServerView {
);
self.set_server_state(state, cx);
- if let Some(connected) = self.as_connected() {
- connected.current.update(cx, |this, cx| {
+ if let Some(view) = self.active_thread() {
+ view.update(cx, |this, cx| {
this.message_editor.update(cx, |editor, cx| {
editor.set_command_state(
this.prompt_capabilities.clone(),
@@ -520,7 +525,14 @@ impl AcpServerView {
Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
Ok(err) => {
cx.update(|window, cx| {
- Self::handle_auth_required(this, err, agent.name(), window, cx)
+ Self::handle_auth_required(
+ this,
+ err,
+ agent.name(),
+ connection,
+ window,
+ cx,
+ )
})
.log_err();
return;
@@ -551,15 +563,13 @@ impl AcpServerView {
.focus(window, cx);
}
+ let id = current.read(cx).thread.read(cx).session_id().clone();
this.set_server_state(
ServerState::Connected(ConnectedServerState {
connection,
auth_state: AuthState::Ok,
- current: current.clone(),
- threads: HashMap::from_iter([(
- current.read(cx).thread.read(cx).session_id().clone(),
- current,
- )]),
+ active_id: Some(id.clone()),
+ threads: HashMap::from_iter([(id, current)]),
}),
cx,
);
@@ -816,6 +826,7 @@ impl AcpServerView {
this: WeakEntity<Self>,
err: AuthRequired,
agent_name: SharedString,
+ connection: Rc<dyn AgentConnection>,
window: &mut Window,
cx: &mut App,
) {
@@ -855,26 +866,36 @@ impl AcpServerView {
};
this.update(cx, |this, cx| {
+ let description = err
+ .description
+ .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
+ let auth_state = AuthState::Unauthenticated {
+ pending_auth_method: None,
+ configuration_view,
+ description,
+ _subscription: subscription,
+ };
if let Some(connected) = this.as_connected_mut() {
- let description = err
- .description
- .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
-
- connected.auth_state = AuthState::Unauthenticated {
- pending_auth_method: None,
- configuration_view,
- description,
- _subscription: subscription,
- };
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
+ connected.auth_state = auth_state;
+ if let Some(view) = connected.active_view()
+ && view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
{
this.focus_handle.focus(window, cx)
}
+ } else {
+ this.set_server_state(
+ ServerState::Connected(ConnectedServerState {
+ auth_state,
+ active_id: None,
+ threads: HashMap::default(),
+ connection,
+ }),
+ cx,
+ );
}
cx.notify();
})
@@ -887,19 +908,15 @@ impl AcpServerView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- match &self.server_state {
- ServerState::Connected(connected) => {
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
- {
- self.focus_handle.focus(window, cx)
- }
+ if let Some(view) = self.active_thread() {
+ if view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
+ {
+ self.focus_handle.focus(window, cx)
}
- _ => {}
}
let load_error = if let Some(load_err) = err.downcast_ref::<LoadError>() {
load_err.clone()
@@ -1148,19 +1165,15 @@ impl AcpServerView {
}
}
AcpThreadEvent::LoadError(error) => {
- match &self.server_state {
- ServerState::Connected(connected) => {
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
- {
- self.focus_handle.focus(window, cx)
- }
+ if let Some(view) = self.active_thread() {
+ if view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
+ {
+ self.focus_handle.focus(window, cx)
}
- _ => {}
}
self.set_server_state(ServerState::LoadError(error.clone()), cx);
}
@@ -1397,6 +1410,7 @@ impl AcpServerView {
provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
},
agent_name,
+ connection,
window,
cx,
);
@@ -1422,6 +1436,7 @@ impl AcpServerView {
provider_id: None,
},
agent_name,
+ connection,
window,
cx,
)
@@ -2397,8 +2412,19 @@ impl AcpServerView {
active.update(cx, |active, cx| active.clear_thread_error(cx));
}
let this = cx.weak_entity();
+ let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else {
+ debug_panic!("This should not be possible");
+ return;
+ };
window.defer(cx, |window, cx| {
- Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx);
+ Self::handle_auth_required(
+ this,
+ AuthRequired::new(),
+ agent_name,
+ connection,
+ window,
+ cx,
+ );
})
}
@@ -2508,7 +2534,14 @@ impl Render for AcpServerView {
cx,
))
.into_any_element(),
- ServerState::Connected(connected) => connected.current.clone().into_any_element(),
+ ServerState::Connected(connected) => {
+ if let Some(view) = connected.active_view() {
+ view.clone().into_any_element()
+ } else {
+ debug_panic!("This state should never be reached");
+ div().into_any_element()
+ }
+ }
})
}
}
@@ -3589,7 +3622,13 @@ pub(crate) mod tests {
thread_view: &Entity<AcpServerView>,
cx: &TestAppContext,
) -> Entity<AcpThreadView> {
- cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone())
+ cx.read(|cx| {
+ thread_view
+ .read(cx)
+ .active_thread()
+ .expect("No active thread")
+ .clone()
+ })
}
fn message_editor(