test.rs

  1use super::*;
  2use std::sync::atomic::Ordering::SeqCst;
  3
  4use super::Client;
  5use gpui::TestAppContext;
  6use parking_lot::Mutex;
  7use postage::{mpsc, prelude::Stream};
  8use std::sync::{
  9    atomic::{AtomicBool, AtomicUsize},
 10    Arc,
 11};
 12use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 13
 14pub struct FakeServer {
 15    peer: Arc<Peer>,
 16    incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
 17    connection_id: Mutex<Option<ConnectionId>>,
 18    forbid_connections: AtomicBool,
 19    auth_count: AtomicUsize,
 20    access_token: AtomicUsize,
 21    user_id: u64,
 22}
 23
 24impl FakeServer {
 25    pub async fn for_client(
 26        client_user_id: u64,
 27        client: &mut Arc<Client>,
 28        cx: &TestAppContext,
 29    ) -> Arc<Self> {
 30        let server = Arc::new(Self {
 31            peer: Peer::new(),
 32            incoming: Default::default(),
 33            connection_id: Default::default(),
 34            forbid_connections: Default::default(),
 35            auth_count: Default::default(),
 36            access_token: Default::default(),
 37            user_id: client_user_id,
 38        });
 39
 40        Arc::get_mut(client)
 41            .unwrap()
 42            .override_authenticate({
 43                let server = server.clone();
 44                move |cx| {
 45                    server.auth_count.fetch_add(1, SeqCst);
 46                    let access_token = server.access_token.load(SeqCst).to_string();
 47                    cx.spawn(move |_| async move {
 48                        Ok(Credentials {
 49                            user_id: client_user_id,
 50                            access_token,
 51                        })
 52                    })
 53                }
 54            })
 55            .override_establish_connection({
 56                let server = server.clone();
 57                move |credentials, cx| {
 58                    let credentials = credentials.clone();
 59                    cx.spawn({
 60                        let server = server.clone();
 61                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
 62                    })
 63                }
 64            });
 65
 66        client
 67            .authenticate_and_connect(&cx.to_async())
 68            .await
 69            .unwrap();
 70        server
 71    }
 72
 73    pub async fn disconnect(&self) {
 74        self.peer.disconnect(self.connection_id()).await;
 75        self.connection_id.lock().take();
 76        self.incoming.lock().take();
 77    }
 78
 79    async fn establish_connection(
 80        &self,
 81        credentials: &Credentials,
 82        cx: &AsyncAppContext,
 83    ) -> Result<Connection, EstablishConnectionError> {
 84        assert_eq!(credentials.user_id, self.user_id);
 85
 86        if self.forbid_connections.load(SeqCst) {
 87            Err(EstablishConnectionError::Other(anyhow!(
 88                "server is forbidding connections"
 89            )))?
 90        }
 91
 92        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
 93            Err(EstablishConnectionError::Unauthorized)?
 94        }
 95
 96        let (client_conn, server_conn, _) = Connection::in_memory();
 97        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
 98        cx.background().spawn(io).detach();
 99        *self.incoming.lock() = Some(incoming);
100        *self.connection_id.lock() = Some(connection_id);
101        Ok(client_conn)
102    }
103
104    pub fn auth_count(&self) -> usize {
105        self.auth_count.load(SeqCst)
106    }
107
108    pub fn roll_access_token(&self) {
109        self.access_token.fetch_add(1, SeqCst);
110    }
111
112    pub fn forbid_connections(&self) {
113        self.forbid_connections.store(true, SeqCst);
114    }
115
116    pub fn allow_connections(&self) {
117        self.forbid_connections.store(false, SeqCst);
118    }
119
120    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
121        self.peer.send(self.connection_id(), message).await.unwrap();
122    }
123
124    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
125        let message = self
126            .incoming
127            .lock()
128            .as_mut()
129            .expect("not connected")
130            .recv()
131            .await
132            .ok_or_else(|| anyhow!("other half hung up"))?;
133        let type_name = message.payload_type_name();
134        Ok(*message
135            .into_any()
136            .downcast::<TypedEnvelope<M>>()
137            .unwrap_or_else(|_| {
138                panic!(
139                    "fake server received unexpected message type: {:?}",
140                    type_name
141                );
142            }))
143    }
144
145    pub async fn respond<T: proto::RequestMessage>(
146        &self,
147        receipt: Receipt<T>,
148        response: T::Response,
149    ) {
150        self.peer.respond(receipt, response).await.unwrap()
151    }
152
153    fn connection_id(&self) -> ConnectionId {
154        self.connection_id.lock().expect("not connected")
155    }
156}