From 44a457e8b6e01fdeca6fbade1bdf2119d657214d Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 14 Sep 2021 18:21:46 -0600 Subject: [PATCH] Cache credentials in memory separately from connection status This prevents us from re-prompting for keychain access when we retry connections after the connection is lost. Co-Authored-By: Max Brunsfeld --- server/src/rpc.rs | 67 ++++++++++----------- zed/src/rpc.rs | 146 ++++++++++++++++++++++++++-------------------- zed/src/test.rs | 39 +++++++------ zed/src/user.rs | 4 +- 4 files changed, 139 insertions(+), 117 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 912f57fcf714e8a892c27accf73ee14ee5c76ba3..86f369fb8a12296569ab8100baa67ce33f9e45d1 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1023,7 +1023,7 @@ mod tests { editor::{Editor, Insert}, fs::{FakeFs, Fs as _}, language::LanguageRegistry, - rpc::{self, Client}, + rpc::{self, Client, Credentials}, settings, test::FakeHttpClient, user::UserStore, @@ -1922,39 +1922,40 @@ mod tests { let forbid_connections = self.forbid_connections.clone(); Arc::get_mut(&mut client) .unwrap() - .set_login_and_connect_callbacks( - move |cx| { - cx.spawn(|_| async move { - let access_token = "the-token".to_string(); - Ok((client_user_id.0 as u64, access_token)) + .override_authenticate(move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok(Credentials { + user_id: client_user_id.0 as u64, + access_token, }) - }, - move |user_id, access_token, cx| { - assert_eq!(user_id, client_user_id.0 as u64); - assert_eq!(access_token, "the-token"); - - let server = server.clone(); - let connection_killers = connection_killers.clone(); - let forbid_connections = forbid_connections.clone(); - let client_name = client_name.clone(); - cx.spawn(move |cx| async move { - if forbid_connections.load(SeqCst) { - Err(anyhow!("server is forbidding connections")) - } else { - let (client_conn, server_conn, kill_conn) = Conn::in_memory(); - connection_killers.lock().insert(client_user_id, kill_conn); - cx.background() - .spawn(server.handle_connection( - server_conn, - client_name, - client_user_id, - )) - .detach(); - Ok(client_conn) - } - }) - }, - ); + }) + }) + .override_establish_connection(move |credentials, cx| { + assert_eq!(credentials.user_id, client_user_id.0 as u64); + assert_eq!(credentials.access_token, "the-token"); + + let server = server.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(anyhow!("server is forbidding connections")) + } else { + let (client_conn, server_conn, kill_conn) = Conn::in_memory(); + connection_killers.lock().insert(client_user_id, kill_conn); + cx.background() + .spawn(server.handle_connection( + server_conn, + client_name, + client_user_id, + )) + .detach(); + Ok(client_conn) + } + }) + }); client .authenticate_and_connect(&cx.to_async()) diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 69bc33a62e1c4065f7c946b239e72d7039ef5c49..bc6b41dd5f523e40e239f78465ed246b38c77bfd 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -29,11 +29,10 @@ lazy_static! { pub struct Client { peer: Arc, state: RwLock, - auth_callback: Option< - Box Task>>, - >, - connect_callback: Option< - Box Task>>, + authenticate: + Option Task>>>, + establish_connection: Option< + Box Task>>, >, } @@ -41,25 +40,17 @@ pub struct Client { pub enum Status { SignedOut, Authenticating, - Connecting { - user_id: u64, - }, + Connecting, ConnectionError, - Connected { - connection_id: ConnectionId, - user_id: u64, - }, + Connected { connection_id: ConnectionId }, ConnectionLost, Reauthenticating, - Reconnecting { - user_id: u64, - }, - ReconnectionError { - next_reconnection: Instant, - }, + Reconnecting, + ReconnectionError { next_reconnection: Instant }, } struct ClientState { + credentials: Option, status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>>, model_handlers: HashMap< @@ -70,9 +61,16 @@ struct ClientState { heartbeat_interval: Duration, } +#[derive(Clone)] +pub struct Credentials { + pub user_id: u64, + pub access_token: String, +} + impl Default for ClientState { fn default() -> Self { Self { + credentials: None, status: watch::channel_with(Status::SignedOut), entity_id_extractors: Default::default(), model_handlers: Default::default(), @@ -107,22 +105,35 @@ impl Client { Arc::new(Self { peer: Peer::new(), state: Default::default(), - auth_callback: None, - connect_callback: None, + authenticate: None, + establish_connection: None, }) } #[cfg(any(test, feature = "test-support"))] - pub fn set_login_and_connect_callbacks( - &mut self, - login: Login, - connect: Connect, - ) where - Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, - Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task>, + pub fn override_authenticate(&mut self, authenticate: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, { - self.auth_callback = Some(Box::new(login)); - self.connect_callback = Some(Box::new(connect)); + self.authenticate = Some(Box::new(authenticate)); + self + } + + #[cfg(any(test, feature = "test-support"))] + pub fn override_establish_connection(&mut self, connect: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task>, + { + self.establish_connection = Some(Box::new(connect)); + self + } + + pub fn user_id(&self) -> Option { + self.state + .read() + .credentials + .as_ref() + .map(|credentials| credentials.user_id) } pub fn status(&self) -> watch::Receiver { @@ -249,23 +260,31 @@ impl Client { self.set_status(Status::Reauthenticating, cx) } - let (user_id, access_token) = match self.authenticate(&cx).await { - Ok(result) => result, - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return Err(err); - } + let credentials = self.state.read().credentials.clone(); + let credentials = if let Some(credentials) = credentials { + credentials + } else { + let credentials = match self.authenticate(&cx).await { + Ok(credentials) => credentials, + Err(err) => { + self.set_status(Status::ConnectionError, cx); + return Err(err); + } + }; + self.state.write().credentials = Some(credentials.clone()); + credentials }; if was_disconnected { - self.set_status(Status::Connecting { user_id }, cx); + self.set_status(Status::Connecting, cx); } else { - self.set_status(Status::Reconnecting { user_id }, cx); + self.set_status(Status::Reconnecting, cx); } - match self.connect(user_id, &access_token, cx).await { + + match self.establish_connection(&credentials, cx).await { Ok(conn) => { log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.set_connection(user_id, conn, cx).await; + self.set_connection(conn, cx).await; Ok(()) } Err(err) => { @@ -275,7 +294,7 @@ impl Client { } } - async fn set_connection(self: &Arc, user_id: u64, conn: Conn, cx: &AsyncAppContext) { + async fn set_connection(self: &Arc, conn: Conn, cx: &AsyncAppContext) { let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; cx.foreground() .spawn({ @@ -310,13 +329,7 @@ impl Client { }) .detach(); - self.set_status( - Status::Connected { - connection_id, - user_id, - }, - cx, - ); + self.set_status(Status::Connected { connection_id }, cx); let handle_io = cx.background().spawn(handle_io); let this = self.clone(); @@ -334,35 +347,35 @@ impl Client { .detach(); } - fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { - if let Some(callback) = self.auth_callback.as_ref() { + fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + if let Some(callback) = self.authenticate.as_ref() { callback(cx) } else { self.authenticate_with_browser(cx) } } - fn connect( + fn establish_connection( self: &Arc, - user_id: u64, - access_token: &str, + credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - if let Some(callback) = self.connect_callback.as_ref() { - callback(user_id, access_token, cx) + if let Some(callback) = self.establish_connection.as_ref() { + callback(credentials, cx) } else { - self.connect_with_websocket(user_id, access_token, cx) + self.establish_websocket_connection(credentials, cx) } } - fn connect_with_websocket( + fn establish_websocket_connection( self: &Arc, - user_id: u64, - access_token: &str, + credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - let request = - Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); + let request = Request::builder().header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ); cx.background().spawn(async move { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { let stream = smol::net::TcpStream::connect(host).await?; @@ -387,7 +400,7 @@ impl Client { pub fn authenticate_with_browser( self: &Arc, cx: &AsyncAppContext, - ) -> Task> { + ) -> Task> { let platform = cx.platform(); let executor = cx.background(); executor.clone().spawn(async move { @@ -397,7 +410,10 @@ impl Client { .flatten() { log::info!("already signed in. user_id: {}", user_id); - return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap())); + return Ok(Credentials { + user_id: user_id.parse()?, + access_token: String::from_utf8(access_token).unwrap(), + }); } // Generate a pair of asymmetric encryption keys. The public key will be used by the @@ -463,7 +479,11 @@ impl Client { platform .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes()) .log_err(); - Ok((user_id.parse()?, access_token)) + + Ok(Credentials { + user_id: user_id.parse()?, + access_token, + }) }) } diff --git a/zed/src/test.rs b/zed/src/test.rs index e8527a4ed762a057fbcd72355fa4783024e57f40..b9948cc460f66d7891b3321dc6bb09d34b207c5a 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -4,7 +4,7 @@ use crate::{ fs::RealFs, http::{HttpClient, Request, Response, ServerResponse}, language::LanguageRegistry, - rpc::{self, Client}, + rpc::{self, Client, Credentials}, settings::{self, ThemeRegistry}, time::ReplicaId, user::UserStore, @@ -226,25 +226,26 @@ impl FakeServer { Arc::get_mut(client) .unwrap() - .set_login_and_connect_callbacks( - move |cx| { - cx.spawn(|_| async move { - let access_token = "the-token".to_string(); - Ok((client_user_id, access_token)) + .override_authenticate(move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok(Credentials { + user_id: client_user_id, + access_token, }) - }, - { - let server = result.clone(); - move |user_id, access_token, cx| { - assert_eq!(user_id, client_user_id); - assert_eq!(access_token, "the-token"); - cx.spawn({ - let server = server.clone(); - move |cx| async move { server.connect(&cx).await } - }) - } - }, - ); + }) + }) + .override_establish_connection({ + let server = result.clone(); + move |credentials, cx| { + assert_eq!(credentials.user_id, client_user_id); + assert_eq!(credentials.access_token, "the-token"); + cx.spawn({ + let server = server.clone(); + move |cx| async move { server.connect(&cx).await } + }) + } + }); client .authenticate_and_connect(&cx.to_async()) diff --git a/zed/src/user.rs b/zed/src/user.rs index 3a119474e1e7397b5c5352ffb13cddd27f98ae9c..06aab321934dd1ffb985430157b7cfd02385dbc7 100644 --- a/zed/src/user.rs +++ b/zed/src/user.rs @@ -51,8 +51,8 @@ impl UserStore { let mut status = rpc.status(); while let Some(status) = status.recv().await { match status { - Status::Connected { user_id, .. } => { - if let Some(this) = this.upgrade() { + Status::Connected { .. } => { + if let Some((this, user_id)) = this.upgrade().zip(rpc.user_id()) { current_user_tx .send(this.fetch_user(user_id).log_err().await) .await