test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{
 10    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 11    ConnectionId, Peer, Receipt, TypedEnvelope,
 12};
 13use std::{fmt, rc::Rc, sync::Arc};
 14
 15pub struct FakeServer {
 16    peer: Arc<Peer>,
 17    state: Arc<Mutex<FakeServerState>>,
 18    user_id: u64,
 19    executor: Rc<executor::Foreground>,
 20}
 21
 22#[derive(Default)]
 23struct FakeServerState {
 24    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 25    connection_id: Option<ConnectionId>,
 26    forbid_connections: bool,
 27    auth_count: usize,
 28    access_token: usize,
 29}
 30
 31impl FakeServer {
 32    pub async fn for_client(
 33        client_user_id: u64,
 34        client: &Arc<Client>,
 35        cx: &TestAppContext,
 36    ) -> Self {
 37        let server = Self {
 38            peer: Peer::new(),
 39            state: Default::default(),
 40            user_id: client_user_id,
 41            executor: cx.foreground(),
 42        };
 43
 44        client
 45            .override_authenticate({
 46                let state = Arc::downgrade(&server.state);
 47                move |cx| {
 48                    let state = state.clone();
 49                    cx.spawn(move |_| async move {
 50                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 51                        let mut state = state.lock();
 52                        state.auth_count += 1;
 53                        let access_token = state.access_token.to_string();
 54                        Ok(Credentials {
 55                            user_id: client_user_id,
 56                            access_token,
 57                        })
 58                    })
 59                }
 60            })
 61            .override_establish_connection({
 62                let peer = Arc::downgrade(&server.peer);
 63                let state = Arc::downgrade(&server.state);
 64                move |credentials, cx| {
 65                    let peer = peer.clone();
 66                    let state = state.clone();
 67                    let credentials = credentials.clone();
 68                    cx.spawn(move |cx| async move {
 69                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 70                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 71                        if state.lock().forbid_connections {
 72                            Err(EstablishConnectionError::Other(anyhow!(
 73                                "server is forbidding connections"
 74                            )))?
 75                        }
 76
 77                        assert_eq!(credentials.user_id, client_user_id);
 78
 79                        if credentials.access_token != state.lock().access_token.to_string() {
 80                            Err(EstablishConnectionError::Unauthorized)?
 81                        }
 82
 83                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 84                        let (connection_id, io, incoming) =
 85                            peer.add_test_connection(server_conn, cx.background());
 86                        cx.background().spawn(io).detach();
 87                        {
 88                            let mut state = state.lock();
 89                            state.connection_id = Some(connection_id);
 90                            state.incoming = Some(incoming);
 91                        }
 92                        peer.send(
 93                            connection_id,
 94                            proto::Hello {
 95                                peer_id: connection_id.0,
 96                            },
 97                        )
 98                        .unwrap();
 99
100                        Ok(client_conn)
101                    })
102                }
103            });
104
105        client
106            .authenticate_and_connect(false, &cx.to_async())
107            .await
108            .unwrap();
109
110        server
111    }
112
113    pub fn disconnect(&self) {
114        if self.state.lock().connection_id.is_some() {
115            self.peer.disconnect(self.connection_id());
116            let mut state = self.state.lock();
117            state.connection_id.take();
118            state.incoming.take();
119        }
120    }
121
122    pub fn auth_count(&self) -> usize {
123        self.state.lock().auth_count
124    }
125
126    pub fn roll_access_token(&self) {
127        self.state.lock().access_token += 1;
128    }
129
130    pub fn forbid_connections(&self) {
131        self.state.lock().forbid_connections = true;
132    }
133
134    pub fn allow_connections(&self) {
135        self.state.lock().forbid_connections = false;
136    }
137
138    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
139        self.peer.send(self.connection_id(), message).unwrap();
140    }
141
142    #[allow(clippy::await_holding_lock)]
143    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
144        self.executor.start_waiting();
145
146        loop {
147            let message = self
148                .state
149                .lock()
150                .incoming
151                .as_mut()
152                .expect("not connected")
153                .next()
154                .await
155                .ok_or_else(|| anyhow!("other half hung up"))?;
156            self.executor.finish_waiting();
157            let type_name = message.payload_type_name();
158            let message = message.into_any();
159
160            if message.is::<TypedEnvelope<M>>() {
161                return Ok(*message.downcast().unwrap());
162            }
163
164            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
165                self.respond(
166                    message
167                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
168                        .unwrap()
169                        .receipt(),
170                    GetPrivateUserInfoResponse {
171                        metrics_id: "the-metrics-id".into(),
172                        staff: false,
173                    },
174                )
175                .await;
176                continue;
177            }
178
179            panic!(
180                "fake server received unexpected message type: {:?}",
181                type_name
182            );
183        }
184    }
185
186    pub async fn respond<T: proto::RequestMessage>(
187        &self,
188        receipt: Receipt<T>,
189        response: T::Response,
190    ) {
191        self.peer.respond(receipt, response).unwrap()
192    }
193
194    fn connection_id(&self) -> ConnectionId {
195        self.state.lock().connection_id.expect("not connected")
196    }
197
198    pub async fn build_user_store(
199        &self,
200        client: Arc<Client>,
201        cx: &mut TestAppContext,
202    ) -> ModelHandle<UserStore> {
203        let http_client = FakeHttpClient::with_404_response();
204        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
205        assert_eq!(
206            self.receive::<proto::GetUsers>()
207                .await
208                .unwrap()
209                .payload
210                .user_ids,
211            &[self.user_id]
212        );
213        user_store
214    }
215}
216
217impl Drop for FakeServer {
218    fn drop(&mut self) {
219        self.disconnect();
220    }
221}
222
223pub struct FakeHttpClient {
224    handler: Box<
225        dyn 'static
226            + Send
227            + Sync
228            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
229    >,
230}
231
232impl FakeHttpClient {
233    pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
234    where
235        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
236        F: 'static + Send + Sync + Fn(Request) -> Fut,
237    {
238        Arc::new(Self {
239            handler: Box::new(move |req| Box::pin(handler(req))),
240        })
241    }
242
243    pub fn with_404_response() -> Arc<dyn HttpClient> {
244        Self::create(|_| async move {
245            Ok(isahc::Response::builder()
246                .status(404)
247                .body(Default::default())
248                .unwrap())
249        })
250    }
251}
252
253impl fmt::Debug for FakeHttpClient {
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        f.debug_struct("FakeHttpClient").finish()
256    }
257}
258
259impl HttpClient for FakeHttpClient {
260    fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
261        let future = (self.handler)(req);
262        Box::pin(async move { future.await.map(Into::into) })
263    }
264}