Fix chat channel unit test

Max Brunsfeld created

Also, improve error in tests when FakeServer never receives a request,
using the new `start_waiting` method on the DeterministicExecutor.

Change summary

crates/client/src/channel.rs |   4 +
crates/client/src/test.rs    | 126 +++++++++++++++++++-------------------
2 files changed, 67 insertions(+), 63 deletions(-)

Detailed changes

crates/client/src/channel.rs 🔗

@@ -598,10 +598,14 @@ mod tests {
 
     #[gpui::test]
     async fn test_channel_messages(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
         let user_id = 5;
         let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
         let mut client = Client::new(http_client.clone());
         let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        Channel::init(&client);
         let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx));
 
         let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));

crates/client/src/test.rs 🔗

@@ -1,25 +1,28 @@
-use super::Client;
-use super::*;
-use crate::http::{HttpClient, Request, Response, ServerResponse};
+use crate::{
+    http::{HttpClient, Request, Response, ServerResponse},
+    Client, Connection, Credentials, EstablishConnectionError, UserStore,
+};
+use anyhow::{anyhow, Result};
 use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
-use gpui::{ModelHandle, TestAppContext};
+use gpui::{executor, ModelHandle, TestAppContext};
 use parking_lot::Mutex;
 use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
-use std::fmt;
-use std::sync::atomic::Ordering::SeqCst;
-use std::sync::{
-    atomic::{AtomicBool, AtomicUsize},
-    Arc,
-};
+use std::{fmt, rc::Rc, sync::Arc};
 
 pub struct FakeServer {
     peer: Arc<Peer>,
-    incoming: Mutex<Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>>,
-    connection_id: Mutex<Option<ConnectionId>>,
-    forbid_connections: AtomicBool,
-    auth_count: AtomicUsize,
-    access_token: AtomicUsize,
+    state: Arc<Mutex<FakeServerState>>,
     user_id: u64,
+    executor: Rc<executor::Foreground>,
+}
+
+#[derive(Default)]
+struct FakeServerState {
+    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
+    connection_id: Option<ConnectionId>,
+    forbid_connections: bool,
+    auth_count: usize,
+    access_token: usize,
 }
 
 impl FakeServer {
@@ -27,24 +30,22 @@ impl FakeServer {
         client_user_id: u64,
         client: &mut Arc<Client>,
         cx: &TestAppContext,
-    ) -> Arc<Self> {
-        let server = Arc::new(Self {
+    ) -> Self {
+        let server = Self {
             peer: Peer::new(),
-            incoming: Default::default(),
-            connection_id: Default::default(),
-            forbid_connections: Default::default(),
-            auth_count: Default::default(),
-            access_token: Default::default(),
+            state: Default::default(),
             user_id: client_user_id,
-        });
+            executor: cx.foreground(),
+        };
 
         Arc::get_mut(client)
             .unwrap()
             .override_authenticate({
-                let server = server.clone();
+                let state = server.state.clone();
                 move |cx| {
-                    server.auth_count.fetch_add(1, SeqCst);
-                    let access_token = server.access_token.load(SeqCst).to_string();
+                    let mut state = state.lock();
+                    state.auth_count += 1;
+                    let access_token = state.access_token.to_string();
                     cx.spawn(move |_| async move {
                         Ok(Credentials {
                             user_id: client_user_id,
@@ -54,12 +55,32 @@ impl FakeServer {
                 }
             })
             .override_establish_connection({
-                let server = server.clone();
+                let peer = server.peer.clone();
+                let state = server.state.clone();
                 move |credentials, cx| {
+                    let peer = peer.clone();
+                    let state = state.clone();
                     let credentials = credentials.clone();
-                    cx.spawn({
-                        let server = server.clone();
-                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
+                    cx.spawn(move |cx| async move {
+                        assert_eq!(credentials.user_id, client_user_id);
+
+                        if state.lock().forbid_connections {
+                            Err(EstablishConnectionError::Other(anyhow!(
+                                "server is forbidding connections"
+                            )))?
+                        }
+
+                        if credentials.access_token != state.lock().access_token.to_string() {
+                            Err(EstablishConnectionError::Unauthorized)?
+                        }
+
+                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
+                        let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+                        cx.background().spawn(io).detach();
+                        let mut state = state.lock();
+                        state.connection_id = Some(connection_id);
+                        state.incoming = Some(incoming);
+                        Ok(client_conn)
                     })
                 }
             });
@@ -73,49 +94,25 @@ impl FakeServer {
 
     pub fn disconnect(&self) {
         self.peer.disconnect(self.connection_id());
-        self.connection_id.lock().take();
-        self.incoming.lock().take();
-    }
-
-    async fn establish_connection(
-        &self,
-        credentials: &Credentials,
-        cx: &AsyncAppContext,
-    ) -> Result<Connection, EstablishConnectionError> {
-        assert_eq!(credentials.user_id, self.user_id);
-
-        if self.forbid_connections.load(SeqCst) {
-            Err(EstablishConnectionError::Other(anyhow!(
-                "server is forbidding connections"
-            )))?
-        }
-
-        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
-            Err(EstablishConnectionError::Unauthorized)?
-        }
-
-        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
-        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
-        cx.background().spawn(io).detach();
-        *self.incoming.lock() = Some(incoming);
-        *self.connection_id.lock() = Some(connection_id);
-        Ok(client_conn)
+        let mut state = self.state.lock();
+        state.connection_id.take();
+        state.incoming.take();
     }
 
     pub fn auth_count(&self) -> usize {
-        self.auth_count.load(SeqCst)
+        self.state.lock().auth_count
     }
 
     pub fn roll_access_token(&self) {
-        self.access_token.fetch_add(1, SeqCst);
+        self.state.lock().access_token += 1;
     }
 
     pub fn forbid_connections(&self) {
-        self.forbid_connections.store(true, SeqCst);
+        self.state.lock().forbid_connections = true;
     }
 
     pub fn allow_connections(&self) {
-        self.forbid_connections.store(false, SeqCst);
+        self.state.lock().forbid_connections = false;
     }
 
     pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
@@ -123,14 +120,17 @@ impl FakeServer {
     }
 
     pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+        self.executor.start_waiting();
         let message = self
-            .incoming
+            .state
             .lock()
+            .incoming
             .as_mut()
             .expect("not connected")
             .next()
             .await
             .ok_or_else(|| anyhow!("other half hung up"))?;
+        self.executor.finish_waiting();
         let type_name = message.payload_type_name();
         Ok(*message
             .into_any()
@@ -152,7 +152,7 @@ impl FakeServer {
     }
 
     fn connection_id(&self) -> ConnectionId {
-        self.connection_id.lock().expect("not connected")
+        self.state.lock().connection_id.expect("not connected")
     }
 
     pub async fn build_user_store(