1use crate::{
2 db::{tests::TestDb, NewUserParams, UserId},
3 executor::Executor,
4 rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
5 AppState,
6};
7use anyhow::anyhow;
8use call::ActiveCall;
9use channel::{ChannelBuffer, ChannelStore};
10use client::{
11 self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
12};
13use collections::{HashMap, HashSet};
14use fs::FakeFs;
15use futures::{channel::oneshot, StreamExt as _};
16use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
17use language::LanguageRegistry;
18use parking_lot::Mutex;
19use project::{Project, WorktreeId};
20use rpc::RECEIVE_TIMEOUT;
21use settings::SettingsStore;
22use std::{
23 cell::{Ref, RefCell, RefMut},
24 env,
25 ops::{Deref, DerefMut},
26 path::Path,
27 sync::{
28 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
29 Arc,
30 },
31};
32use util::http::FakeHttpClient;
33use workspace::{Workspace, WorkspaceStore};
34
35pub struct TestServer {
36 pub app_state: Arc<AppState>,
37 pub test_live_kit_server: Arc<live_kit_client::TestServer>,
38 server: Arc<Server>,
39 connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
40 forbid_connections: Arc<AtomicBool>,
41 _test_db: TestDb,
42}
43
44pub struct TestClient {
45 pub username: String,
46 pub app_state: Arc<workspace::AppState>,
47 state: RefCell<TestClientState>,
48}
49
50#[derive(Default)]
51struct TestClientState {
52 local_projects: Vec<ModelHandle<Project>>,
53 remote_projects: Vec<ModelHandle<Project>>,
54 buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
55 channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
56}
57
58pub struct ContactsSummary {
59 pub current: Vec<String>,
60 pub outgoing_requests: Vec<String>,
61 pub incoming_requests: Vec<String>,
62}
63
64impl TestServer {
65 pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
66 static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
67
68 let use_postgres = env::var("USE_POSTGRES").ok();
69 let use_postgres = use_postgres.as_deref();
70 let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
71 TestDb::postgres(deterministic.build_background())
72 } else {
73 TestDb::sqlite(deterministic.build_background())
74 };
75 let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
76 let live_kit_server = live_kit_client::TestServer::create(
77 format!("http://livekit.{}.test", live_kit_server_id),
78 format!("devkey-{}", live_kit_server_id),
79 format!("secret-{}", live_kit_server_id),
80 deterministic.build_background(),
81 )
82 .unwrap();
83 let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
84 let epoch = app_state
85 .db
86 .create_server(&app_state.config.zed_environment)
87 .await
88 .unwrap();
89 let server = Server::new(
90 epoch,
91 app_state.clone(),
92 Executor::Deterministic(deterministic.build_background()),
93 );
94 server.start().await.unwrap();
95 // Advance clock to ensure the server's cleanup task is finished.
96 deterministic.advance_clock(CLEANUP_TIMEOUT);
97 Self {
98 app_state,
99 server,
100 connection_killers: Default::default(),
101 forbid_connections: Default::default(),
102 _test_db: test_db,
103 test_live_kit_server: live_kit_server,
104 }
105 }
106
107 pub async fn reset(&self) {
108 self.app_state.db.reset();
109 let epoch = self
110 .app_state
111 .db
112 .create_server(&self.app_state.config.zed_environment)
113 .await
114 .unwrap();
115 self.server.reset(epoch);
116 }
117
118 pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
119 cx.update(|cx| {
120 if cx.has_global::<SettingsStore>() {
121 panic!("Same cx used to create two test clients")
122 }
123 cx.set_global(SettingsStore::test(cx));
124 });
125
126 let http = FakeHttpClient::with_404_response();
127 let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
128 {
129 user.id
130 } else {
131 self.app_state
132 .db
133 .create_user(
134 &format!("{name}@example.com"),
135 false,
136 NewUserParams {
137 github_login: name.into(),
138 github_user_id: 0,
139 invite_count: 0,
140 },
141 )
142 .await
143 .expect("creating user failed")
144 .user_id
145 };
146 let client_name = name.to_string();
147 let mut client = cx.read(|cx| Client::new(http.clone(), cx));
148 let server = self.server.clone();
149 let db = self.app_state.db.clone();
150 let connection_killers = self.connection_killers.clone();
151 let forbid_connections = self.forbid_connections.clone();
152
153 Arc::get_mut(&mut client)
154 .unwrap()
155 .set_id(user_id.to_proto())
156 .override_authenticate(move |cx| {
157 cx.spawn(|_| async move {
158 let access_token = "the-token".to_string();
159 Ok(Credentials {
160 user_id: user_id.to_proto(),
161 access_token,
162 })
163 })
164 })
165 .override_establish_connection(move |credentials, cx| {
166 assert_eq!(credentials.user_id, user_id.0 as u64);
167 assert_eq!(credentials.access_token, "the-token");
168
169 let server = server.clone();
170 let db = db.clone();
171 let connection_killers = connection_killers.clone();
172 let forbid_connections = forbid_connections.clone();
173 let client_name = client_name.clone();
174 cx.spawn(move |cx| async move {
175 if forbid_connections.load(SeqCst) {
176 Err(EstablishConnectionError::other(anyhow!(
177 "server is forbidding connections"
178 )))
179 } else {
180 let (client_conn, server_conn, killed) =
181 Connection::in_memory(cx.background());
182 let (connection_id_tx, connection_id_rx) = oneshot::channel();
183 let user = db
184 .get_user_by_id(user_id)
185 .await
186 .expect("retrieving user failed")
187 .unwrap();
188 cx.background()
189 .spawn(server.handle_connection(
190 server_conn,
191 client_name,
192 user,
193 Some(connection_id_tx),
194 Executor::Deterministic(cx.background()),
195 ))
196 .detach();
197 let connection_id = connection_id_rx.await.unwrap();
198 connection_killers
199 .lock()
200 .insert(connection_id.into(), killed);
201 Ok(client_conn)
202 }
203 })
204 });
205
206 let fs = FakeFs::new(cx.background());
207 let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
208 let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
209 let channel_store =
210 cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
211 let mut language_registry = LanguageRegistry::test();
212 language_registry.set_executor(cx.background());
213 let app_state = Arc::new(workspace::AppState {
214 client: client.clone(),
215 user_store: user_store.clone(),
216 workspace_store,
217 channel_store: channel_store.clone(),
218 languages: Arc::new(language_registry),
219 fs: fs.clone(),
220 build_window_options: |_, _, _| Default::default(),
221 initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
222 background_actions: || &[],
223 });
224
225 cx.update(|cx| {
226 theme::init((), cx);
227 Project::init(&client, cx);
228 client::init(&client, cx);
229 language::init(cx);
230 editor::init_settings(cx);
231 workspace::init(app_state.clone(), cx);
232 audio::init((), cx);
233 call::init(client.clone(), user_store.clone(), cx);
234 channel::init(&client);
235 });
236
237 client
238 .authenticate_and_connect(false, &cx.to_async())
239 .await
240 .unwrap();
241
242 let client = TestClient {
243 app_state,
244 username: name.to_string(),
245 state: Default::default(),
246 };
247 client.wait_for_current_user(cx).await;
248 client
249 }
250
251 pub fn disconnect_client(&self, peer_id: PeerId) {
252 self.connection_killers
253 .lock()
254 .remove(&peer_id)
255 .unwrap()
256 .store(true, SeqCst);
257 }
258
259 pub fn simulate_long_connection_interruption(
260 &self,
261 peer_id: PeerId,
262 deterministic: &Arc<Deterministic>,
263 ) {
264 self.forbid_connections();
265 self.disconnect_client(peer_id);
266 deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
267 self.allow_connections();
268 deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
269 deterministic.run_until_parked();
270 }
271
272 pub fn forbid_connections(&self) {
273 self.forbid_connections.store(true, SeqCst);
274 }
275
276 pub fn allow_connections(&self) {
277 self.forbid_connections.store(false, SeqCst);
278 }
279
280 pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
281 for ix in 1..clients.len() {
282 let (left, right) = clients.split_at_mut(ix);
283 let (client_a, cx_a) = left.last_mut().unwrap();
284 for (client_b, cx_b) in right {
285 client_a
286 .app_state
287 .user_store
288 .update(*cx_a, |store, cx| {
289 store.request_contact(client_b.user_id().unwrap(), cx)
290 })
291 .await
292 .unwrap();
293 cx_a.foreground().run_until_parked();
294 client_b
295 .app_state
296 .user_store
297 .update(*cx_b, |store, cx| {
298 store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
299 })
300 .await
301 .unwrap();
302 }
303 }
304 }
305
306 pub async fn make_channel(
307 &self,
308 channel: &str,
309 parent: Option<u64>,
310 admin: (&TestClient, &mut TestAppContext),
311 members: &mut [(&TestClient, &mut TestAppContext)],
312 ) -> u64 {
313 let (admin_client, admin_cx) = admin;
314 let channel_id = admin_client
315 .app_state
316 .channel_store
317 .update(admin_cx, |channel_store, cx| {
318 channel_store.create_channel(channel, parent, cx)
319 })
320 .await
321 .unwrap();
322
323 for (member_client, member_cx) in members {
324 admin_client
325 .app_state
326 .channel_store
327 .update(admin_cx, |channel_store, cx| {
328 channel_store.invite_member(
329 channel_id,
330 member_client.user_id().unwrap(),
331 false,
332 cx,
333 )
334 })
335 .await
336 .unwrap();
337
338 admin_cx.foreground().run_until_parked();
339
340 member_client
341 .app_state
342 .channel_store
343 .update(*member_cx, |channels, _| {
344 channels.respond_to_channel_invite(channel_id, true)
345 })
346 .await
347 .unwrap();
348 }
349
350 channel_id
351 }
352
353 pub async fn make_channel_tree(
354 &self,
355 channels: &[(&str, Option<&str>)],
356 creator: (&TestClient, &mut TestAppContext),
357 ) -> Vec<u64> {
358 let mut observed_channels = HashMap::default();
359 let mut result = Vec::new();
360 for (channel, parent) in channels {
361 let id;
362 if let Some(parent) = parent {
363 if let Some(parent_id) = observed_channels.get(parent) {
364 id = self
365 .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
366 .await;
367 } else {
368 panic!(
369 "Edge {}->{} referenced before {} was created",
370 parent, channel, parent
371 )
372 }
373 } else {
374 id = self
375 .make_channel(channel, None, (creator.0, creator.1), &mut [])
376 .await;
377 }
378
379 observed_channels.insert(channel, id);
380 result.push(id);
381 }
382
383 result
384 }
385
386 pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
387 self.make_contacts(clients).await;
388
389 let (left, right) = clients.split_at_mut(1);
390 let (_client_a, cx_a) = &mut left[0];
391 let active_call_a = cx_a.read(ActiveCall::global);
392
393 for (client_b, cx_b) in right {
394 let user_id_b = client_b.current_user_id(*cx_b).to_proto();
395 active_call_a
396 .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
397 .await
398 .unwrap();
399
400 cx_b.foreground().run_until_parked();
401 let active_call_b = cx_b.read(ActiveCall::global);
402 active_call_b
403 .update(*cx_b, |call, cx| call.accept_incoming(cx))
404 .await
405 .unwrap();
406 }
407 }
408
409 pub async fn build_app_state(
410 test_db: &TestDb,
411 fake_server: &live_kit_client::TestServer,
412 ) -> Arc<AppState> {
413 Arc::new(AppState {
414 db: test_db.db().clone(),
415 live_kit_client: Some(Arc::new(fake_server.create_api_client())),
416 config: Default::default(),
417 })
418 }
419}
420
421impl Deref for TestServer {
422 type Target = Server;
423
424 fn deref(&self) -> &Self::Target {
425 &self.server
426 }
427}
428
429impl Drop for TestServer {
430 fn drop(&mut self) {
431 self.server.teardown();
432 self.test_live_kit_server.teardown().unwrap();
433 }
434}
435
436impl Deref for TestClient {
437 type Target = Arc<Client>;
438
439 fn deref(&self) -> &Self::Target {
440 &self.app_state.client
441 }
442}
443
444impl TestClient {
445 pub fn fs(&self) -> &FakeFs {
446 self.app_state.fs.as_fake()
447 }
448
449 pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
450 &self.app_state.channel_store
451 }
452
453 pub fn user_store(&self) -> &ModelHandle<UserStore> {
454 &self.app_state.user_store
455 }
456
457 pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
458 &self.app_state.languages
459 }
460
461 pub fn client(&self) -> &Arc<Client> {
462 &self.app_state.client
463 }
464
465 pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
466 UserId::from_proto(
467 self.app_state
468 .user_store
469 .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
470 )
471 }
472
473 pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
474 let mut authed_user = self
475 .app_state
476 .user_store
477 .read_with(cx, |user_store, _| user_store.watch_current_user());
478 while authed_user.next().await.unwrap().is_none() {}
479 }
480
481 pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
482 self.app_state
483 .user_store
484 .update(cx, |store, _| store.clear_contacts())
485 .await;
486 }
487
488 pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
489 Ref::map(self.state.borrow(), |state| &state.local_projects)
490 }
491
492 pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
493 Ref::map(self.state.borrow(), |state| &state.remote_projects)
494 }
495
496 pub fn local_projects_mut<'a>(
497 &'a self,
498 ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
499 RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
500 }
501
502 pub fn remote_projects_mut<'a>(
503 &'a self,
504 ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
505 RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
506 }
507
508 pub fn buffers_for_project<'a>(
509 &'a self,
510 project: &ModelHandle<Project>,
511 ) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
512 RefMut::map(self.state.borrow_mut(), |state| {
513 state.buffers.entry(project.clone()).or_default()
514 })
515 }
516
517 pub fn buffers<'a>(
518 &'a self,
519 ) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
520 {
521 RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
522 }
523
524 pub fn channel_buffers<'a>(
525 &'a self,
526 ) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
527 RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
528 }
529
530 pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
531 self.app_state
532 .user_store
533 .read_with(cx, |store, _| ContactsSummary {
534 current: store
535 .contacts()
536 .iter()
537 .map(|contact| contact.user.github_login.clone())
538 .collect(),
539 outgoing_requests: store
540 .outgoing_contact_requests()
541 .iter()
542 .map(|user| user.github_login.clone())
543 .collect(),
544 incoming_requests: store
545 .incoming_contact_requests()
546 .iter()
547 .map(|user| user.github_login.clone())
548 .collect(),
549 })
550 }
551
552 pub async fn build_local_project(
553 &self,
554 root_path: impl AsRef<Path>,
555 cx: &mut TestAppContext,
556 ) -> (ModelHandle<Project>, WorktreeId) {
557 let project = self.build_empty_local_project(cx);
558 let (worktree, _) = project
559 .update(cx, |p, cx| {
560 p.find_or_create_local_worktree(root_path, true, cx)
561 })
562 .await
563 .unwrap();
564 worktree
565 .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
566 .await;
567 (project, worktree.read_with(cx, |tree, _| tree.id()))
568 }
569
570 pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
571 cx.update(|cx| {
572 Project::local(
573 self.client().clone(),
574 self.app_state.user_store.clone(),
575 self.app_state.languages.clone(),
576 self.app_state.fs.clone(),
577 cx,
578 )
579 })
580 }
581
582 pub async fn build_remote_project(
583 &self,
584 host_project_id: u64,
585 guest_cx: &mut TestAppContext,
586 ) -> ModelHandle<Project> {
587 let active_call = guest_cx.read(ActiveCall::global);
588 let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
589 room.update(guest_cx, |room, cx| {
590 room.join_project(
591 host_project_id,
592 self.app_state.languages.clone(),
593 self.app_state.fs.clone(),
594 cx,
595 )
596 })
597 .await
598 .unwrap()
599 }
600
601 pub fn build_workspace(
602 &self,
603 project: &ModelHandle<Project>,
604 cx: &mut TestAppContext,
605 ) -> WindowHandle<Workspace> {
606 cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
607 }
608
609 pub async fn add_admin_to_channel(
610 &self,
611 user: (&TestClient, &mut TestAppContext),
612 channel: u64,
613 cx_self: &mut TestAppContext,
614 ) {
615 let (other_client, other_cx) = user;
616
617 self.app_state
618 .channel_store
619 .update(cx_self, |channel_store, cx| {
620 channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx)
621 })
622 .await
623 .unwrap();
624
625 cx_self.foreground().run_until_parked();
626
627 other_client
628 .app_state
629 .channel_store
630 .update(other_cx, |channels, _| {
631 channels.respond_to_channel_invite(channel, true)
632 })
633 .await
634 .unwrap();
635 }
636}
637
638impl Drop for TestClient {
639 fn drop(&mut self) {
640 self.app_state.client.teardown();
641 }
642}