test.rs

  1use super::Client;
  2use super::*;
  3use crate::http::{HttpClient, Request, Response, ServerResponse};
  4use futures::{future::BoxFuture, Future};
  5use gpui::{ModelHandle, TestAppContext};
  6use parking_lot::Mutex;
  7use postage::{mpsc, prelude::Stream};
  8use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
  9use std::fmt;
 10use std::sync::atomic::Ordering::SeqCst;
 11use std::sync::{
 12    atomic::{AtomicBool, AtomicUsize},
 13    Arc,
 14};
 15
 16pub struct FakeServer {
 17    peer: Arc<Peer>,
 18    incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
 19    connection_id: Mutex<Option<ConnectionId>>,
 20    forbid_connections: AtomicBool,
 21    auth_count: AtomicUsize,
 22    access_token: AtomicUsize,
 23    user_id: u64,
 24}
 25
 26impl FakeServer {
 27    pub async fn for_client(
 28        client_user_id: u64,
 29        client: &mut Arc<Client>,
 30        cx: &TestAppContext,
 31    ) -> Arc<Self> {
 32        let server = Arc::new(Self {
 33            peer: Peer::new(),
 34            incoming: Default::default(),
 35            connection_id: Default::default(),
 36            forbid_connections: Default::default(),
 37            auth_count: Default::default(),
 38            access_token: Default::default(),
 39            user_id: client_user_id,
 40        });
 41
 42        Arc::get_mut(client)
 43            .unwrap()
 44            .override_authenticate({
 45                let server = server.clone();
 46                move |cx| {
 47                    server.auth_count.fetch_add(1, SeqCst);
 48                    let access_token = server.access_token.load(SeqCst).to_string();
 49                    cx.spawn(move |_| async move {
 50                        Ok(Credentials {
 51                            user_id: client_user_id,
 52                            access_token,
 53                        })
 54                    })
 55                }
 56            })
 57            .override_establish_connection({
 58                let server = server.clone();
 59                move |credentials, cx| {
 60                    let credentials = credentials.clone();
 61                    cx.spawn({
 62                        let server = server.clone();
 63                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
 64                    })
 65                }
 66            });
 67
 68        client
 69            .authenticate_and_connect(&cx.to_async())
 70            .await
 71            .unwrap();
 72        server
 73    }
 74
 75    pub async fn disconnect(&self) {
 76        self.peer.disconnect(self.connection_id()).await;
 77        self.connection_id.lock().take();
 78        self.incoming.lock().take();
 79    }
 80
 81    async fn establish_connection(
 82        &self,
 83        credentials: &Credentials,
 84        cx: &AsyncAppContext,
 85    ) -> Result<Connection, EstablishConnectionError> {
 86        assert_eq!(credentials.user_id, self.user_id);
 87
 88        if self.forbid_connections.load(SeqCst) {
 89            Err(EstablishConnectionError::Other(anyhow!(
 90                "server is forbidding connections"
 91            )))?
 92        }
 93
 94        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
 95            Err(EstablishConnectionError::Unauthorized)?
 96        }
 97
 98        let (client_conn, server_conn, _) = Connection::in_memory();
 99        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
100        cx.background().spawn(io).detach();
101        *self.incoming.lock() = Some(incoming);
102        *self.connection_id.lock() = Some(connection_id);
103        Ok(client_conn)
104    }
105
106    pub fn auth_count(&self) -> usize {
107        self.auth_count.load(SeqCst)
108    }
109
110    pub fn roll_access_token(&self) {
111        self.access_token.fetch_add(1, SeqCst);
112    }
113
114    pub fn forbid_connections(&self) {
115        self.forbid_connections.store(true, SeqCst);
116    }
117
118    pub fn allow_connections(&self) {
119        self.forbid_connections.store(false, SeqCst);
120    }
121
122    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
123        self.peer.send(self.connection_id(), message).await.unwrap();
124    }
125
126    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
127        let message = self
128            .incoming
129            .lock()
130            .as_mut()
131            .expect("not connected")
132            .recv()
133            .await
134            .ok_or_else(|| anyhow!("other half hung up"))?;
135        let type_name = message.payload_type_name();
136        Ok(*message
137            .into_any()
138            .downcast::<TypedEnvelope<M>>()
139            .unwrap_or_else(|_| {
140                panic!(
141                    "fake server received unexpected message type: {:?}",
142                    type_name
143                );
144            }))
145    }
146
147    pub async fn respond<T: proto::RequestMessage>(
148        &self,
149        receipt: Receipt<T>,
150        response: T::Response,
151    ) {
152        self.peer.respond(receipt, response).await.unwrap()
153    }
154
155    fn connection_id(&self) -> ConnectionId {
156        self.connection_id.lock().expect("not connected")
157    }
158
159    pub async fn build_user_store(
160        &self,
161        client: Arc<Client>,
162        cx: &mut TestAppContext,
163    ) -> ModelHandle<UserStore> {
164        let http_client = FakeHttpClient::with_404_response();
165        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
166        assert_eq!(
167            self.receive::<proto::GetUsers>()
168                .await
169                .unwrap()
170                .payload
171                .user_ids,
172            &[self.user_id]
173        );
174        user_store
175    }
176}
177
178pub struct FakeHttpClient {
179    handler:
180        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
181}
182
183impl FakeHttpClient {
184    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
185    where
186        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
187        F: 'static + Send + Sync + Fn(Request) -> Fut,
188    {
189        Arc::new(Self {
190            handler: Box::new(move |req| Box::pin(handler(req))),
191        })
192    }
193
194    pub fn with_404_response() -> Arc<dyn HttpClient> {
195        Self::new(|_| async move { Ok(ServerResponse::new(404)) })
196    }
197}
198
199impl fmt::Debug for FakeHttpClient {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        f.debug_struct("FakeHttpClient").finish()
202    }
203}
204
205impl HttpClient for FakeHttpClient {
206    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
207        let future = (self.handler)(req);
208        Box::pin(async move { future.await.map(Into::into) })
209    }
210}