test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 10use std::{fmt, rc::Rc, sync::Arc};
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: Rc<executor::Foreground>,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.foreground(),
 39        };
 40
 41        client
 42            .override_authenticate({
 43                let state = Arc::downgrade(&server.state);
 44                move |cx| {
 45                    let state = state.clone();
 46                    cx.spawn(move |_| async move {
 47                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 48                        let mut state = state.lock();
 49                        state.auth_count += 1;
 50                        let access_token = state.access_token.to_string();
 51                        Ok(Credentials {
 52                            user_id: client_user_id,
 53                            access_token,
 54                        })
 55                    })
 56                }
 57            })
 58            .override_establish_connection({
 59                let peer = Arc::downgrade(&server.peer).clone();
 60                let state = Arc::downgrade(&server.state);
 61                move |credentials, cx| {
 62                    let peer = peer.clone();
 63                    let state = state.clone();
 64                    let credentials = credentials.clone();
 65                    cx.spawn(move |cx| async move {
 66                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 67                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 68                        if state.lock().forbid_connections {
 69                            Err(EstablishConnectionError::Other(anyhow!(
 70                                "server is forbidding connections"
 71                            )))?
 72                        }
 73
 74                        assert_eq!(credentials.user_id, client_user_id);
 75
 76                        if credentials.access_token != state.lock().access_token.to_string() {
 77                            Err(EstablishConnectionError::Unauthorized)?
 78                        }
 79
 80                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 81                        let (connection_id, io, incoming) =
 82                            peer.add_test_connection(server_conn, cx.background()).await;
 83                        cx.background().spawn(io).detach();
 84                        let mut state = state.lock();
 85                        state.connection_id = Some(connection_id);
 86                        state.incoming = Some(incoming);
 87                        Ok(client_conn)
 88                    })
 89                }
 90            });
 91
 92        client
 93            .authenticate_and_connect(false, &cx.to_async())
 94            .await
 95            .unwrap();
 96        server
 97    }
 98
 99    pub fn disconnect(&self) {
100        self.peer.disconnect(self.connection_id());
101        let mut state = self.state.lock();
102        state.connection_id.take();
103        state.incoming.take();
104    }
105
106    pub fn auth_count(&self) -> usize {
107        self.state.lock().auth_count
108    }
109
110    pub fn roll_access_token(&self) {
111        self.state.lock().access_token += 1;
112    }
113
114    pub fn forbid_connections(&self) {
115        self.state.lock().forbid_connections = true;
116    }
117
118    pub fn allow_connections(&self) {
119        self.state.lock().forbid_connections = false;
120    }
121
122    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
123        self.peer.send(self.connection_id(), message).unwrap();
124    }
125
126    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
127        self.executor.start_waiting();
128        let message = self
129            .state
130            .lock()
131            .incoming
132            .as_mut()
133            .expect("not connected")
134            .next()
135            .await
136            .ok_or_else(|| anyhow!("other half hung up"))?;
137        self.executor.finish_waiting();
138        let type_name = message.payload_type_name();
139        Ok(*message
140            .into_any()
141            .downcast::<TypedEnvelope<M>>()
142            .unwrap_or_else(|_| {
143                panic!(
144                    "fake server received unexpected message type: {:?}",
145                    type_name
146                );
147            }))
148    }
149
150    pub async fn respond<T: proto::RequestMessage>(
151        &self,
152        receipt: Receipt<T>,
153        response: T::Response,
154    ) {
155        self.peer.respond(receipt, response).unwrap()
156    }
157
158    fn connection_id(&self) -> ConnectionId {
159        self.state.lock().connection_id.expect("not connected")
160    }
161
162    pub async fn build_user_store(
163        &self,
164        client: Arc<Client>,
165        cx: &mut TestAppContext,
166    ) -> ModelHandle<UserStore> {
167        let http_client = FakeHttpClient::with_404_response();
168        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
169        assert_eq!(
170            self.receive::<proto::GetUsers>()
171                .await
172                .unwrap()
173                .payload
174                .user_ids,
175            &[self.user_id]
176        );
177        user_store
178    }
179}
180
181impl Drop for FakeServer {
182    fn drop(&mut self) {
183        self.disconnect();
184    }
185}
186
187pub struct FakeHttpClient {
188    handler: Box<
189        dyn 'static
190            + Send
191            + Sync
192            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
193    >,
194}
195
196impl FakeHttpClient {
197    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
198    where
199        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
200        F: 'static + Send + Sync + Fn(Request) -> Fut,
201    {
202        Arc::new(Self {
203            handler: Box::new(move |req| Box::pin(handler(req))),
204        })
205    }
206
207    pub fn with_404_response() -> Arc<dyn HttpClient> {
208        Self::new(|_| async move {
209            Ok(isahc::Response::builder()
210                .status(404)
211                .body(Default::default())
212                .unwrap())
213        })
214    }
215}
216
217impl fmt::Debug for FakeHttpClient {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        f.debug_struct("FakeHttpClient").finish()
220    }
221}
222
223impl HttpClient for FakeHttpClient {
224    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response, crate::http::Error>> {
225        let future = (self.handler)(req);
226        Box::pin(async move { future.await.map(Into::into) })
227    }
228}