test.rs

  1use crate::{
  2    http::{HttpClient, Request, Response, ServerResponse},
  3    Client, Connection, Credentials, EstablishConnectionError, UserStore,
  4};
  5use anyhow::{anyhow, Result};
  6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  7use gpui::{executor, ModelHandle, TestAppContext};
  8use parking_lot::Mutex;
  9use postage::barrier;
 10use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 11use std::{fmt, rc::Rc, sync::Arc};
 12
 13pub struct FakeServer {
 14    peer: Arc<Peer>,
 15    state: Arc<Mutex<FakeServerState>>,
 16    user_id: u64,
 17    executor: Rc<executor::Foreground>,
 18}
 19
 20#[derive(Default)]
 21struct FakeServerState {
 22    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 23    connection_id: Option<ConnectionId>,
 24    forbid_connections: bool,
 25    auth_count: usize,
 26    connection_killer: Option<barrier::Sender>,
 27    access_token: usize,
 28}
 29
 30impl FakeServer {
 31    pub async fn for_client(
 32        client_user_id: u64,
 33        client: &mut Arc<Client>,
 34        cx: &TestAppContext,
 35    ) -> Self {
 36        let server = Self {
 37            peer: Peer::new(),
 38            state: Default::default(),
 39            user_id: client_user_id,
 40            executor: cx.foreground(),
 41        };
 42
 43        Arc::get_mut(client)
 44            .unwrap()
 45            .override_authenticate({
 46                let state = server.state.clone();
 47                move |cx| {
 48                    let mut state = state.lock();
 49                    state.auth_count += 1;
 50                    let access_token = state.access_token.to_string();
 51                    cx.spawn(move |_| async move {
 52                        Ok(Credentials {
 53                            user_id: client_user_id,
 54                            access_token,
 55                        })
 56                    })
 57                }
 58            })
 59            .override_establish_connection({
 60                let peer = server.peer.clone();
 61                let state = server.state.clone();
 62                move |credentials, cx| {
 63                    let peer = peer.clone();
 64                    let state = state.clone();
 65                    let credentials = credentials.clone();
 66                    cx.spawn(move |cx| async move {
 67                        assert_eq!(credentials.user_id, client_user_id);
 68
 69                        if state.lock().forbid_connections {
 70                            Err(EstablishConnectionError::Other(anyhow!(
 71                                "server is forbidding connections"
 72                            )))?
 73                        }
 74
 75                        if credentials.access_token != state.lock().access_token.to_string() {
 76                            Err(EstablishConnectionError::Unauthorized)?
 77                        }
 78
 79                        let (client_conn, server_conn, kill) =
 80                            Connection::in_memory(cx.background());
 81                        let (connection_id, io, incoming) =
 82                            peer.add_test_connection(server_conn, cx.background()).await;
 83                        cx.background().spawn(io).detach();
 84                        let mut state = state.lock();
 85                        state.connection_id = Some(connection_id);
 86                        state.incoming = Some(incoming);
 87                        state.connection_killer = Some(kill);
 88                        Ok(client_conn)
 89                    })
 90                }
 91            });
 92
 93        client
 94            .authenticate_and_connect(false, &cx.to_async())
 95            .await
 96            .unwrap();
 97        server
 98    }
 99
100    pub fn disconnect(&self) {
101        self.peer.disconnect(self.connection_id());
102        let mut state = self.state.lock();
103        state.connection_id.take();
104        state.incoming.take();
105    }
106
107    pub fn auth_count(&self) -> usize {
108        self.state.lock().auth_count
109    }
110
111    pub fn roll_access_token(&self) {
112        self.state.lock().access_token += 1;
113    }
114
115    pub fn forbid_connections(&self) {
116        self.state.lock().forbid_connections = true;
117    }
118
119    pub fn allow_connections(&self) {
120        self.state.lock().forbid_connections = false;
121    }
122
123    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
124        self.peer.send(self.connection_id(), message).unwrap();
125    }
126
127    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
128        self.executor.start_waiting();
129        let message = self
130            .state
131            .lock()
132            .incoming
133            .as_mut()
134            .expect("not connected")
135            .next()
136            .await
137            .ok_or_else(|| anyhow!("other half hung up"))?;
138        self.executor.finish_waiting();
139        let type_name = message.payload_type_name();
140        Ok(*message
141            .into_any()
142            .downcast::<TypedEnvelope<M>>()
143            .unwrap_or_else(|_| {
144                panic!(
145                    "fake server received unexpected message type: {:?}",
146                    type_name
147                );
148            }))
149    }
150
151    pub async fn respond<T: proto::RequestMessage>(
152        &self,
153        receipt: Receipt<T>,
154        response: T::Response,
155    ) {
156        self.peer.respond(receipt, response).unwrap()
157    }
158
159    fn connection_id(&self) -> ConnectionId {
160        self.state.lock().connection_id.expect("not connected")
161    }
162
163    pub async fn build_user_store(
164        &self,
165        client: Arc<Client>,
166        cx: &mut TestAppContext,
167    ) -> ModelHandle<UserStore> {
168        let http_client = FakeHttpClient::with_404_response();
169        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
170        assert_eq!(
171            self.receive::<proto::GetUsers>()
172                .await
173                .unwrap()
174                .payload
175                .user_ids,
176            &[self.user_id]
177        );
178        user_store
179    }
180}
181
182pub struct FakeHttpClient {
183    handler:
184        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
185}
186
187impl FakeHttpClient {
188    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
189    where
190        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
191        F: 'static + Send + Sync + Fn(Request) -> Fut,
192    {
193        Arc::new(Self {
194            handler: Box::new(move |req| Box::pin(handler(req))),
195        })
196    }
197
198    pub fn with_404_response() -> Arc<dyn HttpClient> {
199        Self::new(|_| async move { Ok(ServerResponse::new(404)) })
200    }
201}
202
203impl fmt::Debug for FakeHttpClient {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        f.debug_struct("FakeHttpClient").finish()
206    }
207}
208
209impl HttpClient for FakeHttpClient {
210    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
211        let future = (self.handler)(req);
212        Box::pin(async move { future.await.map(Into::into) })
213    }
214}