test.rs

  1use crate::{
  2    assets::Assets,
  3    channel::ChannelList,
  4    fs::FakeFs,
  5    http::{HttpClient, Request, Response, ServerResponse},
  6    language::LanguageRegistry,
  7    rpc::{self, Client, Credentials, EstablishConnectionError},
  8    settings::{self, ThemeRegistry},
  9    user::UserStore,
 10    AppState,
 11};
 12use anyhow::{anyhow, Result};
 13use clock::ReplicaId;
 14use futures::{future::BoxFuture, Future};
 15use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
 16use parking_lot::Mutex;
 17use postage::{mpsc, prelude::Stream as _};
 18use smol::channel;
 19use std::{
 20    fmt,
 21    marker::PhantomData,
 22    path::{Path, PathBuf},
 23    sync::{
 24        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 25        Arc,
 26    },
 27};
 28use tempdir::TempDir;
 29use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
 30
 31#[cfg(test)]
 32#[ctor::ctor]
 33fn init_logger() {
 34    env_logger::init();
 35}
 36
 37#[derive(Clone)]
 38struct Envelope<T: Clone> {
 39    message: T,
 40    sender: ReplicaId,
 41}
 42
 43#[cfg(test)]
 44pub(crate) struct Network<T: Clone, R: rand::Rng> {
 45    inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
 46    all_messages: Vec<T>,
 47    rng: R,
 48}
 49
 50#[cfg(test)]
 51impl<T: Clone, R: rand::Rng> Network<T, R> {
 52    pub fn new(rng: R) -> Self {
 53        Network {
 54            inboxes: Default::default(),
 55            all_messages: Vec::new(),
 56            rng,
 57        }
 58    }
 59
 60    pub fn add_peer(&mut self, id: ReplicaId) {
 61        self.inboxes.insert(id, Vec::new());
 62    }
 63
 64    pub fn is_idle(&self) -> bool {
 65        self.inboxes.values().all(|i| i.is_empty())
 66    }
 67
 68    pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
 69        for (replica, inbox) in self.inboxes.iter_mut() {
 70            if *replica != sender {
 71                for message in &messages {
 72                    let min_index = inbox
 73                        .iter()
 74                        .enumerate()
 75                        .rev()
 76                        .find_map(|(index, envelope)| {
 77                            if sender == envelope.sender {
 78                                Some(index + 1)
 79                            } else {
 80                                None
 81                            }
 82                        })
 83                        .unwrap_or(0);
 84
 85                    // Insert one or more duplicates of this message *after* the previous
 86                    // message delivered by this replica.
 87                    for _ in 0..self.rng.gen_range(1..4) {
 88                        let insertion_index = self.rng.gen_range(min_index..inbox.len() + 1);
 89                        inbox.insert(
 90                            insertion_index,
 91                            Envelope {
 92                                message: message.clone(),
 93                                sender,
 94                            },
 95                        );
 96                    }
 97                }
 98            }
 99        }
100        self.all_messages.extend(messages);
101    }
102
103    pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
104        !self.inboxes[&receiver].is_empty()
105    }
106
107    pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
108        let inbox = self.inboxes.get_mut(&receiver).unwrap();
109        let count = self.rng.gen_range(0..inbox.len() + 1);
110        inbox
111            .drain(0..count)
112            .map(|envelope| envelope.message)
113            .collect()
114    }
115}
116
117pub fn sample_text(rows: usize, cols: usize) -> String {
118    let mut text = String::new();
119    for row in 0..rows {
120        let c: char = ('a' as u32 + row as u32) as u8 as char;
121        let mut line = c.to_string().repeat(cols);
122        if row < rows - 1 {
123            line.push('\n');
124        }
125        text += &line;
126    }
127    text
128}
129
130pub fn temp_tree(tree: serde_json::Value) -> TempDir {
131    let dir = TempDir::new("").unwrap();
132    write_tree(dir.path(), tree);
133    dir
134}
135
136fn write_tree(path: &Path, tree: serde_json::Value) {
137    use serde_json::Value;
138    use std::fs;
139
140    if let Value::Object(map) = tree {
141        for (name, contents) in map {
142            let mut path = PathBuf::from(path);
143            path.push(name);
144            match contents {
145                Value::Object(_) => {
146                    fs::create_dir(&path).unwrap();
147                    write_tree(&path, contents);
148                }
149                Value::Null => {
150                    fs::create_dir(&path).unwrap();
151                }
152                Value::String(contents) => {
153                    fs::write(&path, contents).unwrap();
154                }
155                _ => {
156                    panic!("JSON object must contain only objects, strings, or null");
157                }
158            }
159        }
160    } else {
161        panic!("You must pass a JSON object to this helper")
162    }
163}
164
165pub fn test_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
166    let (settings_tx, settings) = settings::test(cx);
167    let languages = Arc::new(LanguageRegistry::new());
168    let themes = ThemeRegistry::new(Assets, cx.font_cache().clone());
169    let rpc = rpc::Client::new();
170    let http = FakeHttpClient::new(|_| async move { Ok(ServerResponse::new(404)) });
171    let user_store = cx.add_model(|cx| UserStore::new(rpc.clone(), http, cx));
172    Arc::new(AppState {
173        settings_tx: Arc::new(Mutex::new(settings_tx)),
174        settings,
175        themes,
176        languages: languages.clone(),
177        channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)),
178        rpc,
179        user_store,
180        fs: Arc::new(FakeFs::new()),
181    })
182}
183
184pub struct Observer<T>(PhantomData<T>);
185
186impl<T: 'static> Entity for Observer<T> {
187    type Event = ();
188}
189
190impl<T: Entity> Observer<T> {
191    pub fn new(
192        handle: &ModelHandle<T>,
193        cx: &mut gpui::TestAppContext,
194    ) -> (ModelHandle<Self>, channel::Receiver<()>) {
195        let (notify_tx, notify_rx) = channel::unbounded();
196        let observer = cx.add_model(|cx| {
197            cx.observe(handle, move |_, _, _| {
198                let _ = notify_tx.try_send(());
199            })
200            .detach();
201            Observer(PhantomData)
202        });
203        (observer, notify_rx)
204    }
205}
206
207pub struct FakeServer {
208    peer: Arc<Peer>,
209    incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
210    connection_id: Mutex<Option<ConnectionId>>,
211    forbid_connections: AtomicBool,
212    auth_count: AtomicUsize,
213    access_token: AtomicUsize,
214    user_id: u64,
215}
216
217impl FakeServer {
218    pub async fn for_client(
219        client_user_id: u64,
220        client: &mut Arc<Client>,
221        cx: &TestAppContext,
222    ) -> Arc<Self> {
223        let server = Arc::new(Self {
224            peer: Peer::new(),
225            incoming: Default::default(),
226            connection_id: Default::default(),
227            forbid_connections: Default::default(),
228            auth_count: Default::default(),
229            access_token: Default::default(),
230            user_id: client_user_id,
231        });
232
233        Arc::get_mut(client)
234            .unwrap()
235            .override_authenticate({
236                let server = server.clone();
237                move |cx| {
238                    server.auth_count.fetch_add(1, SeqCst);
239                    let access_token = server.access_token.load(SeqCst).to_string();
240                    cx.spawn(move |_| async move {
241                        Ok(Credentials {
242                            user_id: client_user_id,
243                            access_token,
244                        })
245                    })
246                }
247            })
248            .override_establish_connection({
249                let server = server.clone();
250                move |credentials, cx| {
251                    let credentials = credentials.clone();
252                    cx.spawn({
253                        let server = server.clone();
254                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
255                    })
256                }
257            });
258
259        client
260            .authenticate_and_connect(&cx.to_async())
261            .await
262            .unwrap();
263        server
264    }
265
266    pub async fn disconnect(&self) {
267        self.peer.disconnect(self.connection_id()).await;
268        self.connection_id.lock().take();
269        self.incoming.lock().take();
270    }
271
272    async fn establish_connection(
273        &self,
274        credentials: &Credentials,
275        cx: &AsyncAppContext,
276    ) -> Result<Connection, EstablishConnectionError> {
277        assert_eq!(credentials.user_id, self.user_id);
278
279        if self.forbid_connections.load(SeqCst) {
280            Err(EstablishConnectionError::Other(anyhow!(
281                "server is forbidding connections"
282            )))?
283        }
284
285        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
286            Err(EstablishConnectionError::Unauthorized)?
287        }
288
289        let (client_conn, server_conn, _) = Connection::in_memory();
290        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
291        cx.background().spawn(io).detach();
292        *self.incoming.lock() = Some(incoming);
293        *self.connection_id.lock() = Some(connection_id);
294        Ok(client_conn)
295    }
296
297    pub fn auth_count(&self) -> usize {
298        self.auth_count.load(SeqCst)
299    }
300
301    pub fn roll_access_token(&self) {
302        self.access_token.fetch_add(1, SeqCst);
303    }
304
305    pub fn forbid_connections(&self) {
306        self.forbid_connections.store(true, SeqCst);
307    }
308
309    pub fn allow_connections(&self) {
310        self.forbid_connections.store(false, SeqCst);
311    }
312
313    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
314        self.peer.send(self.connection_id(), message).await.unwrap();
315    }
316
317    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
318        let message = self
319            .incoming
320            .lock()
321            .as_mut()
322            .expect("not connected")
323            .recv()
324            .await
325            .ok_or_else(|| anyhow!("other half hung up"))?;
326        let type_name = message.payload_type_name();
327        Ok(*message
328            .into_any()
329            .downcast::<TypedEnvelope<M>>()
330            .unwrap_or_else(|_| {
331                panic!(
332                    "fake server received unexpected message type: {:?}",
333                    type_name
334                );
335            }))
336    }
337
338    pub async fn respond<T: proto::RequestMessage>(
339        &self,
340        receipt: Receipt<T>,
341        response: T::Response,
342    ) {
343        self.peer.respond(receipt, response).await.unwrap()
344    }
345
346    fn connection_id(&self) -> ConnectionId {
347        self.connection_id.lock().expect("not connected")
348    }
349}
350
351pub struct FakeHttpClient {
352    handler:
353        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
354}
355
356impl FakeHttpClient {
357    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
358    where
359        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
360        F: 'static + Send + Sync + Fn(Request) -> Fut,
361    {
362        Arc::new(Self {
363            handler: Box::new(move |req| Box::pin(handler(req))),
364        })
365    }
366}
367
368impl fmt::Debug for FakeHttpClient {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        f.debug_struct("FakeHttpClient").finish()
371    }
372}
373
374impl HttpClient for FakeHttpClient {
375    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
376        let future = (self.handler)(req);
377        Box::pin(async move { future.await.map(Into::into) })
378    }
379}