test.rs

  1use crate::{
  2    assets::Assets,
  3    channel::ChannelList,
  4    fs::FakeFs,
  5    http::{HttpClient, Request, Response, ServerResponse},
  6    language::LanguageRegistry,
  7    rpc::{self, Client, Credentials, EstablishConnectionError},
  8    settings::{self, ThemeRegistry},
  9    user::UserStore,
 10    AppState,
 11};
 12use anyhow::{anyhow, Result};
 13use futures::{future::BoxFuture, Future};
 14use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
 15use parking_lot::Mutex;
 16use postage::{mpsc, prelude::Stream as _};
 17use smol::channel;
 18use std::{
 19    fmt,
 20    marker::PhantomData,
 21    path::{Path, PathBuf},
 22    sync::{
 23        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 24        Arc,
 25    },
 26};
 27use tempdir::TempDir;
 28use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
 29
 30#[cfg(test)]
 31#[ctor::ctor]
 32fn init_logger() {
 33    env_logger::init();
 34}
 35
 36pub fn sample_text(rows: usize, cols: usize) -> String {
 37    let mut text = String::new();
 38    for row in 0..rows {
 39        let c: char = ('a' as u32 + row as u32) as u8 as char;
 40        let mut line = c.to_string().repeat(cols);
 41        if row < rows - 1 {
 42            line.push('\n');
 43        }
 44        text += &line;
 45    }
 46    text
 47}
 48
 49pub fn temp_tree(tree: serde_json::Value) -> TempDir {
 50    let dir = TempDir::new("").unwrap();
 51    write_tree(dir.path(), tree);
 52    dir
 53}
 54
 55fn write_tree(path: &Path, tree: serde_json::Value) {
 56    use serde_json::Value;
 57    use std::fs;
 58
 59    if let Value::Object(map) = tree {
 60        for (name, contents) in map {
 61            let mut path = PathBuf::from(path);
 62            path.push(name);
 63            match contents {
 64                Value::Object(_) => {
 65                    fs::create_dir(&path).unwrap();
 66                    write_tree(&path, contents);
 67                }
 68                Value::Null => {
 69                    fs::create_dir(&path).unwrap();
 70                }
 71                Value::String(contents) => {
 72                    fs::write(&path, contents).unwrap();
 73                }
 74                _ => {
 75                    panic!("JSON object must contain only objects, strings, or null");
 76                }
 77            }
 78        }
 79    } else {
 80        panic!("You must pass a JSON object to this helper")
 81    }
 82}
 83
 84pub fn test_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
 85    let (settings_tx, settings) = settings::test(cx);
 86    let languages = Arc::new(LanguageRegistry::new());
 87    let themes = ThemeRegistry::new(Assets, cx.font_cache().clone());
 88    let rpc = rpc::Client::new();
 89    let http = FakeHttpClient::new(|_| async move { Ok(ServerResponse::new(404)) });
 90    let user_store = cx.add_model(|cx| UserStore::new(rpc.clone(), http, cx));
 91    Arc::new(AppState {
 92        settings_tx: Arc::new(Mutex::new(settings_tx)),
 93        settings,
 94        themes,
 95        languages: languages.clone(),
 96        channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)),
 97        rpc,
 98        user_store,
 99        fs: Arc::new(FakeFs::new()),
