Introduce test-only APIs for configuring how Client reconnects

Max Brunsfeld created

Change summary

zed/src/channel.rs |  4 +-
zed/src/rpc.rs     | 62 +++++++++++++++++++++++++++++++++++++++++++----
zed/src/test.rs    | 30 +++++++++++++++++++++-
3 files changed, 86 insertions(+), 10 deletions(-)

Detailed changes

zed/src/channel.rs 🔗

@@ -450,8 +450,8 @@ mod tests {
     #[gpui::test]
     async fn test_channel_messages(mut cx: TestAppContext) {
         let user_id = 5;
-        let client = Client::new();
-        let server = FakeServer::for_client(user_id, &client, &cx).await;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
         let user_store = Arc::new(UserStore::new(client.clone()));
 
         let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));

zed/src/rpc.rs 🔗

@@ -28,12 +28,21 @@ lazy_static! {
 pub struct Client {
     peer: Arc<Peer>,
     state: RwLock<ClientState>,
+    auth_callback: Option<
+        Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
+    >,
+    connect_callback: Option<
+        Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+    >,
 }
 
 #[derive(Copy, Clone, Debug)]
 pub enum Status {
     Disconnected,
-    Connecting,
+    Authenticating,
+    Connecting {
+        user_id: u64,
+    },
     ConnectionError,
     Connected {
         connection_id: ConnectionId,
@@ -94,9 +103,24 @@ impl Client {
         Arc::new(Self {
             peer: Peer::new(),
             state: Default::default(),
+            auth_callback: None,
+            connect_callback: None,
         })
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn set_login_and_connect_callbacks<Login, Connect>(
+        &mut self,
+        login: Login,
+        connect: Connect,
+    ) where
+        Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
+        Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+    {
+        self.auth_callback = Some(Box::new(login));
+        self.connect_callback = Some(Box::new(connect));
+    }
+
     pub fn status(&self) -> watch::Receiver<Status> {
         self.state.read().status.1.clone()
     }
@@ -192,11 +216,13 @@ impl Client {
     ) -> anyhow::Result<()> {
         if matches!(
             *self.status().borrow(),
-            Status::Connecting { .. } | Status::Connected { .. }
+            Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
         ) {
             return Ok(());
         }
 
+        self.set_status(Status::Authenticating, cx);
+
         let (user_id, access_token) = match self.authenticate(&cx).await {
             Ok(result) => result,
             Err(err) => {
@@ -205,7 +231,7 @@ impl Client {
             }
         };
 
-        self.set_status(Status::Connecting, cx);
+        self.set_status(Status::Connecting { user_id }, cx);
 
         let conn = match self.connect(user_id, &access_token, cx).await {
             Ok(conn) => conn,
@@ -285,11 +311,32 @@ impl Client {
         Ok(())
     }
 
+    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+        if let Some(callback) = self.auth_callback.as_ref() {
+            callback(cx)
+        } else {
+            self.authenticate_with_browser(cx)
+        }
+    }
+
     fn connect(
         self: &Arc<Self>,
         user_id: u64,
         access_token: &str,
         cx: &AsyncAppContext,
+    ) -> Task<Result<Conn>> {
+        if let Some(callback) = self.connect_callback.as_ref() {
+            callback(user_id, access_token, cx)
+        } else {
+            self.connect_with_websocket(user_id, access_token, cx)
+        }
+    }
+
+    fn connect_with_websocket(
+        self: &Arc<Self>,
+        user_id: u64,
+        access_token: &str,
+        cx: &AsyncAppContext,
     ) -> Task<Result<Conn>> {
         let request =
             Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
@@ -314,7 +361,10 @@ impl Client {
         })
     }
 
-    pub fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+    pub fn authenticate_with_browser(
+        self: &Arc<Self>,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<(u64, String)>> {
         let platform = cx.platform();
         let executor = cx.background();
         executor.clone().spawn(async move {
@@ -488,8 +538,8 @@ mod tests {
     #[gpui::test(iterations = 10)]
     async fn test_heartbeat(cx: TestAppContext) {
         let user_id = 5;
-        let client = Client::new();
-        let server = FakeServer::for_client(user_id, &client, &cx).await;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
 
         cx.foreground().advance_clock(Duration::from_secs(10));
         let ping = server.receive::<proto::Ping>().await.unwrap();

zed/src/test.rs 🔗

@@ -203,16 +203,42 @@ pub struct FakeServer {
 }
 
 impl FakeServer {
-    pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Arc<Self> {
+    pub async fn for_client(
+        client_user_id: u64,
+        client: &mut Arc<Client>,
+        cx: &TestAppContext,
+    ) -> Arc<Self> {
         let result = Arc::new(Self {
             peer: Peer::new(),
             incoming: Default::default(),
             connection_id: Default::default(),
         });
 
+        Arc::get_mut(client)
+            .unwrap()
+            .set_login_and_connect_callbacks(
+                move |cx| {
+                    cx.spawn(|_| async move {
+                        let access_token = "the-token".to_string();
+                        Ok((client_user_id, access_token))
+                    })
+                },
+                {
+                    let server = result.clone();
+                    move |user_id, access_token, cx| {
+                        assert_eq!(user_id, client_user_id);
+                        assert_eq!(access_token, "the-token");
+                        cx.spawn({
+                            let server = server.clone();
+                            move |cx| async move { Ok(server.connect(&cx).await) }
+                        })
+                    }
+                },
+            );
+
         let conn = result.connect(&cx.to_async()).await;
         client
-            .set_connection(user_id, conn, &cx.to_async())
+            .set_connection(client_user_id, conn, &cx.to_async())
             .await
             .unwrap();
         result