test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{
 10    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 11    ConnectionId, Peer, Receipt, TypedEnvelope,
 12};
 13use std::{fmt, rc::Rc, sync::Arc};
 14
 15pub struct FakeServer {
 16    peer: Arc<Peer>,
 17    state: Arc<Mutex<FakeServerState>>,
 18    user_id: u64,
 19    executor: Rc<executor::Foreground>,
 20}
 21
 22#[derive(Default)]
 23struct FakeServerState {
 24    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 25    connection_id: Option<ConnectionId>,
 26    forbid_connections: bool,
 27    auth_count: usize,
 28    access_token: usize,
 29}
 30
 31impl FakeServer {
 32    pub async fn for_client(
 33        client_user_id: u64,
 34        client: &Arc<Client>,
 35        cx: &TestAppContext,
 36    ) -> Self {
 37        let server = Self {
 38            peer: Peer::new(),
 39            state: Default::default(),
 40            user_id: client_user_id,
 41            executor: cx.foreground(),
 42        };
 43
 44        client
 45            .override_authenticate({
 46                let state = Arc::downgrade(&server.state);
 47                move |cx| {
 48                    let state = state.clone();
 49                    cx.spawn(move |_| async move {
 50                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 51                        let mut state = state.lock();
 52                        state.auth_count += 1;
 53                        let access_token = state.access_token.to_string();
 54                        Ok(Credentials {
 55                            user_id: client_user_id,
 56                            access_token,
 57                        })
 58                    })
 59                }
 60            })
 61            .override_establish_connection({
 62                let peer = Arc::downgrade(&server.peer);
 63                let state = Arc::downgrade(&server.state);
 64                move |credentials, cx| {
 65                    let peer = peer.clone();
 66                    let state = state.clone();
 67                    let credentials = credentials.clone();
 68                    cx.spawn(move |cx| async move {
 69                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 70                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 71                        if state.lock().forbid_connections {
 72                            Err(EstablishConnectionError::Other(anyhow!(
 73                                "server is forbidding connections"
 74                            )))?
 75                        }
 76
 77                        assert_eq!(credentials.user_id, client_user_id);
 78
 79                        if credentials.access_token != state.lock().access_token.to_string() {
 80                            Err(EstablishConnectionError::Unauthorized)?
 81                        }
 82
 83                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 84                        let (connection_id, io, incoming) =
 85                            peer.add_test_connection(server_conn, cx.background());
 86                        cx.background().spawn(io).detach();
 87                        let mut state = state.lock();
 88                        state.connection_id = Some(connection_id);
 89                        state.incoming = Some(incoming);
 90                        Ok(client_conn)
 91                    })
 92                }
 93            });
 94
 95        client
 96            .authenticate_and_connect(false, &cx.to_async())
 97            .await
 98            .unwrap();
 99
100        server
101    }
102
103    pub fn disconnect(&self) {
104        if self.state.lock().connection_id.is_some() {
105            self.peer.disconnect(self.connection_id());
106            let mut state = self.state.lock();
107            state.connection_id.take();
108            state.incoming.take();
109        }
110    }
111
112    pub fn auth_count(&self) -> usize {
113        self.state.lock().auth_count
114    }
115
116    pub fn roll_access_token(&self) {
117        self.state.lock().access_token += 1;
118    }
119
120    pub fn forbid_connections(&self) {
121        self.state.lock().forbid_connections = true;
122    }
123
124    pub fn allow_connections(&self) {
125        self.state.lock().forbid_connections = false;
126    }
127
128    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
129        self.peer.send(self.connection_id(), message).unwrap();
130    }
131
132    #[allow(clippy::await_holding_lock)]
133    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
134        self.executor.start_waiting();
135
136        loop {
137            let message = self
138                .state
139                .lock()
140                .incoming
141                .as_mut()
142                .expect("not connected")
143                .next()
144                .await
145                .ok_or_else(|| anyhow!("other half hung up"))?;
146            self.executor.finish_waiting();
147            let type_name = message.payload_type_name();
148            let message = message.into_any();
149
150            if message.is::<TypedEnvelope<M>>() {
151                return Ok(*message.downcast().unwrap());
152            }
153
154            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
155                self.respond(
156                    message
157                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
158                        .unwrap()
159                        .receipt(),
160                    GetPrivateUserInfoResponse {
161                        metrics_id: "the-metrics-id".into(),
162                        staff: false,
163                    },
164                )
165                .await;
166                continue;
167            }
168
169            panic!(
170                "fake server received unexpected message type: {:?}",
171                type_name
172            );
173        }
174    }
175
176    pub async fn respond<T: proto::RequestMessage>(
177        &self,
178        receipt: Receipt<T>,
179        response: T::Response,
180    ) {
181        self.peer.respond(receipt, response).unwrap()
182    }
183
184    fn connection_id(&self) -> ConnectionId {
185        self.state.lock().connection_id.expect("not connected")
186    }
187
188    pub async fn build_user_store(
189        &self,
190        client: Arc<Client>,
191        cx: &mut TestAppContext,
192    ) -> ModelHandle<UserStore> {
193        let http_client = FakeHttpClient::with_404_response();
194        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
195        assert_eq!(
196            self.receive::<proto::GetUsers>()
197                .await
198                .unwrap()
199                .payload
200                .user_ids,
201            &[self.user_id]
202        );
203        user_store
204    }
205}
206
207impl Drop for FakeServer {
208    fn drop(&mut self) {
209        self.disconnect();
210    }
211}
212
213pub struct FakeHttpClient {
214    handler: Box<
215        dyn 'static
216            + Send
217            + Sync
218            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
219    >,
220}
221
222impl FakeHttpClient {
223    pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
224    where
225        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
226        F: 'static + Send + Sync + Fn(Request) -> Fut,
227    {
228        Arc::new(Self {
229            handler: Box::new(move |req| Box::pin(handler(req))),
230        })
231    }
232
233    pub fn with_404_response() -> Arc<dyn HttpClient> {
234        Self::create(|_| async move {
235            Ok(isahc::Response::builder()
236                .status(404)
237                .body(Default::default())
238                .unwrap())
239        })
240    }
241}
242
243impl fmt::Debug for FakeHttpClient {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        f.debug_struct("FakeHttpClient").finish()
246    }
247}
248
249impl HttpClient for FakeHttpClient {
250    fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
251        let future = (self.handler)(req);
252        Box::pin(async move { future.await.map(Into::into) })
253    }
254}