1use crate::{
2 assets::Assets,
3 channel::ChannelList,
4 fs::FakeFs,
5 http::{HttpClient, Request, Response, ServerResponse},
6 language::LanguageRegistry,
7 rpc::{self, Client, Credentials, EstablishConnectionError},
8 settings::{self, ThemeRegistry},
9 user::UserStore,
10 AppState,
11};
12use anyhow::{anyhow, Result};
13use futures::{future::BoxFuture, Future};
14use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
15use parking_lot::Mutex;
16use postage::{mpsc, prelude::Stream as _};
17use smol::channel;
18use std::{
19 fmt,
20 marker::PhantomData,
21 path::{Path, PathBuf},
22 sync::{
23 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
24 Arc,
25 },
26};
27use tempdir::TempDir;
28use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
29
30#[cfg(test)]
31#[ctor::ctor]
32fn init_logger() {
33 env_logger::init();
34}
35
36pub fn sample_text(rows: usize, cols: usize) -> String {
37 let mut text = String::new();
38 for row in 0..rows {
39 let c: char = ('a' as u32 + row as u32) as u8 as char;
40 let mut line = c.to_string().repeat(cols);
41 if row < rows - 1 {
42 line.push('\n');
43 }
44 text += &line;
45 }
46 text
47}
48
49pub fn temp_tree(tree: serde_json::Value) -> TempDir {
50 let dir = TempDir::new("").unwrap();
51 write_tree(dir.path(), tree);
52 dir
53}
54
55fn write_tree(path: &Path, tree: serde_json::Value) {
56 use serde_json::Value;
57 use std::fs;
58
59 if let Value::Object(map) = tree {
60 for (name, contents) in map {
61 let mut path = PathBuf::from(path);
62 path.push(name);
63 match contents {
64 Value::Object(_) => {
65 fs::create_dir(&path).unwrap();
66 write_tree(&path, contents);
67 }
68 Value::Null => {
69 fs::create_dir(&path).unwrap();
70 }
71 Value::String(contents) => {
72 fs::write(&path, contents).unwrap();
73 }
74 _ => {
75 panic!("JSON object must contain only objects, strings, or null");
76 }
77 }
78 }
79 } else {
80 panic!("You must pass a JSON object to this helper")
81 }
82}
83
84pub fn test_app_state(cx: &mut MutableAppContext) -> Arc<AppState> {
85 let (settings_tx, settings) = settings::test(cx);
86 let languages = Arc::new(LanguageRegistry::new());
87 let themes = ThemeRegistry::new(Assets, cx.font_cache().clone());
88 let rpc = rpc::Client::new();
89 let http = FakeHttpClient::new(|_| async move { Ok(ServerResponse::new(404)) });
90 let user_store = cx.add_model(|cx| UserStore::new(rpc.clone(), http, cx));
91 Arc::new(AppState {
92 settings_tx: Arc::new(Mutex::new(settings_tx)),
93 settings,
94 themes,
95 languages: languages.clone(),
96 channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)),
97 rpc,
98 user_store,
99 fs: Arc::new(FakeFs::new()),
100 })
101}
102
103pub struct Observer<T>(PhantomData<T>);
104
105impl<T: 'static> Entity for Observer<T> {
106 type Event = ();
107}
108
109impl<T: Entity> Observer<T> {
110 pub fn new(
111 handle: &ModelHandle<T>,
112 cx: &mut gpui::TestAppContext,
113 ) -> (ModelHandle<Self>, channel::Receiver<()>) {
114 let (notify_tx, notify_rx) = channel::unbounded();
115 let observer = cx.add_model(|cx| {
116 cx.observe(handle, move |_, _, _| {
117 let _ = notify_tx.try_send(());
118 })
119 .detach();
120 Observer(PhantomData)
121 });
122 (observer, notify_rx)
123 }
124}
125
126pub struct FakeServer {
127 peer: Arc<Peer>,
128 incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
129 connection_id: Mutex<Option<ConnectionId>>,
130 forbid_connections: AtomicBool,
131 auth_count: AtomicUsize,
132 access_token: AtomicUsize,
133 user_id: u64,
134}
135
136impl FakeServer {
137 pub async fn for_client(
138 client_user_id: u64,
139 client: &mut Arc<Client>,
140 cx: &TestAppContext,
141 ) -> Arc<Self> {
142 let server = Arc::new(Self {
143 peer: Peer::new(),
144 incoming: Default::default(),
145 connection_id: Default::default(),
146 forbid_connections: Default::default(),
147 auth_count: Default::default(),
148 access_token: Default::default(),
149 user_id: client_user_id,
150 });
151
152 Arc::get_mut(client)
153 .unwrap()
154 .override_authenticate({
155 let server = server.clone();
156 move |cx| {
157 server.auth_count.fetch_add(1, SeqCst);
158 let access_token = server.access_token.load(SeqCst).to_string();
159 cx.spawn(move |_| async move {
160 Ok(Credentials {
161 user_id: client_user_id,
162 access_token,
163 })
164 })
165 }
166 })
167 .override_establish_connection({
168 let server = server.clone();
169 move |credentials, cx| {
170 let credentials = credentials.clone();
171 cx.spawn({
172 let server = server.clone();
173 move |cx| async move { server.establish_connection(&credentials, &cx).await }
174 })
175 }
176 });
177
178 client
179 .authenticate_and_connect(&cx.to_async())
180 .await
181 .unwrap();
182 server
183 }
184
185 pub async fn disconnect(&self) {
186 self.peer.disconnect(self.connection_id()).await;
187 self.connection_id.lock().take();
188 self.incoming.lock().take();
189 }
190
191 async fn establish_connection(
192 &self,
193 credentials: &Credentials,
194 cx: &AsyncAppContext,
195 ) -> Result<Connection, EstablishConnectionError> {
196 assert_eq!(credentials.user_id, self.user_id);
197
198 if self.forbid_connections.load(SeqCst) {
199 Err(EstablishConnectionError::Other(anyhow!(
200 "server is forbidding connections"
201 )))?
202 }
203
204 if credentials.access_token != self.access_token.load(SeqCst).to_string() {
205 Err(EstablishConnectionError::Unauthorized)?
206 }
207
208 let (client_conn, server_conn, _) = Connection::in_memory();
209 let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
210 cx.background().spawn(io).detach();
211 *self.incoming.lock() = Some(incoming);
212 *self.connection_id.lock() = Some(connection_id);
213 Ok(client_conn)
214 }
215
216 pub fn auth_count(&self) -> usize {
217 self.auth_count.load(SeqCst)
218 }
219
220 pub fn roll_access_token(&self) {
221 self.access_token.fetch_add(1, SeqCst);
222 }
223
224 pub fn forbid_connections(&self) {
225 self.forbid_connections.store(true, SeqCst);
226 }
227
228 pub fn allow_connections(&self) {
229 self.forbid_connections.store(false, SeqCst);
230 }
231
232 pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
233 self.peer.send(self.connection_id(), message).await.unwrap();
234 }
235
236 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
237 let message = self
238 .incoming
239 .lock()
240 .as_mut()
241 .expect("not connected")
242 .recv()
243 .await
244 .ok_or_else(|| anyhow!("other half hung up"))?;
245 let type_name = message.payload_type_name();
246 Ok(*message
247 .into_any()
248 .downcast::<TypedEnvelope<M>>()
249 .unwrap_or_else(|_| {
250 panic!(
251 "fake server received unexpected message type: {:?}",
252 type_name
253 );
254 }))
255 }
256
257 pub async fn respond<T: proto::RequestMessage>(
258 &self,
259 receipt: Receipt<T>,
260 response: T::Response,
261 ) {
262 self.peer.respond(receipt, response).await.unwrap()
263 }
264
265 fn connection_id(&self) -> ConnectionId {
266 self.connection_id.lock().expect("not connected")
267 }
268}
269
270pub struct FakeHttpClient {
271 handler:
272 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
273}
274
275impl FakeHttpClient {
276 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
277 where
278 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
279 F: 'static + Send + Sync + Fn(Request) -> Fut,
280 {
281 Arc::new(Self {
282 handler: Box::new(move |req| Box::pin(handler(req))),
283 })
284 }
285}
286
287impl fmt::Debug for FakeHttpClient {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 f.debug_struct("FakeHttpClient").finish()
290 }
291}
292
293impl HttpClient for FakeHttpClient {
294 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
295 let future = (self.handler)(req);
296 Box::pin(async move { future.await.map(Into::into) })
297 }
298}