test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{
 10    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 11    ConnectionId, Peer, Receipt, TypedEnvelope,
 12};
 13use std::{fmt, rc::Rc, sync::Arc};
 14
 15pub struct FakeServer {
 16    peer: Arc<Peer>,
 17    state: Arc<Mutex<FakeServerState>>,
 18    user_id: u64,
 19    executor: Rc<executor::Foreground>,
 20}
 21
 22#[derive(Default)]
 23struct FakeServerState {
 24    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 25    connection_id: Option<ConnectionId>,
 26    forbid_connections: bool,
 27    auth_count: usize,
 28    access_token: usize,
 29}
 30
 31impl FakeServer {
 32    pub async fn for_client(
 33        client_user_id: u64,
 34        client: &Arc<Client>,
 35        cx: &TestAppContext,
 36    ) -> Self {
 37        let server = Self {
 38            peer: Peer::new(),
 39            state: Default::default(),
 40            user_id: client_user_id,
 41            executor: cx.foreground(),
 42        };
 43
 44        client
 45            .override_authenticate({
 46                let state = Arc::downgrade(&server.state);
 47                move |cx| {
 48                    let state = state.clone();
 49                    cx.spawn(move |_| async move {
 50                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 51                        let mut state = state.lock();
 52                        state.auth_count += 1;
 53                        let access_token = state.access_token.to_string();
 54                        Ok(Credentials {
 55                            user_id: client_user_id,
 56                            access_token,
 57                        })
 58                    })
 59                }
 60            })
 61            .override_establish_connection({
 62                let peer = Arc::downgrade(&server.peer);
 63                let state = Arc::downgrade(&server.state);
 64                move |credentials, cx| {
 65                    let peer = peer.clone();
 66                    let state = state.clone();
 67                    let credentials = credentials.clone();
 68                    cx.spawn(move |cx| async move {
 69                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 70                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 71                        if state.lock().forbid_connections {
 72                            Err(EstablishConnectionError::Other(anyhow!(
 73                                "server is forbidding connections"
 74                            )))?
 75                        }
 76
 77                        assert_eq!(credentials.user_id, client_user_id);
 78
 79                        if credentials.access_token != state.lock().access_token.to_string() {
 80                            Err(EstablishConnectionError::Unauthorized)?
 81                        }
 82
 83                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 84                        let (connection_id, io, incoming) =
 85                            peer.add_test_connection(server_conn, cx.background()).await;
 86                        cx.background().spawn(io).detach();
 87                        let mut state = state.lock();
 88                        state.connection_id = Some(connection_id);
 89                        state.incoming = Some(incoming);
 90                        Ok(client_conn)
 91                    })
 92                }
 93            });
 94
 95        client
 96            .authenticate_and_connect(false, &cx.to_async())
 97            .await
 98            .unwrap();
 99
100        server
101    }
102
103    pub fn disconnect(&self) {
104        self.peer.disconnect(self.connection_id());
105        let mut state = self.state.lock();
106        state.connection_id.take();
107        state.incoming.take();
108    }
109
110    pub fn auth_count(&self) -> usize {
111        self.state.lock().auth_count
112    }
113
114    pub fn roll_access_token(&self) {
115        self.state.lock().access_token += 1;
116    }
117
118    pub fn forbid_connections(&self) {
119        self.state.lock().forbid_connections = true;
120    }
121
122    pub fn allow_connections(&self) {
123        self.state.lock().forbid_connections = false;
124    }
125
126    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
127        self.peer.send(self.connection_id(), message).unwrap();
128    }
129
130    #[allow(clippy::await_holding_lock)]
131    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
132        self.executor.start_waiting();
133
134        loop {
135            let message = self
136                .state
137                .lock()
138                .incoming
139                .as_mut()
140                .expect("not connected")
141                .next()
142                .await
143                .ok_or_else(|| anyhow!("other half hung up"))?;
144            self.executor.finish_waiting();
145            let type_name = message.payload_type_name();
146            let message = message.into_any();
147
148            if message.is::<TypedEnvelope<M>>() {
149                return Ok(*message.downcast().unwrap());
150            }
151
152            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
153                self.respond(
154                    message
155                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
156                        .unwrap()
157                        .receipt(),
158                    GetPrivateUserInfoResponse {
159                        metrics_id: "the-metrics-id".into(),
160                        staff: false,
161                    },
162                )
163                .await;
164                continue;
165            }
166
167            panic!(
168                "fake server received unexpected message type: {:?}",
169                type_name
170            );
171        }
172    }
173
174    pub async fn respond<T: proto::RequestMessage>(
175        &self,
176        receipt: Receipt<T>,
177        response: T::Response,
178    ) {
179        self.peer.respond(receipt, response).unwrap()
180    }
181
182    fn connection_id(&self) -> ConnectionId {
183        self.state.lock().connection_id.expect("not connected")
184    }
185
186    pub async fn build_user_store(
187        &self,
188        client: Arc<Client>,
189        cx: &mut TestAppContext,
190    ) -> ModelHandle<UserStore> {
191        let http_client = FakeHttpClient::with_404_response();
192        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
193        assert_eq!(
194            self.receive::<proto::GetUsers>()
195                .await
196                .unwrap()
197                .payload
198                .user_ids,
199            &[self.user_id]
200        );
201        user_store
202    }
203}
204
205impl Drop for FakeServer {
206    fn drop(&mut self) {
207        self.disconnect();
208    }
209}
210
211pub struct FakeHttpClient {
212    handler: Box<
213        dyn 'static
214            + Send
215            + Sync
216            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
217    >,
218}
219
220impl FakeHttpClient {
221    pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
222    where
223        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
224        F: 'static + Send + Sync + Fn(Request) -> Fut,
225    {
226        Arc::new(Self {
227            handler: Box::new(move |req| Box::pin(handler(req))),
228        })
229    }
230
231    pub fn with_404_response() -> Arc<dyn HttpClient> {
232        Self::create(|_| async move {
233            Ok(isahc::Response::builder()
234                .status(404)
235                .body(Default::default())
236                .unwrap())
237        })
238    }
239}
240
241impl fmt::Debug for FakeHttpClient {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        f.debug_struct("FakeHttpClient").finish()
244    }
245}
246
247impl HttpClient for FakeHttpClient {
248    fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
249        let future = (self.handler)(req);
250        Box::pin(async move { future.await.map(Into::into) })
251    }
252}