test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use chrono::Duration;
  4use futures::{StreamExt, stream::BoxStream};
  5use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
  6use parking_lot::Mutex;
  7use rpc::{
  8    ConnectionId, Peer, Receipt, TypedEnvelope,
  9    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 10};
 11use std::sync::Arc;
 12
 13pub struct FakeServer {
 14    peer: Arc<Peer>,
 15    state: Arc<Mutex<FakeServerState>>,
 16    user_id: u64,
 17    executor: BackgroundExecutor,
 18}
 19
 20#[derive(Default)]
 21struct FakeServerState {
 22    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 23    connection_id: Option<ConnectionId>,
 24    forbid_connections: bool,
 25    auth_count: usize,
 26    access_token: usize,
 27}
 28
 29impl FakeServer {
 30    pub async fn for_client(
 31        client_user_id: u64,
 32        client: &Arc<Client>,
 33        cx: &TestAppContext,
 34    ) -> Self {
 35        let server = Self {
 36            peer: Peer::new(0),
 37            state: Default::default(),
 38            user_id: client_user_id,
 39            executor: cx.executor(),
 40        };
 41
 42        client
 43            .override_authenticate({
 44                let state = Arc::downgrade(&server.state);
 45                move |cx| {
 46                    let state = state.clone();
 47                    cx.spawn(async move |_| {
 48                        let state = state.upgrade().context("server dropped")?;
 49                        let mut state = state.lock();
 50                        state.auth_count += 1;
 51                        let access_token = state.access_token.to_string();
 52                        Ok(Credentials {
 53                            user_id: client_user_id,
 54                            access_token,
 55                        })
 56                    })
 57                }
 58            })
 59            .override_establish_connection({
 60                let peer = Arc::downgrade(&server.peer);
 61                let state = Arc::downgrade(&server.state);
 62                move |credentials, cx| {
 63                    let peer = peer.clone();
 64                    let state = state.clone();
 65                    let credentials = credentials.clone();
 66                    cx.spawn(async move |cx| {
 67                        let state = state.upgrade().context("server dropped")?;
 68                        let peer = peer.upgrade().context("server dropped")?;
 69                        if state.lock().forbid_connections {
 70                            Err(EstablishConnectionError::Other(anyhow!(
 71                                "server is forbidding connections"
 72                            )))?
 73                        }
 74
 75                        if credentials
 76                            != (Credentials {
 77                                user_id: client_user_id,
 78                                access_token: state.lock().access_token.to_string(),
 79                            })
 80                        {
 81                            Err(EstablishConnectionError::Unauthorized)?
 82                        }
 83
 84                        let (client_conn, server_conn, _) =
 85                            Connection::in_memory(cx.background_executor().clone());
 86                        let (connection_id, io, incoming) =
 87                            peer.add_test_connection(server_conn, cx.background_executor().clone());
 88                        cx.background_spawn(io).detach();
 89                        {
 90                            let mut state = state.lock();
 91                            state.connection_id = Some(connection_id);
 92                            state.incoming = Some(incoming);
 93                        }
 94                        peer.send(
 95                            connection_id,
 96                            proto::Hello {
 97                                peer_id: Some(connection_id.into()),
 98                            },
 99                        )
100                        .unwrap();
101
102                        Ok(client_conn)
103                    })
104                }
105            });
106
107        client
108            .authenticate_and_connect(false, &cx.to_async())
109            .await
110            .into_response()
111            .unwrap();
112
113        server
114    }
115
116    pub fn disconnect(&self) {
117        if self.state.lock().connection_id.is_some() {
118            self.peer.disconnect(self.connection_id());
119            let mut state = self.state.lock();
120            state.connection_id.take();
121            state.incoming.take();
122        }
123    }
124
125    pub fn auth_count(&self) -> usize {
126        self.state.lock().auth_count
127    }
128
129    pub fn roll_access_token(&self) {
130        self.state.lock().access_token += 1;
131    }
132
133    pub fn forbid_connections(&self) {
134        self.state.lock().forbid_connections = true;
135    }
136
137    pub fn allow_connections(&self) {
138        self.state.lock().forbid_connections = false;
139    }
140
141    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
142        self.peer.send(self.connection_id(), message).unwrap();
143    }
144
145    #[allow(clippy::await_holding_lock)]
146    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
147        self.executor.start_waiting();
148
149        loop {
150            let message = self
151                .state
152                .lock()
153                .incoming
154                .as_mut()
155                .expect("not connected")
156                .next()
157                .await
158                .context("other half hung up")?;
159            self.executor.finish_waiting();
160            let type_name = message.payload_type_name();
161            let message = message.into_any();
162
163            if message.is::<TypedEnvelope<M>>() {
164                return Ok(*message.downcast().unwrap());
165            }
166
167            let accepted_tos_at = chrono::Utc::now()
168                .checked_sub_signed(Duration::hours(5))
169                .expect("failed to build accepted_tos_at")
170                .timestamp() as u64;
171
172            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
173                self.respond(
174                    message
175                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
176                        .unwrap()
177                        .receipt(),
178                    GetPrivateUserInfoResponse {
179                        metrics_id: "the-metrics-id".into(),
180                        staff: false,
181                        flags: Default::default(),
182                        accepted_tos_at: Some(accepted_tos_at),
183                    },
184                );
185                continue;
186            }
187
188            panic!(
189                "fake server received unexpected message type: {:?}",
190                type_name
191            );
192        }
193    }
194
195    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
196        self.peer.respond(receipt, response).unwrap()
197    }
198
199    fn connection_id(&self) -> ConnectionId {
200        self.state.lock().connection_id.expect("not connected")
201    }
202
203    pub async fn build_user_store(
204        &self,
205        client: Arc<Client>,
206        cx: &mut TestAppContext,
207    ) -> Entity<UserStore> {
208        let user_store = cx.new(|cx| UserStore::new(client, cx));
209        assert_eq!(
210            self.receive::<proto::GetUsers>()
211                .await
212                .unwrap()
213                .payload
214                .user_ids,
215            &[self.user_id]
216        );
217        user_store
218    }
219}
220
221impl Drop for FakeServer {
222    fn drop(&mut self) {
223        self.disconnect();
224    }
225}