test.rs

  1use crate::{
  2    http::{self, HttpClient, Request, Response},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 10use std::{fmt, rc::Rc, sync::Arc};
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: Rc<executor::Foreground>,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.foreground(),
 39        };
 40
 41        client
 42            .override_authenticate({
 43                let state = Arc::downgrade(&server.state);
 44                move |cx| {
 45                    let state = state.clone();
 46                    cx.spawn(move |_| async move {
 47                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 48                        let mut state = state.lock();
 49                        state.auth_count += 1;
 50                        let access_token = state.access_token.to_string();
 51                        Ok(Credentials {
 52                            user_id: client_user_id,
 53                            access_token,
 54                        })
 55                    })
 56                }
 57            })
 58            .override_establish_connection({
 59                let peer = Arc::downgrade(&server.peer);
 60                let state = Arc::downgrade(&server.state);
 61                move |credentials, cx| {
 62                    let peer = peer.clone();
 63                    let state = state.clone();
 64                    let credentials = credentials.clone();
 65                    cx.spawn(move |cx| async move {
 66                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 67                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 68                        if state.lock().forbid_connections {
 69                            Err(EstablishConnectionError::Other(anyhow!(
 70                                "server is forbidding connections"
 71                            )))?
 72                        }
 73
 74                        assert_eq!(credentials.user_id, client_user_id);
 75
 76                        if credentials.access_token != state.lock().access_token.to_string() {
 77                            Err(EstablishConnectionError::Unauthorized)?
 78                        }
 79
 80                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 81                        let (connection_id, io, incoming) =
 82                            peer.add_test_connection(server_conn, cx.background()).await;
 83                        cx.background().spawn(io).detach();
 84                        let mut state = state.lock();
 85                        state.connection_id = Some(connection_id);
 86                        state.incoming = Some(incoming);
 87                        Ok(client_conn)
 88                    })
 89                }
 90            });
 91
 92        client
 93            .authenticate_and_connect(false, &cx.to_async())
 94            .await
 95            .unwrap();
 96        server
 97    }
 98
 99    pub fn disconnect(&self) {
100        self.peer.disconnect(self.connection_id());
101        let mut state = self.state.lock();
102        state.connection_id.take();
103        state.incoming.take();
104    }
105
106    pub fn auth_count(&self) -> usize {
107        self.state.lock().auth_count
108    }
109
110    pub fn roll_access_token(&self) {
111        self.state.lock().access_token += 1;
112    }
113
114    pub fn forbid_connections(&self) {
115        self.state.lock().forbid_connections = true;
116    }
117
118    pub fn allow_connections(&self) {
119        self.state.lock().forbid_connections = false;
120    }
121
122    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
123        self.peer.send(self.connection_id(), message).unwrap();
124    }
125
126    #[allow(clippy::await_holding_lock)]
127    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
128        self.executor.start_waiting();
129        let message = self
130            .state
131            .lock()
132            .incoming
133            .as_mut()
134            .expect("not connected")
135            .next()
136            .await
137            .ok_or_else(|| anyhow!("other half hung up"))?;
138        self.executor.finish_waiting();
139        let type_name = message.payload_type_name();
140        Ok(*message
141            .into_any()
142            .downcast::<TypedEnvelope<M>>()
143            .unwrap_or_else(|_| {
144                panic!(
145                    "fake server received unexpected message type: {:?}",
146                    type_name
147                );
148            }))
149    }
150
151    pub async fn respond<T: proto::RequestMessage>(
152        &self,
153        receipt: Receipt<T>,
154        response: T::Response,
155    ) {
156        self.peer.respond(receipt, response).unwrap()
157    }
158
159    fn connection_id(&self) -> ConnectionId {
160        self.state.lock().connection_id.expect("not connected")
161    }
162
163    pub async fn build_user_store(
164        &self,
165        client: Arc<Client>,
166        cx: &mut TestAppContext,
167    ) -> ModelHandle<UserStore> {
168        let http_client = FakeHttpClient::with_404_response();
169        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
170        assert_eq!(
171            self.receive::<proto::GetUsers>()
172                .await
173                .unwrap()
174                .payload
175                .user_ids,
176            &[self.user_id]
177        );
178        user_store
179    }
180}
181
182impl Drop for FakeServer {
183    fn drop(&mut self) {
184        self.disconnect();
185    }
186}
187
188pub struct FakeHttpClient {
189    handler: Box<
190        dyn 'static
191            + Send
192            + Sync
193            + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
194    >,
195}
196
197impl FakeHttpClient {
198    pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
199    where
200        Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
201        F: 'static + Send + Sync + Fn(Request) -> Fut,
202    {
203        Arc::new(Self {
204            handler: Box::new(move |req| Box::pin(handler(req))),
205        })
206    }
207
208    pub fn with_404_response() -> Arc<dyn HttpClient> {
209        Self::create(|_| async move {
210            Ok(isahc::Response::builder()
211                .status(404)
212                .body(Default::default())
213                .unwrap())
214        })
215    }
216}
217
218impl fmt::Debug for FakeHttpClient {
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        f.debug_struct("FakeHttpClient").finish()
221    }
222}
223
224impl HttpClient for FakeHttpClient {
225    fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
226        let future = (self.handler)(req);
227        Box::pin(async move { future.await.map(Into::into) })
228    }
229}