test_server.rs

  1use crate::{
  2    db::{tests::TestDb, NewUserParams, UserId},
  3    executor::Executor,
  4    rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  5    AppState,
  6};
  7use anyhow::anyhow;
  8use call::ActiveCall;
  9use channel::{ChannelBuffer, ChannelStore};
 10use client::{
 11    self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
 12};
 13use collections::{HashMap, HashSet};
 14use fs::FakeFs;
 15use futures::{channel::oneshot, StreamExt as _};
 16use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
 17use language::LanguageRegistry;
 18use node_runtime::FakeNodeRuntime;
 19use parking_lot::Mutex;
 20use project::{Project, WorktreeId};
 21use rpc::RECEIVE_TIMEOUT;
 22use settings::SettingsStore;
 23use std::{
 24    cell::{Ref, RefCell, RefMut},
 25    env,
 26    ops::{Deref, DerefMut},
 27    path::Path,
 28    sync::{
 29        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 30        Arc,
 31    },
 32};
 33use util::http::FakeHttpClient;
 34use workspace::{Workspace, WorkspaceStore};
 35
 36pub struct TestServer {
 37    pub app_state: Arc<AppState>,
 38    pub test_live_kit_server: Arc<live_kit_client::TestServer>,
 39    server: Arc<Server>,
 40    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
 41    forbid_connections: Arc<AtomicBool>,
 42    _test_db: TestDb,
 43}
 44
 45pub struct TestClient {
 46    pub username: String,
 47    pub app_state: Arc<workspace::AppState>,
 48    channel_store: ModelHandle<ChannelStore>,
 49    state: RefCell<TestClientState>,
 50}
 51
 52#[derive(Default)]
 53struct TestClientState {
 54    local_projects: Vec<ModelHandle<Project>>,
 55    remote_projects: Vec<ModelHandle<Project>>,
 56    buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
 57    channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
 58}
 59
 60pub struct ContactsSummary {
 61    pub current: Vec<String>,
 62    pub outgoing_requests: Vec<String>,
 63    pub incoming_requests: Vec<String>,
 64}
 65
 66impl TestServer {
 67    pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
 68        static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
 69
 70        let use_postgres = env::var("USE_POSTGRES").ok();
 71        let use_postgres = use_postgres.as_deref();
 72        let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
 73            TestDb::postgres(deterministic.build_background())
 74        } else {
 75            TestDb::sqlite(deterministic.build_background())
 76        };
 77        let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
 78        let live_kit_server = live_kit_client::TestServer::create(
 79            format!("http://livekit.{}.test", live_kit_server_id),
 80            format!("devkey-{}", live_kit_server_id),
 81            format!("secret-{}", live_kit_server_id),
 82            deterministic.build_background(),
 83        )
 84        .unwrap();
 85        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
 86        let epoch = app_state
 87            .db
 88            .create_server(&app_state.config.zed_environment)
 89            .await
 90            .unwrap();
 91        let server = Server::new(
 92            epoch,
 93            app_state.clone(),
 94            Executor::Deterministic(deterministic.build_background()),
 95        );
 96        server.start().await.unwrap();
 97        // Advance clock to ensure the server's cleanup task is finished.
 98        deterministic.advance_clock(CLEANUP_TIMEOUT);
 99        Self {
100            app_state,
101            server,
102            connection_killers: Default::default(),
103            forbid_connections: Default::default(),
104            _test_db: test_db,
105            test_live_kit_server: live_kit_server,
106        }
107    }
108
109    pub async fn reset(&self) {
110        self.app_state.db.reset();
111        let epoch = self
112            .app_state
113            .db
114            .create_server(&self.app_state.config.zed_environment)
115            .await
116            .unwrap();
117        self.server.reset(epoch);
118    }
119
120    pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
121        cx.update(|cx| {
122            if cx.has_global::<SettingsStore>() {
123                panic!("Same cx used to create two test clients")
124            }
125            cx.set_global(SettingsStore::test(cx));
126        });
127
128        let http = FakeHttpClient::with_404_response();
129        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
130        {
131            user.id
132        } else {
133            self.app_state
134                .db
135                .create_user(
136                    &format!("{name}@example.com"),
137                    false,
138                    NewUserParams {
139                        github_login: name.into(),
140                        github_user_id: 0,
141                        invite_count: 0,
142                    },
143                )
144                .await
145                .expect("creating user failed")
146                .user_id
147        };
148        let client_name = name.to_string();
149        let mut client = cx.read(|cx| Client::new(http.clone(), cx));
150        let server = self.server.clone();
151        let db = self.app_state.db.clone();
152        let connection_killers = self.connection_killers.clone();
153        let forbid_connections = self.forbid_connections.clone();
154
155        Arc::get_mut(&mut client)
156            .unwrap()
157            .set_id(user_id.to_proto())
158            .override_authenticate(move |cx| {
159                cx.spawn(|_| async move {
160                    let access_token = "the-token".to_string();
161                    Ok(Credentials {
162                        user_id: user_id.to_proto(),
163                        access_token,
164                    })
165                })
166            })
167            .override_establish_connection(move |credentials, cx| {
168                assert_eq!(credentials.user_id, user_id.0 as u64);
169                assert_eq!(credentials.access_token, "the-token");
170
171                let server = server.clone();
172                let db = db.clone();
173                let connection_killers = connection_killers.clone();
174                let forbid_connections = forbid_connections.clone();
175                let client_name = client_name.clone();
176                cx.spawn(move |cx| async move {
177                    if forbid_connections.load(SeqCst) {
178                        Err(EstablishConnectionError::other(anyhow!(
179                            "server is forbidding connections"
180                        )))
181                    } else {
182                        let (client_conn, server_conn, killed) =
183                            Connection::in_memory(cx.background());
184                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
185                        let user = db
186                            .get_user_by_id(user_id)
187                            .await
188                            .expect("retrieving user failed")
189                            .unwrap();
190                        cx.background()
191                            .spawn(server.handle_connection(
192                                server_conn,
193                                client_name,
194                                user,
195                                Some(connection_id_tx),
196                                Executor::Deterministic(cx.background()),
197                            ))
198                            .detach();
199                        let connection_id = connection_id_rx.await.unwrap();
200                        connection_killers
201                            .lock()
202                            .insert(connection_id.into(), killed);
203                        Ok(client_conn)
204                    }
205                })
206            });
207
208        let fs = FakeFs::new(cx.background());
209        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
210        let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
211        let mut language_registry = LanguageRegistry::test();
212        language_registry.set_executor(cx.background());
213        let app_state = Arc::new(workspace::AppState {
214            client: client.clone(),
215            user_store: user_store.clone(),
216            workspace_store,
217            languages: Arc::new(language_registry),
218            fs: fs.clone(),
219            build_window_options: |_, _, _| Default::default(),
220            initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
221            background_actions: || &[],
222            node_runtime: FakeNodeRuntime::new(),
223        });
224
225        cx.update(|cx| {
226            theme::init((), cx);
227            Project::init(&client, cx);
228            client::init(&client, cx);
229            language::init(cx);
230            editor::init_settings(cx);
231            workspace::init(app_state.clone(), cx);
232            audio::init((), cx);
233            call::init(client.clone(), user_store.clone(), cx);
234            channel::init(&client, user_store.clone(), cx);
235            notifications::init(client.clone(), user_store, cx);
236        });
237
238        client
239            .authenticate_and_connect(false, &cx.to_async())
240            .await
241            .unwrap();
242
243        let client = TestClient {
244            app_state,
245            username: name.to_string(),
246            channel_store: cx.read(ChannelStore::global).clone(),
247            state: Default::default(),
248        };
249        client.wait_for_current_user(cx).await;
250        client
251    }
252
253    pub fn disconnect_client(&self, peer_id: PeerId) {
254        self.connection_killers
255            .lock()
256            .remove(&peer_id)
257            .unwrap()
258            .store(true, SeqCst);
259    }
260
261    pub fn simulate_long_connection_interruption(
262        &self,
263        peer_id: PeerId,
264        deterministic: &Arc<Deterministic>,
265    ) {
266        self.forbid_connections();
267        self.disconnect_client(peer_id);
268        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
269        self.allow_connections();
270        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
271        deterministic.run_until_parked();
272    }
273
274    pub fn forbid_connections(&self) {
275        self.forbid_connections.store(true, SeqCst);
276    }
277
278    pub fn allow_connections(&self) {
279        self.forbid_connections.store(false, SeqCst);
280    }
281
282    pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
283        for ix in 1..clients.len() {
284            let (left, right) = clients.split_at_mut(ix);
285            let (client_a, cx_a) = left.last_mut().unwrap();
286            for (client_b, cx_b) in right {
287                client_a
288                    .app_state
289                    .user_store
290                    .update(*cx_a, |store, cx| {
291                        store.request_contact(client_b.user_id().unwrap(), cx)
292                    })
293                    .await
294                    .unwrap();
295                cx_a.foreground().run_until_parked();
296                client_b
297                    .app_state
298                    .user_store
299                    .update(*cx_b, |store, cx| {
300                        store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
301                    })
302                    .await
303                    .unwrap();
304            }
305        }
306    }
307
308    pub async fn make_channel(
309        &self,
310        channel: &str,
311        parent: Option<u64>,
312        admin: (&TestClient, &mut TestAppContext),
313        members: &mut [(&TestClient, &mut TestAppContext)],
314    ) -> u64 {
315        let (_, admin_cx) = admin;
316        let channel_id = admin_cx
317            .read(ChannelStore::global)
318            .update(admin_cx, |channel_store, cx| {
319                channel_store.create_channel(channel, parent, cx)
320            })
321            .await
322            .unwrap();
323
324        for (member_client, member_cx) in members {
325            admin_cx
326                .read(ChannelStore::global)
327                .update(admin_cx, |channel_store, cx| {
328                    channel_store.invite_member(
329                        channel_id,
330                        member_client.user_id().unwrap(),
331                        false,
332                        cx,
333                    )
334                })
335                .await
336                .unwrap();
337
338            admin_cx.foreground().run_until_parked();
339
340            member_cx
341                .read(ChannelStore::global)
342                .update(*member_cx, |channels, _| {
343                    channels.respond_to_channel_invite(channel_id, true)
344                })
345                .await
346                .unwrap();
347        }
348
349        channel_id
350    }
351
352    pub async fn make_channel_tree(
353        &self,
354        channels: &[(&str, Option<&str>)],
355        creator: (&TestClient, &mut TestAppContext),
356    ) -> Vec<u64> {
357        let mut observed_channels = HashMap::default();
358        let mut result = Vec::new();
359        for (channel, parent) in channels {
360            let id;
361            if let Some(parent) = parent {
362                if let Some(parent_id) = observed_channels.get(parent) {
363                    id = self
364                        .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
365                        .await;
366                } else {
367                    panic!(
368                        "Edge {}->{} referenced before {} was created",
369                        parent, channel, parent
370                    )
371                }
372            } else {
373                id = self
374                    .make_channel(channel, None, (creator.0, creator.1), &mut [])
375                    .await;
376            }
377
378            observed_channels.insert(channel, id);
379            result.push(id);
380        }
381
382        result
383    }
384
385    pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
386        self.make_contacts(clients).await;
387
388        let (left, right) = clients.split_at_mut(1);
389        let (_client_a, cx_a) = &mut left[0];
390        let active_call_a = cx_a.read(ActiveCall::global);
391
392        for (client_b, cx_b) in right {
393            let user_id_b = client_b.current_user_id(*cx_b).to_proto();
394            active_call_a
395                .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
396                .await
397                .unwrap();
398
399            cx_b.foreground().run_until_parked();
400            let active_call_b = cx_b.read(ActiveCall::global);
401            active_call_b
402                .update(*cx_b, |call, cx| call.accept_incoming(cx))
403                .await
404                .unwrap();
405        }
406    }
407
408    pub async fn build_app_state(
409        test_db: &TestDb,
410        fake_server: &live_kit_client::TestServer,
411    ) -> Arc<AppState> {
412        Arc::new(AppState {
413            db: test_db.db().clone(),
414            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
415            config: Default::default(),
416        })
417    }
418}
419
420impl Deref for TestServer {
421    type Target = Server;
422
423    fn deref(&self) -> &Self::Target {
424        &self.server
425    }
426}
427
428impl Drop for TestServer {
429    fn drop(&mut self) {
430        self.server.teardown();
431        self.test_live_kit_server.teardown().unwrap();
432    }
433}
434
435impl Deref for TestClient {
436    type Target = Arc<Client>;
437
438    fn deref(&self) -> &Self::Target {
439        &self.app_state.client
440    }
441}
442
443impl TestClient {
444    pub fn fs(&self) -> &FakeFs {
445        self.app_state.fs.as_fake()
446    }
447
448    pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
449        &self.channel_store
450    }
451
452    pub fn user_store(&self) -> &ModelHandle<UserStore> {
453        &self.app_state.user_store
454    }
455
456    pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
457        &self.app_state.languages
458    }
459
460    pub fn client(&self) -> &Arc<Client> {
461        &self.app_state.client
462    }
463
464    pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
465        UserId::from_proto(
466            self.app_state
467                .user_store
468                .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
469        )
470    }
471
472    pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
473        let mut authed_user = self
474            .app_state
475            .user_store
476            .read_with(cx, |user_store, _| user_store.watch_current_user());
477        while authed_user.next().await.unwrap().is_none() {}
478    }
479
480    pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
481        self.app_state
482            .user_store
483            .update(cx, |store, _| store.clear_contacts())
484            .await;
485    }
486
487    pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
488        Ref::map(self.state.borrow(), |state| &state.local_projects)
489    }
490
491    pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
492        Ref::map(self.state.borrow(), |state| &state.remote_projects)
493    }
494
495    pub fn local_projects_mut<'a>(
496        &'a self,
497    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
498        RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
499    }
500
501    pub fn remote_projects_mut<'a>(
502        &'a self,
503    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
504        RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
505    }
506
507    pub fn buffers_for_project<'a>(
508        &'a self,
509        project: &ModelHandle<Project>,
510    ) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
511        RefMut::map(self.state.borrow_mut(), |state| {
512            state.buffers.entry(project.clone()).or_default()
513        })
514    }
515
516    pub fn buffers<'a>(
517        &'a self,
518    ) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
519    {
520        RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
521    }
522
523    pub fn channel_buffers<'a>(
524        &'a self,
525    ) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
526        RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
527    }
528
529    pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
530        self.app_state
531            .user_store
532            .read_with(cx, |store, _| ContactsSummary {
533                current: store
534                    .contacts()
535                    .iter()
536                    .map(|contact| contact.user.github_login.clone())
537                    .collect(),
538                outgoing_requests: store
539                    .outgoing_contact_requests()
540                    .iter()
541                    .map(|user| user.github_login.clone())
542                    .collect(),
543                incoming_requests: store
544                    .incoming_contact_requests()
545                    .iter()
546                    .map(|user| user.github_login.clone())
547                    .collect(),
548            })
549    }
550
551    pub async fn build_local_project(
552        &self,
553        root_path: impl AsRef<Path>,
554        cx: &mut TestAppContext,
555    ) -> (ModelHandle<Project>, WorktreeId) {
556        let project = self.build_empty_local_project(cx);
557        let (worktree, _) = project
558            .update(cx, |p, cx| {
559                p.find_or_create_local_worktree(root_path, true, cx)
560            })
561            .await
562            .unwrap();
563        worktree
564            .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
565            .await;
566        (project, worktree.read_with(cx, |tree, _| tree.id()))
567    }
568
569    pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
570        cx.update(|cx| {
571            Project::local(
572                self.client().clone(),
573                self.app_state.node_runtime.clone(),
574                self.app_state.user_store.clone(),
575                self.app_state.languages.clone(),
576                self.app_state.fs.clone(),
577                cx,
578            )
579        })
580    }
581
582    pub async fn build_remote_project(
583        &self,
584        host_project_id: u64,
585        guest_cx: &mut TestAppContext,
586    ) -> ModelHandle<Project> {
587        let active_call = guest_cx.read(ActiveCall::global);
588        let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
589        room.update(guest_cx, |room, cx| {
590            room.join_project(
591                host_project_id,
592                self.app_state.languages.clone(),
593                self.app_state.fs.clone(),
594                cx,
595            )
596        })
597        .await
598        .unwrap()
599    }
600
601    pub fn build_workspace(
602        &self,
603        project: &ModelHandle<Project>,
604        cx: &mut TestAppContext,
605    ) -> WindowHandle<Workspace> {
606        cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
607    }
608
609    pub async fn add_admin_to_channel(
610        &self,
611        user: (&TestClient, &mut TestAppContext),
612        channel: u64,
613        cx_self: &mut TestAppContext,
614    ) {
615        let (other_client, other_cx) = user;
616
617        cx_self
618            .read(ChannelStore::global)
619            .update(cx_self, |channel_store, cx| {
620                channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx)
621            })
622            .await
623            .unwrap();
624
625        cx_self.foreground().run_until_parked();
626
627        other_cx
628            .read(ChannelStore::global)
629            .update(other_cx, |channel_store, _| {
630                channel_store.respond_to_channel_invite(channel, true)
631            })
632            .await
633            .unwrap();
634    }
635}
636
637impl Drop for TestClient {
638    fn drop(&mut self) {
639        self.app_state.client.teardown();
640    }
641}