100    })
101}
102
103pub struct Observer<T>(PhantomData<T>);
104
105impl<T: 'static> Entity for Observer<T> {
106    type Event = ();
107}
108
109impl<T: Entity> Observer<T> {
110    pub fn new(
111        handle: &ModelHandle<T>,
112        cx: &mut gpui::TestAppContext,
113    ) -> (ModelHandle<Self>, channel::Receiver<()>) {
114        let (notify_tx, notify_rx) = channel::unbounded();
115        let observer = cx.add_model(|cx| {
116            cx.observe(handle, move |_, _, _| {
117                let _ = notify_tx.try_send(());
118            })
119            .detach();
120            Observer(PhantomData)
121        });
122        (observer, notify_rx)
123    }
124}
125
126pub struct FakeServer {
127    peer: Arc<Peer>,
128    incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
129    connection_id: Mutex<Option<ConnectionId>>,
130    forbid_connections: AtomicBool,
131    auth_count: AtomicUsize,
132    access_token: AtomicUsize,
133    user_id: u64,
134}
135
136impl FakeServer {
137    pub async fn for_client(
138        client_user_id: u64,
139        client: &mut Arc<Client>,
140        cx: &TestAppContext,
141    ) -> Arc<Self> {
142        let server = Arc::new(Self {
143            peer: Peer::new(),
144            incoming: Default::default(),
145            connection_id: Default::default(),
146            forbid_connections: Default::default(),
147            auth_count: Default::default(),
148            access_token: Default::default(),
149            user_id: client_user_id,
150        });
151
152        Arc::get_mut(client)
153            .unwrap()
154            .override_authenticate({
155                let server = server.clone();
156                move |cx| {
157                    server.auth_count.fetch_add(1, SeqCst);
158                    let access_token = server.access_token.load(SeqCst).to_string();
159                    cx.spawn(move |_| async move {
160                        Ok(Credentials {
161                            user_id: client_user_id,
162                            access_token,
163                        })
164                    })
165                }
166            })
167            .override_establish_connection({
168                let server = server.clone();
169                move |credentials, cx| {
170                    let credentials = credentials.clone();
171                    cx.spawn({
172                        let server = server.clone();
173                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
174                    })
175                }
176            });
177
178        client
179            .authenticate_and_connect(&cx.to_async())
180            .await
181            .unwrap();
182        server
183    }
184
185    pub async fn disconnect(&self) {
186        self.peer.disconnect(self.connection_id()).await;
187        self.connection_id.lock().take();
188        self.incoming.lock().take();
189    }
190
191    async fn establish_connection(
192        &self,
193        credentials: &Credentials,
194        cx: &AsyncAppContext,
195    ) -> Result<Connection, EstablishConnectionError> {
196        assert_eq!(credentials.user_id, self.user_id);
197
198        if self.forbid_connections.load(SeqCst) {
199            Err(EstablishConnectionError::Other(anyhow!(
200                "server is forbidding connections"
201            )))?
202        }
203
204        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
205            Err(EstablishConnectionError::Unauthorized)?
206        }
207
208        let (client_conn, server_conn, _) = Connection::in_memory();
209        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
210        cx.background().spawn(io).detach();
211        *self.incoming.lock() = Some(incoming);
212        *self.connection_id.lock() = Some(connection_id);
213        Ok(client_conn)
214    }
215
216    pub fn auth_count(&self) -> usize {
217        self.auth_count.load(SeqCst)
218    }
219
220    pub fn roll_access_token(&self) {
221        self.access_token.fetch_add(1, SeqCst);
222    }
223
224    pub fn forbid_connections(&self) {
225        self.forbid_connections.store(true, SeqCst);
226    }
227
228    pub fn allow_connections(&self) {
229        self.forbid_connections.store(false, SeqCst);
230    }
231
232    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
233        self.peer.send(self.connection_id(), message).await.unwrap();
234    }
235
236    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
237        let message = self
238            .incoming
239            .lock()
240            .as_mut()
241            .expect("not connected")
242            .recv()
243            .await
244            .ok_or_else(|| anyhow!("other half hung up"))?;
245        let type_name = message.payload_type_name();
246        Ok(*message
247            .into_any()
248            .downcast::<TypedEnvelope<M>>()
249            .unwrap_or_else(|_| {
250                panic!(
251                    "fake server received unexpected message type: {:?}",
252                    type_name
253                );
254            }))
255    }
256
257    pub async fn respond<T: proto::RequestMessage>(
258        &self,
259        receipt: Receipt<T>,
260        response: T::Response,
261    ) {
262        self.peer.respond(receipt, response).await.unwrap()
263    }
264
265    fn connection_id(&self) -> ConnectionId {
266        self.connection_id.lock().expect("not connected")
267    }
268}
269
270pub struct FakeHttpClient {
271    handler:
272        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
273}
274
275impl FakeHttpClient {
276    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
277    where
278        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
279        F: 'static + Send + Sync + Fn(Request) -> Fut,
280    {
281        Arc::new(Self {
282            handler: Box::new(move |req| Box::pin(handler(req))),
283        })
284    }
285}
286
287impl fmt::Debug for FakeHttpClient {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        f.debug_struct("FakeHttpClient").finish()
290    }
291}
292
293impl HttpClient for FakeHttpClient {
294    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
295        let future = (self.handler)(req);
296        Box::pin(async move { future.await.map(Into::into) })
297    }
298}