@@ -55,8 +55,8 @@ use ui::{
PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*,
right_click_menu,
};
-use util::defer;
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
+use util::{debug_panic, defer};
use workspace::{
CollaboratorId, MultiWorkspace, NewTerminal, Toast, Workspace, notifications::NotificationId,
};
@@ -180,9 +180,9 @@ pub struct AcpServerView {
}
impl AcpServerView {
- pub fn active_thread(&self) -> Option<Entity<AcpThreadView>> {
+ pub fn active_thread(&self) -> Option<&Entity<AcpThreadView>> {
match &self.server_state {
- ServerState::Connected(connected) => Some(connected.current.clone()),
+ ServerState::Connected(connected) => connected.active_view(),
_ => None,
}
}
@@ -190,15 +190,15 @@ impl AcpServerView {
pub fn parent_thread(&self, cx: &App) -> Option<Entity<AcpThreadView>> {
match &self.server_state {
ServerState::Connected(connected) => {
- let mut current = connected.current.clone();
+ let mut current = connected.active_view()?;
while let Some(parent_id) = current.read(cx).parent_id.clone() {
if let Some(parent) = connected.threads.get(&parent_id) {
- current = parent.clone();
+ current = parent;
} else {
break;
}
}
- Some(current)
+ Some(current.clone())
}
_ => None,
}
@@ -251,7 +251,7 @@ enum ServerState {
// hashmap of threads, current becomes session_id
pub struct ConnectedServerState {
auth_state: AuthState,
- current: Entity<AcpThreadView>,
+ active_id: Option<acp::SessionId>,
threads: HashMap<acp::SessionId, Entity<AcpThreadView>>,
connection: Rc<dyn AgentConnection>,
}
@@ -279,13 +279,18 @@ struct LoadingView {
}
impl ConnectedServerState {
+ pub fn active_view(&self) -> Option<&Entity<AcpThreadView>> {
+ self.active_id.as_ref().and_then(|id| self.threads.get(id))
+ }
+
pub fn has_thread_error(&self, cx: &App) -> bool {
- self.current.read(cx).thread_error.is_some()
+ self.active_view()
+ .map_or(false, |view| view.read(cx).thread_error.is_some())
}
pub fn navigate_to_session(&mut self, session_id: acp::SessionId) {
- if let Some(session) = self.threads.get(&session_id) {
- self.current = session.clone();
+ if self.threads.contains_key(&session_id) {
+ self.active_id = Some(session_id);
}
}
@@ -388,8 +393,8 @@ impl AcpServerView {
);
self.set_server_state(state, cx);
- if let Some(connected) = self.as_connected() {
- connected.current.update(cx, |this, cx| {
+ if let Some(view) = self.active_thread() {
+ view.update(cx, |this, cx| {
this.message_editor.update(cx, |editor, cx| {
editor.set_command_state(
this.prompt_capabilities.clone(),
@@ -522,7 +527,14 @@ impl AcpServerView {
Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
Ok(err) => {
cx.update(|window, cx| {
- Self::handle_auth_required(this, err, agent.name(), window, cx)
+ Self::handle_auth_required(
+ this,
+ err,
+ agent.name(),
+ connection,
+ window,
+ cx,
+ )
})
.log_err();
return;
@@ -553,15 +565,13 @@ impl AcpServerView {
.focus(window, cx);
}
+ let id = current.read(cx).thread.read(cx).session_id().clone();
this.set_server_state(
ServerState::Connected(ConnectedServerState {
connection,
auth_state: AuthState::Ok,
- current: current.clone(),
- threads: HashMap::from_iter([(
- current.read(cx).thread.read(cx).session_id().clone(),
- current,
- )]),
+ active_id: Some(id.clone()),
+ threads: HashMap::from_iter([(id, current)]),
}),
cx,
);
@@ -818,6 +828,7 @@ impl AcpServerView {
this: WeakEntity<Self>,
err: AuthRequired,
agent_name: SharedString,
+ connection: Rc<dyn AgentConnection>,
window: &mut Window,
cx: &mut App,
) {
@@ -857,26 +868,36 @@ impl AcpServerView {
};
this.update(cx, |this, cx| {
+ let description = err
+ .description
+ .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
+ let auth_state = AuthState::Unauthenticated {
+ pending_auth_method: None,
+ configuration_view,
+ description,
+ _subscription: subscription,
+ };
if let Some(connected) = this.as_connected_mut() {
- let description = err
- .description
- .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx)));
-
- connected.auth_state = AuthState::Unauthenticated {
- pending_auth_method: None,
- configuration_view,
- description,
- _subscription: subscription,
- };
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
+ connected.auth_state = auth_state;
+ if let Some(view) = connected.active_view()
+ && view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
{
this.focus_handle.focus(window, cx)
}
+ } else {
+ this.set_server_state(
+ ServerState::Connected(ConnectedServerState {
+ auth_state,
+ active_id: None,
+ threads: HashMap::default(),
+ connection,
+ }),
+ cx,
+ );
}
cx.notify();
})
@@ -889,19 +910,15 @@ impl AcpServerView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- match &self.server_state {
- ServerState::Connected(connected) => {
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
- {
- self.focus_handle.focus(window, cx)
- }
+ if let Some(view) = self.active_thread() {
+ if view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
+ {
+ self.focus_handle.focus(window, cx)
}
- _ => {}
}
let load_error = if let Some(load_err) = err.downcast_ref::<LoadError>() {
load_err.clone()
@@ -1150,19 +1167,15 @@ impl AcpServerView {
}
}
AcpThreadEvent::LoadError(error) => {
- match &self.server_state {
- ServerState::Connected(connected) => {
- if connected
- .current
- .read(cx)
- .message_editor
- .focus_handle(cx)
- .is_focused(window)
- {
- self.focus_handle.focus(window, cx)
- }
+ if let Some(view) = self.active_thread() {
+ if view
+ .read(cx)
+ .message_editor
+ .focus_handle(cx)
+ .is_focused(window)
+ {
+ self.focus_handle.focus(window, cx)
}
- _ => {}
}
self.set_server_state(ServerState::LoadError(error.clone()), cx);
}
@@ -1391,6 +1404,7 @@ impl AcpServerView {
if !provider.is_authenticated(cx) {
let this = cx.weak_entity();
let agent_name = self.agent.name();
+ let connection = connection.clone();
window.defer(cx, |window, cx| {
Self::handle_auth_required(
this,
@@ -1399,6 +1413,7 @@ impl AcpServerView {
provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
},
agent_name,
+ connection,
window,
cx,
);
@@ -1412,6 +1427,7 @@ impl AcpServerView {
{
let this = cx.weak_entity();
let agent_name = self.agent.name();
+ let connection = connection.clone();
window.defer(cx, |window, cx| {
Self::handle_auth_required(
@@ -1424,6 +1440,7 @@ impl AcpServerView {
provider_id: None,
},
agent_name,
+ connection,
window,
cx,
)
@@ -2416,8 +2433,19 @@ impl AcpServerView {
active.update(cx, |active, cx| active.clear_thread_error(cx));
}
let this = cx.weak_entity();
+ let Some(connection) = self.as_connected().map(|c| c.connection.clone()) else {
+ debug_panic!("This should not be possible");
+ return;
+ };
window.defer(cx, |window, cx| {
- Self::handle_auth_required(this, AuthRequired::new(), agent_name, window, cx);
+ Self::handle_auth_required(
+ this,
+ AuthRequired::new(),
+ agent_name,
+ connection,
+ window,
+ cx,
+ );
})
}
@@ -2527,7 +2555,14 @@ impl Render for AcpServerView {
cx,
))
.into_any_element(),
- ServerState::Connected(connected) => connected.current.clone().into_any_element(),
+ ServerState::Connected(connected) => {
+ if let Some(view) = connected.active_view() {
+ view.clone().into_any_element()
+ } else {
+ debug_panic!("This state should never be reached");
+ div().into_any_element()
+ }
+ }
})
}
}
@@ -2940,6 +2975,78 @@ pub(crate) mod tests {
});
}
+ #[gpui::test]
+ async fn test_auth_required_on_initial_connect(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let connection = AuthGatedAgentConnection::new();
+ let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
+
+ // When new_session returns AuthRequired, the server should transition
+ // to Connected + Unauthenticated rather than getting stuck in Loading.
+ thread_view.read_with(cx, |view, _cx| {
+ let connected = view
+ .as_connected()
+ .expect("Should be in Connected state even though auth is required");
+ assert!(
+ !connected.auth_state.is_ok(),
+ "Auth state should be Unauthenticated"
+ );
+ assert!(
+ connected.active_id.is_none(),
+ "There should be no active thread since no session was created"
+ );
+ assert!(
+ connected.threads.is_empty(),
+ "There should be no threads since no session was created"
+ );
+ });
+
+ thread_view.read_with(cx, |view, _cx| {
+ assert!(
+ view.active_thread().is_none(),
+ "active_thread() should be None when unauthenticated without a session"
+ );
+ });
+
+ // Authenticate using the real authenticate flow on AcpServerView.
+ // This calls connection.authenticate(), which flips the internal flag,
+ // then on success triggers reset() -> new_session() which now succeeds.
+ thread_view.update_in(cx, |view, window, cx| {
+ view.authenticate(
+ acp::AuthMethodId::new(AuthGatedAgentConnection::AUTH_METHOD_ID),
+ window,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+
+ // After auth, the server should have an active thread in the Ok state.
+ thread_view.read_with(cx, |view, cx| {
+ let connected = view
+ .as_connected()
+ .expect("Should still be in Connected state after auth");
+ assert!(connected.auth_state.is_ok(), "Auth state should be Ok");
+ assert!(
+ connected.active_id.is_some(),
+ "There should be an active thread after successful auth"
+ );
+ assert_eq!(
+ connected.threads.len(),
+ 1,
+ "There should be exactly one thread"
+ );
+
+ let active = view
+ .active_thread()
+ .expect("active_thread() should return the new thread");
+ assert!(
+ active.read(cx).thread_error.is_none(),
+ "The new thread should have no errors"
+ );
+ });
+ }
+
#[gpui::test]
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
init_test(cx);
@@ -3497,6 +3604,99 @@ pub(crate) mod tests {
}
}
+ /// Simulates an agent that requires authentication before a session can be
+ /// created. `new_session` returns `AuthRequired` until `authenticate` is
+ /// called with the correct method, after which sessions are created normally.
+ #[derive(Clone)]
+ struct AuthGatedAgentConnection {
+ authenticated: Arc<Mutex<bool>>,
+ auth_method: acp::AuthMethod,
+ }
+
+ impl AuthGatedAgentConnection {
+ const AUTH_METHOD_ID: &str = "test-login";
+
+ fn new() -> Self {
+ Self {
+ authenticated: Arc::new(Mutex::new(false)),
+ auth_method: acp::AuthMethod::new(Self::AUTH_METHOD_ID, "Test Login"),
+ }
+ }
+ }
+
+ impl AgentConnection for AuthGatedAgentConnection {
+ fn telemetry_id(&self) -> SharedString {
+ "auth-gated".into()
+ }
+
+ fn new_session(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ _cwd: &Path,
+ cx: &mut gpui::App,
+ ) -> Task<gpui::Result<Entity<AcpThread>>> {
+ if !*self.authenticated.lock() {
+ return Task::ready(Err(acp_thread::AuthRequired::new()
+ .with_description("Sign in to continue".to_string())
+ .into()));
+ }
+
+ let session_id = acp::SessionId::new("auth-gated-session");
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ Task::ready(Ok(cx.new(|cx| {
+ AcpThread::new(
+ None,
+ "AuthGatedAgent",
+ self,
+ project,
+ action_log,
+ session_id,
+ watch::Receiver::constant(
+ acp::PromptCapabilities::new()
+ .image(true)
+ .audio(true)
+ .embedded_context(true),
+ ),
+ cx,
+ )
+ })))
+ }
+
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ std::slice::from_ref(&self.auth_method)
+ }
+
+ fn authenticate(
+ &self,
+ method_id: acp::AuthMethodId,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<()>> {
+ if method_id == self.auth_method.id {
+ *self.authenticated.lock() = true;
+ Task::ready(Ok(()))
+ } else {
+ Task::ready(Err(anyhow::anyhow!("Unknown auth method")))
+ }
+ }
+
+ fn prompt(
+ &self,
+ _id: Option<acp_thread::UserMessageId>,
+ _params: acp::PromptRequest,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<acp::PromptResponse>> {
+ unimplemented!()
+ }
+
+ fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
+ unimplemented!()
+ }
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
+ }
+
#[derive(Clone)]
struct SaboteurAgentConnection;
@@ -3749,7 +3949,13 @@ pub(crate) mod tests {
thread_view: &Entity<AcpServerView>,
cx: &TestAppContext,
) -> Entity<AcpThreadView> {
- cx.read(|cx| thread_view.read(cx).as_connected().unwrap().current.clone())
+ cx.read(|cx| {
+ thread_view
+ .read(cx)
+ .active_thread()
+ .expect("No active thread")
+ .clone()
+ })
}
fn message_editor(