test_server.rs

  1use crate::{
  2    db::{tests::TestDb, NewUserParams, UserId},
  3    executor::Executor,
  4    rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  5    AppState,
  6};
  7use anyhow::anyhow;
  8use call::ActiveCall;
  9use channel::{ChannelBuffer, ChannelStore};
 10use client::{
 11    self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
 12};
 13use collections::{HashMap, HashSet};
 14use fs::FakeFs;
 15use futures::{channel::oneshot, StreamExt as _};
 16use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
 17use language::LanguageRegistry;
 18use node_runtime::FakeNodeRuntime;
 19use parking_lot::Mutex;
 20use project::{Project, WorktreeId};
 21use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT};
 22use settings::SettingsStore;
 23use std::{
 24    cell::{Ref, RefCell, RefMut},
 25    env,
 26    ops::{Deref, DerefMut},
 27    path::Path,
 28    sync::{
 29        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 30        Arc,
 31    },
 32};
 33use util::http::FakeHttpClient;
 34use workspace::{Workspace, WorkspaceStore};
 35
 36pub struct TestServer {
 37    pub app_state: Arc<AppState>,
 38    pub test_live_kit_server: Arc<live_kit_client::TestServer>,
 39    server: Arc<Server>,
 40    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
 41    forbid_connections: Arc<AtomicBool>,
 42    _test_db: TestDb,
 43}
 44
 45pub struct TestClient {
 46    pub username: String,
 47    pub app_state: Arc<workspace::AppState>,
 48    channel_store: ModelHandle<ChannelStore>,
 49    state: RefCell<TestClientState>,
 50}
 51
 52#[derive(Default)]
 53struct TestClientState {
 54    local_projects: Vec<ModelHandle<Project>>,
 55    remote_projects: Vec<ModelHandle<Project>>,
 56    buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
 57    channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
 58}
 59
 60pub struct ContactsSummary {
 61    pub current: Vec<String>,
 62    pub outgoing_requests: Vec<String>,
 63    pub incoming_requests: Vec<String>,
 64}
 65
 66impl TestServer {
 67    pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
 68        static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
 69
 70        let use_postgres = env::var("USE_POSTGRES").ok();
 71        let use_postgres = use_postgres.as_deref();
 72        let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
 73            TestDb::postgres(deterministic.build_background())
 74        } else {
 75            TestDb::sqlite(deterministic.build_background())
 76        };
 77        let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
 78        let live_kit_server = live_kit_client::TestServer::create(
 79            format!("http://livekit.{}.test", live_kit_server_id),
 80            format!("devkey-{}", live_kit_server_id),
 81            format!("secret-{}", live_kit_server_id),
 82            deterministic.build_background(),
 83        )
 84        .unwrap();
 85        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
 86        let epoch = app_state
 87            .db
 88            .create_server(&app_state.config.zed_environment)
 89            .await
 90            .unwrap();
 91        let server = Server::new(
 92            epoch,
 93            app_state.clone(),
 94            Executor::Deterministic(deterministic.build_background()),
 95        );
 96        server.start().await.unwrap();
 97        // Advance clock to ensure the server's cleanup task is finished.
 98        deterministic.advance_clock(CLEANUP_TIMEOUT);
 99        Self {
100            app_state,
101            server,
102            connection_killers: Default::default(),
103            forbid_connections: Default::default(),
104            _test_db: test_db,
105            test_live_kit_server: live_kit_server,
106        }
107    }
108
109    pub async fn reset(&self) {
110        self.app_state.db.reset();
111        let epoch = self
112            .app_state
113            .db
114            .create_server(&self.app_state.config.zed_environment)
115            .await
116            .unwrap();
117        self.server.reset(epoch);
118    }
119
120    pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
121        cx.update(|cx| {
122            if cx.has_global::<SettingsStore>() {
123                panic!("Same cx used to create two test clients")
124            }
125            cx.set_global(SettingsStore::test(cx));
126        });
127
128        let http = FakeHttpClient::with_404_response();
129        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
130        {
131            user.id
132        } else {
133            self.app_state
134                .db
135                .create_user(
136                    &format!("{name}@example.com"),
137                    false,
138                    NewUserParams {
139                        github_login: name.into(),
140                        github_user_id: 0,
141                        invite_count: 0,
142                    },
143                )
144                .await
145                .expect("creating user failed")
146                .user_id
147        };
148        let client_name = name.to_string();
149        let mut client = cx.read(|cx| Client::new(http.clone(), cx));
150        let server = self.server.clone();
151        let db = self.app_state.db.clone();
152        let connection_killers = self.connection_killers.clone();
153        let forbid_connections = self.forbid_connections.clone();
154
155        Arc::get_mut(&mut client)
156            .unwrap()
157            .set_id(user_id.to_proto())
158            .override_authenticate(move |cx| {
159                cx.spawn(|_| async move {
160                    let access_token = "the-token".to_string();
161                    Ok(Credentials {
162                        user_id: user_id.to_proto(),
163                        access_token,
164                    })
165                })
166            })
167            .override_establish_connection(move |credentials, cx| {
168                assert_eq!(credentials.user_id, user_id.0 as u64);
169                assert_eq!(credentials.access_token, "the-token");
170
171                let server = server.clone();
172                let db = db.clone();
173                let connection_killers = connection_killers.clone();
174                let forbid_connections = forbid_connections.clone();
175                let client_name = client_name.clone();
176                cx.spawn(move |cx| async move {
177                    if forbid_connections.load(SeqCst) {
178                        Err(EstablishConnectionError::other(anyhow!(
179                            "server is forbidding connections"
180                        )))
181                    } else {
182                        let (client_conn, server_conn, killed) =
183                            Connection::in_memory(cx.background());
184                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
185                        let user = db
186                            .get_user_by_id(user_id)
187                            .await
188                            .expect("retrieving user failed")
189                            .unwrap();
190                        cx.background()
191                            .spawn(server.handle_connection(
192                                server_conn,
193                                client_name,
194                                user,
195                                Some(connection_id_tx),
196                                Executor::Deterministic(cx.background()),
197                            ))
198                            .detach();
199                        let connection_id = connection_id_rx.await.unwrap();
200                        connection_killers
201                            .lock()
202                            .insert(connection_id.into(), killed);
203                        Ok(client_conn)
204                    }
205                })
206            });
207
208        let fs = FakeFs::new(cx.background());
209        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
210        let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
211        let mut language_registry = LanguageRegistry::test();
212        language_registry.set_executor(cx.background());
213        let app_state = Arc::new(workspace::AppState {
214            client: client.clone(),
215            user_store: user_store.clone(),
216            workspace_store,
217            languages: Arc::new(language_registry),
218            fs: fs.clone(),
219            build_window_options: |_, _, _| Default::default(),
220            initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
221            background_actions: || &[],
222            node_runtime: FakeNodeRuntime::new(),
223        });
224
225        cx.update(|cx| {
226            theme::init((), cx);
227            Project::init(&client, cx);
228            client::init(&client, cx);
229            language::init(cx);
230            editor::init_settings(cx);
231            workspace::init(app_state.clone(), cx);
232            audio::init((), cx);
233            call::init(client.clone(), user_store.clone(), cx);
234            channel::init(&client, user_store, cx);
235        });
236
237        client
238            .authenticate_and_connect(false, &cx.to_async())
239            .await
240            .unwrap();
241
242        let client = TestClient {
243            app_state,
244            username: name.to_string(),
245            channel_store: cx.read(ChannelStore::global).clone(),
246            state: Default::default(),
247        };
248        client.wait_for_current_user(cx).await;
249        client
250    }
251
252    pub fn disconnect_client(&self, peer_id: PeerId) {
253        self.connection_killers
254            .lock()
255            .remove(&peer_id)
256            .unwrap()
257            .store(true, SeqCst);
258    }
259
260    pub fn simulate_long_connection_interruption(
261        &self,
262        peer_id: PeerId,
263        deterministic: &Arc<Deterministic>,
264    ) {
265        self.forbid_connections();
266        self.disconnect_client(peer_id);
267        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
268        self.allow_connections();
269        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
270        deterministic.run_until_parked();
271    }
272
273    pub fn forbid_connections(&self) {
274        self.forbid_connections.store(true, SeqCst);
275    }
276
277    pub fn allow_connections(&self) {
278        self.forbid_connections.store(false, SeqCst);
279    }
280
281    pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
282        for ix in 1..clients.len() {
283            let (left, right) = clients.split_at_mut(ix);
284            let (client_a, cx_a) = left.last_mut().unwrap();
285            for (client_b, cx_b) in right {
286                client_a
287                    .app_state
288                    .user_store
289                    .update(*cx_a, |store, cx| {
290                        store.request_contact(client_b.user_id().unwrap(), cx)
291                    })
292                    .await
293                    .unwrap();
294                cx_a.foreground().run_until_parked();
295                client_b
296                    .app_state
297                    .user_store
298                    .update(*cx_b, |store, cx| {
299                        store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
300                    })
301                    .await
302                    .unwrap();
303            }
304        }
305    }
306
307    pub async fn make_channel(
308        &self,
309        channel: &str,
310        parent: Option<u64>,
311        admin: (&TestClient, &mut TestAppContext),
312        members: &mut [(&TestClient, &mut TestAppContext)],
313    ) -> u64 {
314        let (_, admin_cx) = admin;
315        let channel_id = admin_cx
316            .read(ChannelStore::global)
317            .update(admin_cx, |channel_store, cx| {
318                channel_store.create_channel(channel, parent, cx)
319            })
320            .await
321            .unwrap();
322
323        for (member_client, member_cx) in members {
324            admin_cx
325                .read(ChannelStore::global)
326                .update(admin_cx, |channel_store, cx| {
327                    channel_store.invite_member(
328                        channel_id,
329                        member_client.user_id().unwrap(),
330                        ChannelRole::Member,
331                        cx,
332                    )
333                })
334                .await
335                .unwrap();
336
337            admin_cx.foreground().run_until_parked();
338
339            member_cx
340                .read(ChannelStore::global)
341                .update(*member_cx, |channels, _| {
342                    channels.respond_to_channel_invite(channel_id, true)
343                })
344                .await
345                .unwrap();
346        }
347
348        channel_id
349    }
350
351    pub async fn make_channel_tree(
352        &self,
353        channels: &[(&str, Option<&str>)],
354        creator: (&TestClient, &mut TestAppContext),
355    ) -> Vec<u64> {
356        let mut observed_channels = HashMap::default();
357        let mut result = Vec::new();
358        for (channel, parent) in channels {
359            let id;
360            if let Some(parent) = parent {
361                if let Some(parent_id) = observed_channels.get(parent) {
362                    id = self
363                        .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
364                        .await;
365                } else {
366                    panic!(
367                        "Edge {}->{} referenced before {} was created",
368                        parent, channel, parent
369                    )
370                }
371            } else {
372                id = self
373                    .make_channel(channel, None, (creator.0, creator.1), &mut [])
374                    .await;
375            }
376
377            observed_channels.insert(channel, id);
378            result.push(id);
379        }
380
381        result
382    }
383
384    pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
385        self.make_contacts(clients).await;
386
387        let (left, right) = clients.split_at_mut(1);
388        let (_client_a, cx_a) = &mut left[0];
389        let active_call_a = cx_a.read(ActiveCall::global);
390
391        for (client_b, cx_b) in right {
392            let user_id_b = client_b.current_user_id(*cx_b).to_proto();
393            active_call_a
394                .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
395                .await
396                .unwrap();
397
398            cx_b.foreground().run_until_parked();
399            let active_call_b = cx_b.read(ActiveCall::global);
400            active_call_b
401                .update(*cx_b, |call, cx| call.accept_incoming(cx))
402                .await
403                .unwrap();
404        }
405    }
406
407    pub async fn build_app_state(
408        test_db: &TestDb,
409        fake_server: &live_kit_client::TestServer,
410    ) -> Arc<AppState> {
411        Arc::new(AppState {
412            db: test_db.db().clone(),
413            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
414            config: Default::default(),
415        })
416    }
417}
418
419impl Deref for TestServer {
420    type Target = Server;
421
422    fn deref(&self) -> &Self::Target {
423        &self.server
424    }
425}
426
427impl Drop for TestServer {
428    fn drop(&mut self) {
429        self.server.teardown();
430        self.test_live_kit_server.teardown().unwrap();
431    }
432}
433
434impl Deref for TestClient {
435    type Target = Arc<Client>;
436
437    fn deref(&self) -> &Self::Target {
438        &self.app_state.client
439    }
440}
441
442impl TestClient {
443    pub fn fs(&self) -> &FakeFs {
444        self.app_state.fs.as_fake()
445    }
446
447    pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
448        &self.channel_store
449    }
450
451    pub fn user_store(&self) -> &ModelHandle<UserStore> {
452        &self.app_state.user_store
453    }
454
455    pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
456        &self.app_state.languages
457    }
458
459    pub fn client(&self) -> &Arc<Client> {
460        &self.app_state.client
461    }
462
463    pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
464        UserId::from_proto(
465            self.app_state
466                .user_store
467                .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
468        )
469    }
470
471    pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
472        let mut authed_user = self
473            .app_state
474            .user_store
475            .read_with(cx, |user_store, _| user_store.watch_current_user());
476        while authed_user.next().await.unwrap().is_none() {}
477    }
478
479    pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
480        self.app_state
481            .user_store
482            .update(cx, |store, _| store.clear_contacts())
483            .await;
484    }
485
486    pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
487        Ref::map(self.state.borrow(), |state| &state.local_projects)
488    }
489
490    pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
491        Ref::map(self.state.borrow(), |state| &state.remote_projects)
492    }
493
494    pub fn local_projects_mut<'a>(
495        &'a self,
496    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
497        RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
498    }
499
500    pub fn remote_projects_mut<'a>(
501        &'a self,
502    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
503        RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
504    }
505
506    pub fn buffers_for_project<'a>(
507        &'a self,
508        project: &ModelHandle<Project>,
509    ) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
510        RefMut::map(self.state.borrow_mut(), |state| {
511            state.buffers.entry(project.clone()).or_default()
512        })
513    }
514
515    pub fn buffers<'a>(
516        &'a self,
517    ) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
518    {
519        RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
520    }
521
522    pub fn channel_buffers<'a>(
523        &'a self,
524    ) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
525        RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
526    }
527
528    pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
529        self.app_state
530            .user_store
531            .read_with(cx, |store, _| ContactsSummary {
532                current: store
533                    .contacts()
534                    .iter()
535                    .map(|contact| contact.user.github_login.clone())
536                    .collect(),
537                outgoing_requests: store
538                    .outgoing_contact_requests()
539                    .iter()
540                    .map(|user| user.github_login.clone())
541                    .collect(),
542                incoming_requests: store
543                    .incoming_contact_requests()
544                    .iter()
545                    .map(|user| user.github_login.clone())
546                    .collect(),
547            })
548    }
549
550    pub async fn build_local_project(
551        &self,
552        root_path: impl AsRef<Path>,
553        cx: &mut TestAppContext,
554    ) -> (ModelHandle<Project>, WorktreeId) {
555        let project = self.build_empty_local_project(cx);
556        let (worktree, _) = project
557            .update(cx, |p, cx| {
558                p.find_or_create_local_worktree(root_path, true, cx)
559            })
560            .await
561            .unwrap();
562        worktree
563            .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
564            .await;
565        (project, worktree.read_with(cx, |tree, _| tree.id()))
566    }
567
568    pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
569        cx.update(|cx| {
570            Project::local(
571                self.client().clone(),
572                self.app_state.node_runtime.clone(),
573                self.app_state.user_store.clone(),
574                self.app_state.languages.clone(),
575                self.app_state.fs.clone(),
576                cx,
577            )
578        })
579    }
580
581    pub async fn build_remote_project(
582        &self,
583        host_project_id: u64,
584        guest_cx: &mut TestAppContext,
585    ) -> ModelHandle<Project> {
586        let active_call = guest_cx.read(ActiveCall::global);
587        let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
588        room.update(guest_cx, |room, cx| {
589            room.join_project(
590                host_project_id,
591                self.app_state.languages.clone(),
592                self.app_state.fs.clone(),
593                cx,
594            )
595        })
596        .await
597        .unwrap()
598    }
599
600    pub fn build_workspace(
601        &self,
602        project: &ModelHandle<Project>,
603        cx: &mut TestAppContext,
604    ) -> WindowHandle<Workspace> {
605        cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
606    }
607
608    pub async fn add_admin_to_channel(
609        &self,
610        user: (&TestClient, &mut TestAppContext),
611        channel: u64,
612        cx_self: &mut TestAppContext,
613    ) {
614        let (other_client, other_cx) = user;
615
616        cx_self
617            .read(ChannelStore::global)
618            .update(cx_self, |channel_store, cx| {
619                channel_store.invite_member(
620                    channel,
621                    other_client.user_id().unwrap(),
622                    ChannelRole::Admin,
623                    cx,
624                )
625            })
626            .await
627            .unwrap();
628
629        cx_self.foreground().run_until_parked();
630
631        other_cx
632            .read(ChannelStore::global)
633            .update(other_cx, |channel_store, _| {
634                channel_store.respond_to_channel_invite(channel, true)
635            })
636            .await
637            .unwrap();
638    }
639}
640
641impl Drop for TestClient {
642    fn drop(&mut self) {
643        self.app_state.client.teardown();
644    }
645}