test.rs

  1use crate::{
  2    http::{HttpClient, Request, Response, ServerResponse},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 10use std::{fmt, rc::Rc, sync::Arc};
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: Rc<executor::Foreground>,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &mut Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.foreground(),
 39        };
 40
 41        Arc::get_mut(client)
 42            .unwrap()
 43            .override_authenticate({
 44                let state = server.state.clone();
 45                move |cx| {
 46                    let mut state = state.lock();
 47                    state.auth_count += 1;
 48                    let access_token = state.access_token.to_string();
 49                    cx.spawn(move |_| async move {
 50                        Ok(Credentials {
 51                            user_id: client_user_id,
 52                            access_token,
 53                        })
 54                    })
 55                }
 56            })
 57            .override_establish_connection({
 58                let peer = server.peer.clone();
 59                let state = server.state.clone();
 60                move |credentials, cx| {
 61                    let peer = peer.clone();
 62                    let state = state.clone();
 63                    let credentials = credentials.clone();
 64                    cx.spawn(move |cx| async move {
 65                        assert_eq!(credentials.user_id, client_user_id);
 66
 67                        if state.lock().forbid_connections {
 68                            Err(EstablishConnectionError::Other(anyhow!(
 69                                "server is forbidding connections"
 70                            )))?
 71                        }
 72
 73                        if credentials.access_token != state.lock().access_token.to_string() {
 74                            Err(EstablishConnectionError::Unauthorized)?
 75                        }
 76
 77                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 78                        let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
 79                        cx.background().spawn(io).detach();
 80                        let mut state = state.lock();
 81                        state.connection_id = Some(connection_id);
 82                        state.incoming = Some(incoming);
 83                        Ok(client_conn)
 84                    })
 85                }
 86            });
 87
 88        client
 89            .authenticate_and_connect(&cx.to_async())
 90            .await
 91            .unwrap();
 92        server
 93    }
 94
 95    pub fn disconnect(&self) {
 96        self.peer.disconnect(self.connection_id());
 97        let mut state = self.state.lock();
 98        state.connection_id.take();
 99        state.incoming.take();
100    }
101
102    pub fn auth_count(&self) -> usize {
103        self.state.lock().auth_count
104    }
105
106    pub fn roll_access_token(&self) {
107        self.state.lock().access_token += 1;
108    }
109
110    pub fn forbid_connections(&self) {
111        self.state.lock().forbid_connections = true;
112    }
113
114    pub fn allow_connections(&self) {
115        self.state.lock().forbid_connections = false;
116    }
117
118    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
119        self.peer.send(self.connection_id(), message).unwrap();
120    }
121
122    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
123        self.executor.start_waiting();
124        let message = self
125            .state
126            .lock()
127            .incoming
128            .as_mut()
129            .expect("not connected")
130            .next()
131            .await
132            .ok_or_else(|| anyhow!("other half hung up"))?;
133        self.executor.finish_waiting();
134        let type_name = message.payload_type_name();
135        Ok(*message
136            .into_any()
137            .downcast::<TypedEnvelope<M>>()
138            .unwrap_or_else(|_| {
139                panic!(
140                    "fake server received unexpected message type: {:?}",
141                    type_name
142                );
143            }))
144    }
145
146    pub async fn respond<T: proto::RequestMessage>(
147        &self,
148        receipt: Receipt<T>,
149        response: T::Response,
150    ) {
151        self.peer.respond(receipt, response).unwrap()
152    }
153
154    fn connection_id(&self) -> ConnectionId {
155        self.state.lock().connection_id.expect("not connected")
156    }
157
158    pub async fn build_user_store(
159        &self,
160        client: Arc<Client>,
161        cx: &mut TestAppContext,
162    ) -> ModelHandle<UserStore> {
163        let http_client = FakeHttpClient::with_404_response();
164        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
165        assert_eq!(
166            self.receive::<proto::GetUsers>()
167                .await
168                .unwrap()
169                .payload
170                .user_ids,
171            &[self.user_id]
172        );
173        user_store
174    }
175}
176
177pub struct FakeHttpClient {
178    handler:
179        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
180}
181
182impl FakeHttpClient {
183    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
184    where
185        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
186        F: 'static + Send + Sync + Fn(Request) -> Fut,
187    {
188        Arc::new(Self {
189            handler: Box::new(move |req| Box::pin(handler(req))),
190        })
191    }
192
193    pub fn with_404_response() -> Arc<dyn HttpClient> {
194        Self::new(|_| async move { Ok(ServerResponse::new(404)) })
195    }
196}
197
198impl fmt::Debug for FakeHttpClient {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        f.debug_struct("FakeHttpClient").finish()
201    }
202}
203
204impl HttpClient for FakeHttpClient {
205    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
206        let future = (self.handler)(req);
207        Box::pin(async move { future.await.map(Into::into) })
208    }
209}