test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{
 10    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 11    ConnectionId, Peer, Receipt, TypedEnvelope,
 12};
 13use std::{fmt, rc::Rc, sync::Arc};
 14
 15pub struct FakeServer {
 16    peer: Arc<Peer>,
 17    state: Arc<Mutex<FakeServerState>>,
 18    user_id: u64,
 19    executor: Rc<executor::Foreground>,
 20}
 21
 22#[derive(Default)]
 23struct FakeServerState {
 24    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 25    connection_id: Option<ConnectionId>,
 26    forbid_connections: bool,
 27    auth_count: usize,
 28    access_token: usize,
 29}
 30
 31impl FakeServer {
 32    pub async fn for_client(
 33        client_user_id: u64,
 34        client: &Arc<Client>,
 35        cx: &TestAppContext,
 36    ) -> Self {
 37        let server = Self {
 38            peer: Peer::new(),
 39            state: Default::default(),
 40            user_id: client_user_id,
 41            executor: cx.foreground(),
 42        };
 43
 44        client
 45            .override_authenticate({
 46                let state = Arc::downgrade(&server.state);
 47                move |cx| {
 48                    let state = state.clone();
 49                    cx.spawn(move |_| async move {
 50                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 51                        let mut state = state.lock();
 52                        state.auth_count += 1;
 53                        let access_token = state.access_token.to_string();
 54                        Ok(Credentials {
 55                            user_id: client_user_id,
 56                            access_token,
 57                        })
 58                    })
 59                }
 60            })
 61            .override_establish_connection({
 62                let peer = Arc::downgrade(&server.peer);
 63                let state = Arc::downgrade(&server.state);
 64                move |credentials, cx| {
 65                    let peer = peer.clone();
 66                    let state = state.clone();
 67                    let credentials = credentials.clone();
 68                    cx.spawn(move |cx| async move {
 69                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 70                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 71                        if state.lock().forbid_connections {
 72                            Err(EstablishConnectionError::Other(anyhow!(
 73                                "server is forbidding connections"
 74                            )))?
 75                        }
 76
 77                        assert_eq!(credentials.user_id, client_user_id);
 78
 79                        if credentials.access_token != state.lock().access_token.to_string() {
 80                            Err(EstablishConnectionError::Unauthorized)?
 81                        }
 82
 83                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 84                        let (connection_id, io, incoming) =
 85                            peer.add_test_connection(server_conn, cx.background()).await;
 86                        cx.background().spawn(io).detach();
 87                        let mut state = state.lock();
 88                        state.connection_id = Some(connection_id);
 89                        state.incoming = Some(incoming);
 90                        Ok(client_conn)
 91                    })
 92                }
 93            });
 94
 95        client
 96            .authenticate_and_connect(false, &cx.to_async())
 97            .await
 98            .unwrap();
 99
100        server
101    }
102
103    pub fn disconnect(&self) {
104        self.peer.disconnect(self.connection_id());
105        let mut state = self.state.lock();
106        state.connection_id.take();
107        state.incoming.take();
108    }
109
110    pub fn auth_count(&self) -> usize {
111        self.state.lock().auth_count
112    }
113
114    pub fn roll_access_token(&self) {
115        self.state.lock().access_token += 1;
116    }
117
118    pub fn forbid_connections(&self) {
119        self.state.lock().forbid_connections = true;
120    }
121
122    pub fn allow_connections(&self) {
123        self.state.lock().forbid_connections = false;
124    }
125
126    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
127        self.peer.send(self.connection_id(), message).unwrap();
128    }
129
130    #[allow(clippy::await_holding_lock)]
131    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
132        self.executor.start_waiting();
133
134        loop {
135            let message = self
136                .state
137                .lock()
138                .incoming
139                .as_mut()
140                .expect("not connected")
141                .next()
142                .await
143                .ok_or_else(|| anyhow!("other half hung up"))?;
144            self.executor.finish_waiting();
145            let type_name = message.payload_type_name();
146            let message = message.into_any();
147
148            if message.is::<TypedEnvelope<M>>() {
149                return Ok(*message.downcast().unwrap());
150            }
151
152            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
153                self.respond(
154                    message
155                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
156                        .unwrap()
157                        .receipt(),
158                    GetPrivateUserInfoResponse {
159                        metrics_id: "the-metrics-id".into(),
160                    },
161                )
162                .await;
163                continue;
164            }
165
166            panic!(
167                "fake server received unexpected message type: {:?}",
168                type_name
169            );
170        }
171    }
172
173    pub async fn respond<T: proto::RequestMessage>(
174        &self,
175        receipt: Receipt<T>,
176        response: T::Response,
177    ) {
178        self.peer.respond(receipt, response).unwrap()
179    }
180
181    fn connection_id(&self) -> ConnectionId {
182        self.state.lock().connection_id.expect("not connected")
183    }
184
185    pub async fn build_user_store(
186        &self,
187        client: Arc<Client>,
188        cx: &mut TestAppContext,
189    ) -> ModelHandle<UserStore> {
190        let http_client = FakeHttpClient::with_404_response();
191        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
192        assert_eq!(
193            self.receive::<proto::GetUsers>()
194                .await
195                .unwrap()
196                .payload
197                .user_ids,
198            &[self.user_id]
199        );
200        user_store
201    }
202}
203
204impl Drop for FakeServer {
205    fn drop(&mut self) {
206        self.disconnect();
207    }
208}
209
210pub struct FakeHttpClient {
211    handler: Box<
212        dyn 'static
213            + Send
214            + Sync
215            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
216    >,
217}
218
219impl FakeHttpClient {
220    pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
221    where
222        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
223        F: 'static + Send + Sync + Fn(Request) -> Fut,
224    {
225        Arc::new(Self {
226            handler: Box::new(move |req| Box::pin(handler(req))),
227        })
228    }
229
230    pub fn with_404_response() -> Arc<dyn HttpClient> {
231        Self::create(|_| async move {
232            Ok(isahc::Response::builder()
233                .status(404)
234                .body(Default::default())
235                .unwrap())
236        })
237    }
238}
239
240impl fmt::Debug for FakeHttpClient {
241    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242        f.debug_struct("FakeHttpClient").finish()
243    }
244}
245
246impl HttpClient for FakeHttpClient {
247    fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
248        let future = (self.handler)(req);
249        Box::pin(async move { future.await.map(Into::into) })
250    }
251}