test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{anyhow, Result};
  3use futures::{stream::BoxStream, StreamExt};
  4use gpui::{executor, ModelHandle, TestAppContext};
  5use parking_lot::Mutex;
  6use rpc::{
  7    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
  8    ConnectionId, Peer, Receipt, TypedEnvelope,
  9};
 10use std::{rc::Rc, sync::Arc};
 11use util::http::FakeHttpClient;
 12
 13pub struct FakeServer {
 14    peer: Arc<Peer>,
 15    state: Arc<Mutex<FakeServerState>>,
 16    user_id: u64,
 17    executor: Rc<executor::Foreground>,
 18}
 19
 20#[derive(Default)]
 21struct FakeServerState {
 22    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 23    connection_id: Option<ConnectionId>,
 24    forbid_connections: bool,
 25    auth_count: usize,
 26    access_token: usize,
 27}
 28
 29impl FakeServer {
 30    pub async fn for_client(
 31        client_user_id: u64,
 32        client: &Arc<Client>,
 33        cx: &TestAppContext,
 34    ) -> Self {
 35        let server = Self {
 36            peer: Peer::new(0),
 37            state: Default::default(),
 38            user_id: client_user_id,
 39            executor: cx.foreground(),
 40        };
 41
 42        client
 43            .override_authenticate({
 44                let state = Arc::downgrade(&server.state);
 45                move |cx| {
 46                    let state = state.clone();
 47                    cx.spawn(move |_| async move {
 48                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 49                        let mut state = state.lock();
 50                        state.auth_count += 1;
 51                        let access_token = state.access_token.to_string();
 52                        Ok(Credentials {
 53                            user_id: client_user_id,
 54                            access_token,
 55                        })
 56                    })
 57                }
 58            })
 59            .override_establish_connection({
 60                let peer = Arc::downgrade(&server.peer);
 61                let state = Arc::downgrade(&server.state);
 62                move |credentials, cx| {
 63                    let peer = peer.clone();
 64                    let state = state.clone();
 65                    let credentials = credentials.clone();
 66                    cx.spawn(move |cx| async move {
 67                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 68                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 69                        if state.lock().forbid_connections {
 70                            Err(EstablishConnectionError::Other(anyhow!(
 71                                "server is forbidding connections"
 72                            )))?
 73                        }
 74
 75                        assert_eq!(credentials.user_id, client_user_id);
 76
 77                        if credentials.access_token != state.lock().access_token.to_string() {
 78                            Err(EstablishConnectionError::Unauthorized)?
 79                        }
 80
 81                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 82                        let (connection_id, io, incoming) =
 83                            peer.add_test_connection(server_conn, cx.background());
 84                        cx.background().spawn(io).detach();
 85                        {
 86                            let mut state = state.lock();
 87                            state.connection_id = Some(connection_id);
 88                            state.incoming = Some(incoming);
 89                        }
 90                        peer.send(
 91                            connection_id,
 92                            proto::Hello {
 93                                peer_id: Some(connection_id.into()),
 94                            },
 95                        )
 96                        .unwrap();
 97
 98                        Ok(client_conn)
 99                    })
100                }
101            });
102
103        client
104            .authenticate_and_connect(false, &cx.to_async())
105            .await
106            .unwrap();
107
108        server
109    }
110
111    pub fn disconnect(&self) {
112        if self.state.lock().connection_id.is_some() {
113            self.peer.disconnect(self.connection_id());
114            let mut state = self.state.lock();
115            state.connection_id.take();
116            state.incoming.take();
117        }
118    }
119
120    pub fn auth_count(&self) -> usize {
121        self.state.lock().auth_count
122    }
123
124    pub fn roll_access_token(&self) {
125        self.state.lock().access_token += 1;
126    }
127
128    pub fn forbid_connections(&self) {
129        self.state.lock().forbid_connections = true;
130    }
131
132    pub fn allow_connections(&self) {
133        self.state.lock().forbid_connections = false;
134    }
135
136    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
137        self.peer.send(self.connection_id(), message).unwrap();
138    }
139
140    #[allow(clippy::await_holding_lock)]
141    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
142        self.executor.start_waiting();
143
144        loop {
145            let message = self
146                .state
147                .lock()
148                .incoming
149                .as_mut()
150                .expect("not connected")
151                .next()
152                .await
153                .ok_or_else(|| anyhow!("other half hung up"))?;
154            self.executor.finish_waiting();
155            let type_name = message.payload_type_name();
156            let message = message.into_any();
157
158            if message.is::<TypedEnvelope<M>>() {
159                return Ok(*message.downcast().unwrap());
160            }
161
162            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
163                self.respond(
164                    message
165                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
166                        .unwrap()
167                        .receipt(),
168                    GetPrivateUserInfoResponse {
169                        metrics_id: "the-metrics-id".into(),
170                        staff: false,
171                        flags: Default::default(),
172                    },
173                );
174                continue;
175            }
176
177            panic!(
178                "fake server received unexpected message type: {:?}",
179                type_name
180            );
181        }
182    }
183
184    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
185        self.peer.respond(receipt, response).unwrap()
186    }
187
188    fn connection_id(&self) -> ConnectionId {
189        self.state.lock().connection_id.expect("not connected")
190    }
191
192    pub async fn build_user_store(
193        &self,
194        client: Arc<Client>,
195        cx: &mut TestAppContext,
196    ) -> ModelHandle<UserStore> {
197        let http_client = FakeHttpClient::with_404_response();
198        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
199        assert_eq!(
200            self.receive::<proto::GetUsers>()
201                .await
202                .unwrap()
203                .payload
204                .user_ids,
205            &[self.user_id]
206        );
207        user_store
208    }
209}
210
211impl Drop for FakeServer {
212    fn drop(&mut self) {
213        self.disconnect();
214    }
215}