Cache credentials in memory separately from connection status

Nathan Sobo and Max Brunsfeld created

This prevents us from re-prompting for keychain access when we retry connections after the connection is lost.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

server/src/rpc.rs |  67 +++++++++++-----------
zed/src/rpc.rs    | 146 +++++++++++++++++++++++++++---------------------
zed/src/test.rs   |  39 ++++++------
zed/src/user.rs   |   4 
4 files changed, 139 insertions(+), 117 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -1023,7 +1023,7 @@ mod tests {
         editor::{Editor, Insert},
         fs::{FakeFs, Fs as _},
         language::LanguageRegistry,
-        rpc::{self, Client},
+        rpc::{self, Client, Credentials},
         settings,
         test::FakeHttpClient,
         user::UserStore,
@@ -1922,39 +1922,40 @@ mod tests {
             let forbid_connections = self.forbid_connections.clone();
             Arc::get_mut(&mut client)
                 .unwrap()
-                .set_login_and_connect_callbacks(
-                    move |cx| {
-                        cx.spawn(|_| async move {
-                            let access_token = "the-token".to_string();
-                            Ok((client_user_id.0 as u64, access_token))
+                .override_authenticate(move |cx| {
+                    cx.spawn(|_| async move {
+                        let access_token = "the-token".to_string();
+                        Ok(Credentials {
+                            user_id: client_user_id.0 as u64,
+                            access_token,
                         })
-                    },
-                    move |user_id, access_token, cx| {
-                        assert_eq!(user_id, client_user_id.0 as u64);
-                        assert_eq!(access_token, "the-token");
-
-                        let server = server.clone();
-                        let connection_killers = connection_killers.clone();
-                        let forbid_connections = forbid_connections.clone();
-                        let client_name = client_name.clone();
-                        cx.spawn(move |cx| async move {
-                            if forbid_connections.load(SeqCst) {
-                                Err(anyhow!("server is forbidding connections"))
-                            } else {
-                                let (client_conn, server_conn, kill_conn) = Conn::in_memory();
-                                connection_killers.lock().insert(client_user_id, kill_conn);
-                                cx.background()
-                                    .spawn(server.handle_connection(
-                                        server_conn,
-                                        client_name,
-                                        client_user_id,
-                                    ))
-                                    .detach();
-                                Ok(client_conn)
-                            }
-                        })
-                    },
-                );
+                    })
+                })
+                .override_establish_connection(move |credentials, cx| {
+                    assert_eq!(credentials.user_id, client_user_id.0 as u64);
+                    assert_eq!(credentials.access_token, "the-token");
+
+                    let server = server.clone();
+                    let connection_killers = connection_killers.clone();
+                    let forbid_connections = forbid_connections.clone();
+                    let client_name = client_name.clone();
+                    cx.spawn(move |cx| async move {
+                        if forbid_connections.load(SeqCst) {
+                            Err(anyhow!("server is forbidding connections"))
+                        } else {
+                            let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+                            connection_killers.lock().insert(client_user_id, kill_conn);
+                            cx.background()
+                                .spawn(server.handle_connection(
+                                    server_conn,
+                                    client_name,
+                                    client_user_id,
+                                ))
+                                .detach();
+                            Ok(client_conn)
+                        }
+                    })
+                });
 
             client
                 .authenticate_and_connect(&cx.to_async())

zed/src/rpc.rs 🔗

@@ -29,11 +29,10 @@ lazy_static! {
 pub struct Client {
     peer: Arc<Peer>,
     state: RwLock<ClientState>,
-    auth_callback: Option<
-        Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
-    >,
-    connect_callback: Option<
-        Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+    authenticate:
+        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
+    establish_connection: Option<
+        Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>,
     >,
 }
 
@@ -41,25 +40,17 @@ pub struct Client {
 pub enum Status {
     SignedOut,
     Authenticating,
-    Connecting {
-        user_id: u64,
-    },
+    Connecting,
     ConnectionError,
-    Connected {
-        connection_id: ConnectionId,
-        user_id: u64,
-    },
+    Connected { connection_id: ConnectionId },
     ConnectionLost,
     Reauthenticating,
-    Reconnecting {
-        user_id: u64,
-    },
-    ReconnectionError {
-        next_reconnection: Instant,
-    },
+    Reconnecting,
+    ReconnectionError { next_reconnection: Instant },
 }
 
 struct ClientState {
+    credentials: Option<Credentials>,
     status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
@@ -70,9 +61,16 @@ struct ClientState {
     heartbeat_interval: Duration,
 }
 
+#[derive(Clone)]
+pub struct Credentials {
+    pub user_id: u64,
+    pub access_token: String,
+}
+
 impl Default for ClientState {
     fn default() -> Self {
         Self {
+            credentials: None,
             status: watch::channel_with(Status::SignedOut),
             entity_id_extractors: Default::default(),
             model_handlers: Default::default(),
@@ -107,22 +105,35 @@ impl Client {
         Arc::new(Self {
             peer: Peer::new(),
             state: Default::default(),
-            auth_callback: None,
-            connect_callback: None,
+            authenticate: None,
+            establish_connection: None,
         })
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn set_login_and_connect_callbacks<Login, Connect>(
-        &mut self,
-        login: Login,
-        connect: Connect,
-    ) where
-        Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
-        Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
+    where
+        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
     {
-        self.auth_callback = Some(Box::new(login));
-        self.connect_callback = Some(Box::new(connect));
+        self.authenticate = Some(Box::new(authenticate));
+        self
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
+    where
+        F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>,
+    {
+        self.establish_connection = Some(Box::new(connect));
+        self
+    }
+
+    pub fn user_id(&self) -> Option<u64> {
+        self.state
+            .read()
+            .credentials
+            .as_ref()
+            .map(|credentials| credentials.user_id)
     }
 
     pub fn status(&self) -> watch::Receiver<Status> {
@@ -249,23 +260,31 @@ impl Client {
             self.set_status(Status::Reauthenticating, cx)
         }
 
-        let (user_id, access_token) = match self.authenticate(&cx).await {
-            Ok(result) => result,
-            Err(err) => {
-                self.set_status(Status::ConnectionError, cx);
-                return Err(err);
-            }
+        let credentials = self.state.read().credentials.clone();
+        let credentials = if let Some(credentials) = credentials {
+            credentials
+        } else {
+            let credentials = match self.authenticate(&cx).await {
+                Ok(credentials) => credentials,
+                Err(err) => {
+                    self.set_status(Status::ConnectionError, cx);
+                    return Err(err);
+                }
+            };
+            self.state.write().credentials = Some(credentials.clone());
+            credentials
         };
 
         if was_disconnected {
-            self.set_status(Status::Connecting { user_id }, cx);
+            self.set_status(Status::Connecting, cx);
         } else {
-            self.set_status(Status::Reconnecting { user_id }, cx);
+            self.set_status(Status::Reconnecting, cx);
         }
-        match self.connect(user_id, &access_token, cx).await {
+
+        match self.establish_connection(&credentials, cx).await {
             Ok(conn) => {
                 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-                self.set_connection(user_id, conn, cx).await;
+                self.set_connection(conn, cx).await;
                 Ok(())
             }
             Err(err) => {
@@ -275,7 +294,7 @@ impl Client {
         }
     }
 
-    async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
+    async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) {
         let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
         cx.foreground()
             .spawn({
@@ -310,13 +329,7 @@ impl Client {
             })
             .detach();
 
-        self.set_status(
-            Status::Connected {
-                connection_id,
-                user_id,
-            },
-            cx,
-        );
+        self.set_status(Status::Connected { connection_id }, cx);
 
         let handle_io = cx.background().spawn(handle_io);
         let this = self.clone();
@@ -334,35 +347,35 @@ impl Client {
             .detach();
     }
 
-    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
-        if let Some(callback) = self.auth_callback.as_ref() {
+    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
+        if let Some(callback) = self.authenticate.as_ref() {
             callback(cx)
         } else {
             self.authenticate_with_browser(cx)
         }
     }
 
-    fn connect(
+    fn establish_connection(
         self: &Arc<Self>,
-        user_id: u64,
-        access_token: &str,
+        credentials: &Credentials,
         cx: &AsyncAppContext,
     ) -> Task<Result<Conn>> {
-        if let Some(callback) = self.connect_callback.as_ref() {
-            callback(user_id, access_token, cx)
+        if let Some(callback) = self.establish_connection.as_ref() {
+            callback(credentials, cx)
         } else {
-            self.connect_with_websocket(user_id, access_token, cx)
+            self.establish_websocket_connection(credentials, cx)
         }
     }
 
-    fn connect_with_websocket(
+    fn establish_websocket_connection(
         self: &Arc<Self>,
-        user_id: u64,
-        access_token: &str,
+        credentials: &Credentials,
         cx: &AsyncAppContext,
     ) -> Task<Result<Conn>> {
-        let request =
-            Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+        let request = Request::builder().header(
+            "Authorization",
+            format!("{} {}", credentials.user_id, credentials.access_token),
+        );
         cx.background().spawn(async move {
             if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
                 let stream = smol::net::TcpStream::connect(host).await?;
@@ -387,7 +400,7 @@ impl Client {
     pub fn authenticate_with_browser(
         self: &Arc<Self>,
         cx: &AsyncAppContext,
-    ) -> Task<Result<(u64, String)>> {
+    ) -> Task<Result<Credentials>> {
         let platform = cx.platform();
         let executor = cx.background();
         executor.clone().spawn(async move {
@@ -397,7 +410,10 @@ impl Client {
                 .flatten()
             {
                 log::info!("already signed in. user_id: {}", user_id);
-                return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
+                return Ok(Credentials {
+                    user_id: user_id.parse()?,
+                    access_token: String::from_utf8(access_token).unwrap(),
+                });
             }
 
             // Generate a pair of asymmetric encryption keys. The public key will be used by the
@@ -463,7 +479,11 @@ impl Client {
             platform
                 .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
                 .log_err();
-            Ok((user_id.parse()?, access_token))
+
+            Ok(Credentials {
+                user_id: user_id.parse()?,
+                access_token,
+            })
         })
     }
 

zed/src/test.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     fs::RealFs,
     http::{HttpClient, Request, Response, ServerResponse},
     language::LanguageRegistry,
-    rpc::{self, Client},
+    rpc::{self, Client, Credentials},
     settings::{self, ThemeRegistry},
     time::ReplicaId,
     user::UserStore,
@@ -226,25 +226,26 @@ impl FakeServer {
 
         Arc::get_mut(client)
             .unwrap()
-            .set_login_and_connect_callbacks(
-                move |cx| {
-                    cx.spawn(|_| async move {
-                        let access_token = "the-token".to_string();
-                        Ok((client_user_id, access_token))
+            .override_authenticate(move |cx| {
+                cx.spawn(|_| async move {
+                    let access_token = "the-token".to_string();
+                    Ok(Credentials {
+                        user_id: client_user_id,
+                        access_token,
                     })
-                },
-                {
-                    let server = result.clone();
-                    move |user_id, access_token, cx| {
-                        assert_eq!(user_id, client_user_id);
-                        assert_eq!(access_token, "the-token");
-                        cx.spawn({
-                            let server = server.clone();
-                            move |cx| async move { server.connect(&cx).await }
-                        })
-                    }
-                },
-            );
+                })
+            })
+            .override_establish_connection({
+                let server = result.clone();
+                move |credentials, cx| {
+                    assert_eq!(credentials.user_id, client_user_id);
+                    assert_eq!(credentials.access_token, "the-token");
+                    cx.spawn({
+                        let server = server.clone();
+                        move |cx| async move { server.connect(&cx).await }
+                    })
+                }
+            });
 
         client
             .authenticate_and_connect(&cx.to_async())

zed/src/user.rs 🔗

@@ -51,8 +51,8 @@ impl UserStore {
                 let mut status = rpc.status();
                 while let Some(status) = status.recv().await {
                     match status {
-                        Status::Connected { user_id, .. } => {
-                            if let Some(this) = this.upgrade() {
+                        Status::Connected { .. } => {
+                            if let Some((this, user_id)) = this.upgrade().zip(rpc.user_id()) {
                                 current_user_tx
                                     .send(this.fetch_user(user_id).log_err().await)
                                     .await