1use crate::{
2 assets::Assets,
3 channel::ChannelList,
4 fs::FakeFs,
5 http::{HttpClient, Request, Response, ServerResponse},
6 language::LanguageRegistry,
7 rpc::{self, Client, Credentials, EstablishConnectionError},
8 settings::{self, ThemeRegistry},
9 user::UserStore,
10 AppState,
11};
12use anyhow::{anyhow, Result};
13use clock::ReplicaId;
14use futures::{future::BoxFuture, Future};
15use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
16use parking_lot::Mutex;
17use postage::{mpsc, prelude::Stream as _};
18use smol::channel;
19use std::{
20 fmt,
21 marker::PhantomData,
22 path::{Path, PathBuf},
23 sync::{
24 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
25 Arc,
26 },
27};
28use tempdir::TempDir;
29use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
30
31#[cfg(test)]
32#[ctor::ctor]
33fn init_logger() {
34 env_logger::init();
35}
36
37#[derive(Clone)]
38struct Envelope<T: Clone> {
39 message: T,
40 sender: ReplicaId,
41}
42
43#[cfg(test)]
44pub(crate) struct Network<T: Clone, R: rand::Rng> {
45 inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
46 all_messages: Vec<T>,
47 rng: R,
48}
49
50#[cfg(test)]
51impl<T: Clone, R: rand::Rng> Network<T, R> {
52 pub fn new(rng: R) -> Self {
53 Network {
54 inboxes: Default::default(),
55 all_messages: Vec::new(),
56 rng,
57 }
58 }
59
60 pub fn add_peer(&mut self, id: ReplicaId) {
61 self.inboxes.insert(id, Vec::new());
62 }
63
64 pub fn is_idle(&self) -> bool {
65 self.inboxes.values().all(|i| i.is_empty())
66 }
67
68 pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
69 for (replica, inbox) in self.inboxes.iter_mut() {
70 if *replica != sender {
71 for message in &messages {
72 let min_index = inbox
73 .iter()
74 .enumerate()
75 .rev()
76 .find_map(|(index, envelope)| {
77 if sender == envelope.sender {
78 Some(index + 1)
79 } else {
80 None
81 }
82 })
83 .unwrap_or(0);
84
85 // Insert one or more duplicates of this message *after* the previous
86 // message delivered by this replica.
87 for _ in 0..self.rng.gen_range(1..4) {
88 let insertion_index = self.rng.gen_range(min_index..inbox.len() + 1);
89 inbox.insert(
90 insertion_index,
91 Envelope {
92 message: message.clone(),
93 sender,
94 },
95 );
96 }
97 }
98 }
99 }
100 self.all_messages.extend(messages);
101 }
102
103 pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
104 !self.inboxes[&receiver].is_empty()
105 }
106
107 pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
108 let inbox = self.inboxes.get_mut(&receiver).unwrap();
109 let count = self.rng.gen_range(0..inbox.len() + 1);
110 inbox
111 .drain(0..count)
112 .map(|envelope| envelope.message)
113 .collect()
114 }
115}
116
117pub fn sample_text(rows: usize, cols: usize) -> String {
118 let mut text = String::new();
119 for row in 0..rows {
120 let c: char = ('a' as u32 + row as u32) as u8 as char;
121 let mut line = c.to_string().repeat(cols);
122 if row < rows - 1 {
123 line.push('\n');
124 }
125 text += &line;
126 }
127 text
128}
129
130pub fn temp_tree(tree: serde_json::Value) -> TempDir {
131 let dir = TempDir::new("").unwrap();
132 write_tree(dir.path(), tree);
133 dir
134}
135
136fn write_tree(path: &Path, tree: serde_json::Value) {
137 use serde_json::Value;
138 use std::fs;
139
140 if let Value::Object(map) = tree {
141 for (name, contents) in map {
142 let mut path = PathBuf::from(path);
143 path.push(name);
144 match contents {
145 Value::Object(_) => {
146 fs::create_dir(&path).unwrap();
147 write_tree(&path, contents);
148 }
149 Value::Null => {
150 fs::create_dir(&path).unwrap();
151 }
152 Value::String(contents) => {
153 fs::write(&path, contents).unwrap();
154 }
155 _ => {
156 panic!("JSON object must contain only objects, strings, or null");
157 }
158 }
159 }
160 } else {
161 panic!("You must pass a JSON object to this helper")
162 }
163}
164
165pub fn test_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
166 let (settings_tx, settings) = settings::test(cx);
167 let languages = Arc::new(LanguageRegistry::new());
168 let themes = ThemeRegistry::new(Assets, cx.font_cache().clone());
169 let rpc = rpc::Client::new();
170 let http = FakeHttpClient::new(|_| async move { Ok(ServerResponse::new(404)) });
171 let user_store = cx.add_model(|cx| UserStore::new(rpc.clone(), http, cx));
172 Arc::new(AppState {
173 settings_tx: Arc::new(Mutex::new(settings_tx)),
174 settings,
175 themes,
176 languages: languages.clone(),
177 channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)),
178 rpc,
179 user_store,
180 fs: Arc::new(FakeFs::new()),
181 })
182}
183
184pub struct Observer<T>(PhantomData<T>);
185
186impl<T: 'static> Entity for Observer<T> {
187 type Event = ();
188}
189
190impl<T: Entity> Observer<T> {
191 pub fn new(
192 handle: &ModelHandle<T>,
193 cx: &mut gpui::TestAppContext,
194 ) -> (ModelHandle<Self>, channel::Receiver<()>) {
195 let (notify_tx, notify_rx) = channel::unbounded();
196 let observer = cx.add_model(|cx| {
197 cx.observe(handle, move |_, _, _| {
198 let _ = notify_tx.try_send(());
199 })
200 .detach();
201 Observer(PhantomData)
202 });
203 (observer, notify_rx)
204 }
205}
206
207pub struct FakeServer {
208 peer: Arc<Peer>,
209 incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
210 connection_id: Mutex<Option<ConnectionId>>,
211 forbid_connections: AtomicBool,
212 auth_count: AtomicUsize,
213 access_token: AtomicUsize,
214 user_id: u64,
215}
216
217impl FakeServer {
218 pub async fn for_client(
219 client_user_id: u64,
220 client: &mut Arc<Client>,
221 cx: &TestAppContext,
222 ) -> Arc<Self> {
223 let server = Arc::new(Self {
224 peer: Peer::new(),
225 incoming: Default::default(),
226 connection_id: Default::default(),
227 forbid_connections: Default::default(),
228 auth_count: Default::default(),
229 access_token: Default::default(),
230 user_id: client_user_id,
231 });
232
233 Arc::get_mut(client)
234 .unwrap()
235 .override_authenticate({
236 let server = server.clone();
237 move |cx| {
238 server.auth_count.fetch_add(1, SeqCst);
239 let access_token = server.access_token.load(SeqCst).to_string();
240 cx.spawn(move |_| async move {
241 Ok(Credentials {
242 user_id: client_user_id,
243 access_token,
244 })
245 })
246 }
247 })
248 .override_establish_connection({
249 let server = server.clone();
250 move |credentials, cx| {
251 let credentials = credentials.clone();
252 cx.spawn({
253 let server = server.clone();
254 move |cx| async move { server.establish_connection(&credentials, &cx).await }
255 })
256 }
257 });
258
259 client
260 .authenticate_and_connect(&cx.to_async())
261 .await
262 .unwrap();
263 server
264 }
265
266 pub async fn disconnect(&self) {
267 self.peer.disconnect(self.connection_id()).await;
268 self.connection_id.lock().take();
269 self.incoming.lock().take();
270 }
271
272 async fn establish_connection(
273 &self,
274 credentials: &Credentials,
275 cx: &AsyncAppContext,
276 ) -> Result<Connection, EstablishConnectionError> {
277 assert_eq!(credentials.user_id, self.user_id);
278
279 if self.forbid_connections.load(SeqCst) {
280 Err(EstablishConnectionError::Other(anyhow!(
281 "server is forbidding connections"
282 )))?
283 }
284
285 if credentials.access_token != self.access_token.load(SeqCst).to_string() {
286 Err(EstablishConnectionError::Unauthorized)?
287 }
288
289 let (client_conn, server_conn, _) = Connection::in_memory();
290 let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
291 cx.background().spawn(io).detach();
292 *self.incoming.lock() = Some(incoming);
293 *self.connection_id.lock() = Some(connection_id);
294 Ok(client_conn)
295 }
296
297 pub fn auth_count(&self) -> usize {
298 self.auth_count.load(SeqCst)
299 }
300
301 pub fn roll_access_token(&self) {
302 self.access_token.fetch_add(1, SeqCst);
303 }
304
305 pub fn forbid_connections(&self) {
306 self.forbid_connections.store(true, SeqCst);
307 }
308
309 pub fn allow_connections(&self) {
310 self.forbid_connections.store(false, SeqCst);
311 }
312
313 pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
314 self.peer.send(self.connection_id(), message).await.unwrap();
315 }
316
317 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
318 let message = self
319 .incoming
320 .lock()
321 .as_mut()
322 .expect("not connected")
323 .recv()
324 .await
325 .ok_or_else(|| anyhow!("other half hung up"))?;
326 let type_name = message.payload_type_name();
327 Ok(*message
328 .into_any()
329 .downcast::<TypedEnvelope<M>>()
330 .unwrap_or_else(|_| {
331 panic!(
332 "fake server received unexpected message type: {:?}",
333 type_name
334 );
335 }))
336 }
337
338 pub async fn respond<T: proto::RequestMessage>(
339 &self,
340 receipt: Receipt<T>,
341 response: T::Response,
342 ) {
343 self.peer.respond(receipt, response).await.unwrap()
344 }
345
346 fn connection_id(&self) -> ConnectionId {
347 self.connection_id.lock().expect("not connected")
348 }
349}
350
351pub struct FakeHttpClient {
352 handler:
353 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
354}
355
356impl FakeHttpClient {
357 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
358 where
359 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
360 F: 'static + Send + Sync + Fn(Request) -> Fut,
361 {
362 Arc::new(Self {
363 handler: Box::new(move |req| Box::pin(handler(req))),
364 })
365 }
366}
367
368impl fmt::Debug for FakeHttpClient {
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 f.debug_struct("FakeHttpClient").finish()
371 }
372}
373
374impl HttpClient for FakeHttpClient {
375 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
376 let future = (self.handler)(req);
377 Box::pin(async move { future.await.map(Into::into) })
378 }
379}