test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{anyhow, Result};
  3use futures::{stream::BoxStream, StreamExt};
  4use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
  5use parking_lot::Mutex;
  6use rpc::{
  7    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
  8    ConnectionId, Peer, Receipt, TypedEnvelope,
  9};
 10use std::sync::Arc;
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: BackgroundExecutor,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(0),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.executor(),
 39        };
 40
 41        client
 42            .override_authenticate({
 43                let state = Arc::downgrade(&server.state);
 44                move |cx| {
 45                    let state = state.clone();
 46                    cx.spawn(move |_| async move {
 47                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 48                        let mut state = state.lock();
 49                        state.auth_count += 1;
 50                        let access_token = state.access_token.to_string();
 51                        Ok(Credentials::User {
 52                            user_id: client_user_id,
 53                            access_token,
 54                        })
 55                    })
 56                }
 57            })
 58            .override_establish_connection({
 59                let peer = Arc::downgrade(&server.peer);
 60                let state = Arc::downgrade(&server.state);
 61                move |credentials, cx| {
 62                    let peer = peer.clone();
 63                    let state = state.clone();
 64                    let credentials = credentials.clone();
 65                    cx.spawn(move |cx| async move {
 66                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 67                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
 68                        if state.lock().forbid_connections {
 69                            Err(EstablishConnectionError::Other(anyhow!(
 70                                "server is forbidding connections"
 71                            )))?
 72                        }
 73
 74                        if credentials
 75                            != (Credentials::User {
 76                                user_id: client_user_id,
 77                                access_token: state.lock().access_token.to_string(),
 78                            })
 79                        {
 80                            Err(EstablishConnectionError::Unauthorized)?
 81                        }
 82
 83                        let (client_conn, server_conn, _) =
 84                            Connection::in_memory(cx.background_executor().clone());
 85                        let (connection_id, io, incoming) =
 86                            peer.add_test_connection(server_conn, cx.background_executor().clone());
 87                        cx.background_executor().spawn(io).detach();
 88                        {
 89                            let mut state = state.lock();
 90                            state.connection_id = Some(connection_id);
 91                            state.incoming = Some(incoming);
 92                        }
 93                        peer.send(
 94                            connection_id,
 95                            proto::Hello {
 96                                peer_id: Some(connection_id.into()),
 97                            },
 98                        )
 99                        .unwrap();
100
101                        Ok(client_conn)
102                    })
103                }
104            });
105
106        client
107            .authenticate_and_connect(false, &cx.to_async())
108            .await
109            .unwrap();
110
111        server
112    }
113
114    pub fn disconnect(&self) {
115        if self.state.lock().connection_id.is_some() {
116            self.peer.disconnect(self.connection_id());
117            let mut state = self.state.lock();
118            state.connection_id.take();
119            state.incoming.take();
120        }
121    }
122
123    pub fn auth_count(&self) -> usize {
124        self.state.lock().auth_count
125    }
126
127    pub fn roll_access_token(&self) {
128        self.state.lock().access_token += 1;
129    }
130
131    pub fn forbid_connections(&self) {
132        self.state.lock().forbid_connections = true;
133    }
134
135    pub fn allow_connections(&self) {
136        self.state.lock().forbid_connections = false;
137    }
138
139    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
140        self.peer.send(self.connection_id(), message).unwrap();
141    }
142
143    #[allow(clippy::await_holding_lock)]
144    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
145        self.executor.start_waiting();
146
147        loop {
148            let message = self
149                .state
150                .lock()
151                .incoming
152                .as_mut()
153                .expect("not connected")
154                .next()
155                .await
156                .ok_or_else(|| anyhow!("other half hung up"))?;
157            self.executor.finish_waiting();
158            let type_name = message.payload_type_name();
159            let message = message.into_any();
160
161            if message.is::<TypedEnvelope<M>>() {
162                return Ok(*message.downcast().unwrap());
163            }
164
165            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
166                self.respond(
167                    message
168                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
169                        .unwrap()
170                        .receipt(),
171                    GetPrivateUserInfoResponse {
172                        metrics_id: "the-metrics-id".into(),
173                        staff: false,
174                        flags: Default::default(),
175                    },
176                );
177                continue;
178            }
179
180            panic!(
181                "fake server received unexpected message type: {:?}",
182                type_name
183            );
184        }
185    }
186
187    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
188        self.peer.respond(receipt, response).unwrap()
189    }
190
191    fn connection_id(&self) -> ConnectionId {
192        self.state.lock().connection_id.expect("not connected")
193    }
194
195    pub async fn build_user_store(
196        &self,
197        client: Arc<Client>,
198        cx: &mut TestAppContext,
199    ) -> Model<UserStore> {
200        let user_store = cx.new_model(|cx| UserStore::new(client, cx));
201        assert_eq!(
202            self.receive::<proto::GetUsers>()
203                .await
204                .unwrap()
205                .payload
206                .user_ids,
207            &[self.user_id]
208        );
209        user_store
210    }
211}
212
213impl Drop for FakeServer {
214    fn drop(&mut self) {
215        self.disconnect();
216    }
217}