test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 10use std::{fmt, rc::Rc, sync::Arc};
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: Rc<executor::Foreground>,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &mut Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.foreground(),
 39        };
 40
 41        Arc::get_mut(client)
 42            .unwrap()
 43            .override_authenticate({
 44                let state = server.state.clone();
 45                move |cx| {
 46                    let mut state = state.lock();
 47                    state.auth_count += 1;
 48                    let access_token = state.access_token.to_string();
 49                    cx.spawn(move |_| async move {
 50                        Ok(Credentials {
 51                            user_id: client_user_id,
 52                            access_token,
 53                        })
 54                    })
 55                }
 56            })
 57            .override_establish_connection({
 58                let peer = server.peer.clone();
 59                let state = server.state.clone();
 60                move |credentials, cx| {
 61                    let peer = peer.clone();
 62                    let state = state.clone();
 63                    let credentials = credentials.clone();
 64                    cx.spawn(move |cx| async move {
 65                        assert_eq!(credentials.user_id, client_user_id);
 66
 67                        if state.lock().forbid_connections {
 68                            Err(EstablishConnectionError::Other(anyhow!(
 69                                "server is forbidding connections"
 70                            )))?
 71                        }
 72
 73                        if credentials.access_token != state.lock().access_token.to_string() {
 74                            Err(EstablishConnectionError::Unauthorized)?
 75                        }
 76
 77                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 78                        let (connection_id, io, incoming) =
 79                            peer.add_test_connection(server_conn, cx.background()).await;
 80                        cx.background().spawn(io).detach();
 81                        let mut state = state.lock();
 82                        state.connection_id = Some(connection_id);
 83                        state.incoming = Some(incoming);
 84                        Ok(client_conn)
 85                    })
 86                }
 87            });
 88
 89        client
 90            .authenticate_and_connect(false, &cx.to_async())
 91            .await
 92            .unwrap();
 93        server
 94    }
 95
 96    pub fn disconnect(&self) {
 97        self.peer.disconnect(self.connection_id());
 98        let mut state = self.state.lock();
 99        state.connection_id.take();
100        state.incoming.take();
101    }
102
103    pub fn auth_count(&self) -> usize {
104        self.state.lock().auth_count
105    }
106
107    pub fn roll_access_token(&self) {
108        self.state.lock().access_token += 1;
109    }
110
111    pub fn forbid_connections(&self) {
112        self.state.lock().forbid_connections = true;
113    }
114
115    pub fn allow_connections(&self) {
116        self.state.lock().forbid_connections = false;
117    }
118
119    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
120        self.peer.send(self.connection_id(), message).unwrap();
121    }
122
123    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
124        self.executor.start_waiting();
125        let message = self
126            .state
127            .lock()
128            .incoming
129            .as_mut()
130            .expect("not connected")
131            .next()
132            .await
133            .ok_or_else(|| anyhow!("other half hung up"))?;
134        self.executor.finish_waiting();
135        let type_name = message.payload_type_name();
136        Ok(*message
137            .into_any()
138            .downcast::<TypedEnvelope<M>>()
139            .unwrap_or_else(|_| {
140                panic!(
141                    "fake server received unexpected message type: {:?}",
142                    type_name
143                );
144            }))
145    }
146
147    pub async fn respond<T: proto::RequestMessage>(
148        &self,
149        receipt: Receipt<T>,
150        response: T::Response,
151    ) {
152        self.peer.respond(receipt, response).unwrap()
153    }
154
155    fn connection_id(&self) -> ConnectionId {
156        self.state.lock().connection_id.expect("not connected")
157    }
158
159    pub async fn build_user_store(
160        &self,
161        client: Arc<Client>,
162        cx: &mut TestAppContext,
163    ) -> ModelHandle<UserStore> {
164        let http_client = FakeHttpClient::with_404_response();
165        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
166        assert_eq!(
167            self.receive::<proto::GetUsers>()
168                .await
169                .unwrap()
170                .payload
171                .user_ids,
172            &[self.user_id]
173        );
174        user_store
175    }
176}
177
178pub struct FakeHttpClient {
179    handler: Box<
180        dyn 'static
181            + Send
182            + Sync
183            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
184    >,
185}
186
187impl FakeHttpClient {
188    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
189    where
190        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
191        F: 'static + Send + Sync + Fn(Request) -> Fut,
192    {
193        Arc::new(Self {
194            handler: Box::new(move |req| Box::pin(handler(req))),
195        })
196    }
197
198    pub fn with_404_response() -> Arc<dyn HttpClient> {
199        Self::new(|_| async move {
200            Ok(isahc::Response::builder()
201                .status(404)
202                .body(Default::default())
203                .unwrap())
204        })
205    }
206}
207
208impl fmt::Debug for FakeHttpClient {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        f.debug_struct("FakeHttpClient").finish()
211    }
212}
213
214impl HttpClient for FakeHttpClient {
215    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response, crate::http::Error>> {
216        let future = (self.handler)(req);
217        Box::pin(async move { future.await.map(Into::into) })
218    }
219}