test_server.rs

  1use crate::{
  2    db::{tests::TestDb, NewUserParams, UserId},
  3    executor::Executor,
  4    rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  5    AppState,
  6};
  7use anyhow::anyhow;
  8use call::ActiveCall;
  9use channel::{ChannelBuffer, ChannelStore};
 10use client::{
 11    self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
 12};
 13use collections::{HashMap, HashSet};
 14use fs::FakeFs;
 15use futures::{channel::oneshot, StreamExt as _};
 16use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
 17use language::LanguageRegistry;
 18use parking_lot::Mutex;
 19use project::{Project, WorktreeId};
 20use rpc::RECEIVE_TIMEOUT;
 21use settings::SettingsStore;
 22use std::{
 23    cell::{Ref, RefCell, RefMut},
 24    env,
 25    ops::{Deref, DerefMut},
 26    path::Path,
 27    sync::{
 28        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 29        Arc,
 30    },
 31};
 32use util::http::FakeHttpClient;
 33use workspace::{Workspace, WorkspaceStore};
 34
 35pub struct TestServer {
 36    pub app_state: Arc<AppState>,
 37    pub test_live_kit_server: Arc<live_kit_client::TestServer>,
 38    server: Arc<Server>,
 39    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
 40    forbid_connections: Arc<AtomicBool>,
 41    _test_db: TestDb,
 42}
 43
 44pub struct TestClient {
 45    pub username: String,
 46    pub app_state: Arc<workspace::AppState>,
 47    state: RefCell<TestClientState>,
 48}
 49
 50#[derive(Default)]
 51struct TestClientState {
 52    local_projects: Vec<ModelHandle<Project>>,
 53    remote_projects: Vec<ModelHandle<Project>>,
 54    buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
 55    channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
 56}
 57
 58pub struct ContactsSummary {
 59    pub current: Vec<String>,
 60    pub outgoing_requests: Vec<String>,
 61    pub incoming_requests: Vec<String>,
 62}
 63
 64impl TestServer {
 65    pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
 66        static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
 67
 68        let use_postgres = env::var("USE_POSTGRES").ok();
 69        let use_postgres = use_postgres.as_deref();
 70        let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
 71            TestDb::postgres(deterministic.build_background())
 72        } else {
 73            TestDb::sqlite(deterministic.build_background())
 74        };
 75        let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
 76        let live_kit_server = live_kit_client::TestServer::create(
 77            format!("http://livekit.{}.test", live_kit_server_id),
 78            format!("devkey-{}", live_kit_server_id),
 79            format!("secret-{}", live_kit_server_id),
 80            deterministic.build_background(),
 81        )
 82        .unwrap();
 83        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
 84        let epoch = app_state
 85            .db
 86            .create_server(&app_state.config.zed_environment)
 87            .await
 88            .unwrap();
 89        let server = Server::new(
 90            epoch,
 91            app_state.clone(),
 92            Executor::Deterministic(deterministic.build_background()),
 93        );
 94        server.start().await.unwrap();
 95        // Advance clock to ensure the server's cleanup task is finished.
 96        deterministic.advance_clock(CLEANUP_TIMEOUT);
 97        Self {
 98            app_state,
 99            server,
100            connection_killers: Default::default(),
101            forbid_connections: Default::default(),
102            _test_db: test_db,
103            test_live_kit_server: live_kit_server,
104        }
105    }
106
107    pub async fn reset(&self) {
108        self.app_state.db.reset();
109        let epoch = self
110            .app_state
111            .db
112            .create_server(&self.app_state.config.zed_environment)
113            .await
114            .unwrap();
115        self.server.reset(epoch);
116    }
117
118    pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
119        cx.update(|cx| {
120            if cx.has_global::<SettingsStore>() {
121                panic!("Same cx used to create two test clients")
122            }
123            cx.set_global(SettingsStore::test(cx));
124        });
125
126        let http = FakeHttpClient::with_404_response();
127        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
128        {
129            user.id
130        } else {
131            self.app_state
132                .db
133                .create_user(
134                    &format!("{name}@example.com"),
135                    false,
136                    NewUserParams {
137                        github_login: name.into(),
138                        github_user_id: 0,
139                        invite_count: 0,
140                    },
141                )
142                .await
143                .expect("creating user failed")
144                .user_id
145        };
146        let client_name = name.to_string();
147        let mut client = cx.read(|cx| Client::new(http.clone(), cx));
148        let server = self.server.clone();
149        let db = self.app_state.db.clone();
150        let connection_killers = self.connection_killers.clone();
151        let forbid_connections = self.forbid_connections.clone();
152
153        Arc::get_mut(&mut client)
154            .unwrap()
155            .set_id(user_id.to_proto())
156            .override_authenticate(move |cx| {
157                cx.spawn(|_| async move {
158                    let access_token = "the-token".to_string();
159                    Ok(Credentials {
160                        user_id: user_id.to_proto(),
161                        access_token,
162                    })
163                })
164            })
165            .override_establish_connection(move |credentials, cx| {
166                assert_eq!(credentials.user_id, user_id.0 as u64);
167                assert_eq!(credentials.access_token, "the-token");
168
169                let server = server.clone();
170                let db = db.clone();
171                let connection_killers = connection_killers.clone();
172                let forbid_connections = forbid_connections.clone();
173                let client_name = client_name.clone();
174                cx.spawn(move |cx| async move {
175                    if forbid_connections.load(SeqCst) {
176                        Err(EstablishConnectionError::other(anyhow!(
177                            "server is forbidding connections"
178                        )))
179                    } else {
180                        let (client_conn, server_conn, killed) =
181                            Connection::in_memory(cx.background());
182                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
183                        let user = db
184                            .get_user_by_id(user_id)
185                            .await
186                            .expect("retrieving user failed")
187                            .unwrap();
188                        cx.background()
189                            .spawn(server.handle_connection(
190                                server_conn,
191                                client_name,
192                                user,
193                                Some(connection_id_tx),
194                                Executor::Deterministic(cx.background()),
195                            ))
196                            .detach();
197                        let connection_id = connection_id_rx.await.unwrap();
198                        connection_killers
199                            .lock()
200                            .insert(connection_id.into(), killed);
201                        Ok(client_conn)
202                    }
203                })
204            });
205
206        let fs = FakeFs::new(cx.background());
207        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
208        let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
209        let channel_store =
210            cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
211        let mut language_registry = LanguageRegistry::test();
212        language_registry.set_executor(cx.background());
213        let app_state = Arc::new(workspace::AppState {
214            client: client.clone(),
215            user_store: user_store.clone(),
216            workspace_store,
217            channel_store: channel_store.clone(),
218            languages: Arc::new(language_registry),
219            fs: fs.clone(),
220            build_window_options: |_, _, _| Default::default(),
221            initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
222            background_actions: || &[],
223        });
224
225        cx.update(|cx| {
226            theme::init((), cx);
227            Project::init(&client, cx);
228            client::init(&client, cx);
229            language::init(cx);
230            editor::init_settings(cx);
231            workspace::init(app_state.clone(), cx);
232            audio::init((), cx);
233            call::init(client.clone(), user_store.clone(), cx);
234            channel::init(&client);
235        });
236
237        client
238            .authenticate_and_connect(false, &cx.to_async())
239            .await
240            .unwrap();
241
242        let client = TestClient {
243            app_state,
244            username: name.to_string(),
245            state: Default::default(),
246        };
247        client.wait_for_current_user(cx).await;
248        client
249    }
250
251    pub fn disconnect_client(&self, peer_id: PeerId) {
252        self.connection_killers
253            .lock()
254            .remove(&peer_id)
255            .unwrap()
256            .store(true, SeqCst);
257    }
258
259    pub fn simulate_long_connection_interruption(
260        &self,
261        peer_id: PeerId,
262        deterministic: &Arc<Deterministic>,
263    ) {
264        self.forbid_connections();
265        self.disconnect_client(peer_id);
266        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
267        self.allow_connections();
268        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
269        deterministic.run_until_parked();
270    }
271
272    pub fn forbid_connections(&self) {
273        self.forbid_connections.store(true, SeqCst);
274    }
275
276    pub fn allow_connections(&self) {
277        self.forbid_connections.store(false, SeqCst);
278    }
279
280    pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
281        for ix in 1..clients.len() {
282            let (left, right) = clients.split_at_mut(ix);
283            let (client_a, cx_a) = left.last_mut().unwrap();
284            for (client_b, cx_b) in right {
285                client_a
286                    .app_state
287                    .user_store
288                    .update(*cx_a, |store, cx| {
289                        store.request_contact(client_b.user_id().unwrap(), cx)
290                    })
291                    .await
292                    .unwrap();
293                cx_a.foreground().run_until_parked();
294                client_b
295                    .app_state
296                    .user_store
297                    .update(*cx_b, |store, cx| {
298                        store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
299                    })
300                    .await
301                    .unwrap();
302            }
303        }
304    }
305
306    pub async fn make_channel(
307        &self,
308        channel: &str,
309        parent: Option<u64>,
310        admin: (&TestClient, &mut TestAppContext),
311        members: &mut [(&TestClient, &mut TestAppContext)],
312    ) -> u64 {
313        let (admin_client, admin_cx) = admin;
314        let channel_id = admin_client
315            .app_state
316            .channel_store
317            .update(admin_cx, |channel_store, cx| {
318                channel_store.create_channel(channel, parent, cx)
319            })
320            .await
321            .unwrap();
322
323        for (member_client, member_cx) in members {
324            admin_client
325                .app_state
326                .channel_store
327                .update(admin_cx, |channel_store, cx| {
328                    channel_store.invite_member(
329                        channel_id,
330                        member_client.user_id().unwrap(),
331                        false,
332                        cx,
333                    )
334                })
335                .await
336                .unwrap();
337
338            admin_cx.foreground().run_until_parked();
339
340            member_client
341                .app_state
342                .channel_store
343                .update(*member_cx, |channels, _| {
344                    channels.respond_to_channel_invite(channel_id, true)
345                })
346                .await
347                .unwrap();
348        }
349
350        channel_id
351    }
352
353    pub async fn make_channel_tree(
354        &self,
355        channels: &[(&str, Option<&str>)],
356        creator: (&TestClient, &mut TestAppContext),
357    ) -> Vec<u64> {
358        let mut observed_channels = HashMap::default();
359        let mut result = Vec::new();
360        for (channel, parent) in channels {
361            let id;
362            if let Some(parent) = parent {
363                if let Some(parent_id) = observed_channels.get(parent) {
364                    id = self
365                        .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
366                        .await;
367                } else {
368                    panic!(
369                        "Edge {}->{} referenced before {} was created",
370                        parent, channel, parent
371                    )
372                }
373            } else {
374                id = self
375                    .make_channel(channel, None, (creator.0, creator.1), &mut [])
376                    .await;
377            }
378
379            observed_channels.insert(channel, id);
380            result.push(id);
381        }
382
383        result
384    }
385
386    pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
387        self.make_contacts(clients).await;
388
389        let (left, right) = clients.split_at_mut(1);
390        let (_client_a, cx_a) = &mut left[0];
391        let active_call_a = cx_a.read(ActiveCall::global);
392
393        for (client_b, cx_b) in right {
394            let user_id_b = client_b.current_user_id(*cx_b).to_proto();
395            active_call_a
396                .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
397                .await
398                .unwrap();
399
400            cx_b.foreground().run_until_parked();
401            let active_call_b = cx_b.read(ActiveCall::global);
402            active_call_b
403                .update(*cx_b, |call, cx| call.accept_incoming(cx))
404                .await
405                .unwrap();
406        }
407    }
408
409    pub async fn build_app_state(
410        test_db: &TestDb,
411        fake_server: &live_kit_client::TestServer,
412    ) -> Arc<AppState> {
413        Arc::new(AppState {
414            db: test_db.db().clone(),
415            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
416            config: Default::default(),
417        })
418    }
419}
420
421impl Deref for TestServer {
422    type Target = Server;
423
424    fn deref(&self) -> &Self::Target {
425        &self.server
426    }
427}
428
429impl Drop for TestServer {
430    fn drop(&mut self) {
431        self.server.teardown();
432        self.test_live_kit_server.teardown().unwrap();
433    }
434}
435
436impl Deref for TestClient {
437    type Target = Arc<Client>;
438
439    fn deref(&self) -> &Self::Target {
440        &self.app_state.client
441    }
442}
443
444impl TestClient {
445    pub fn fs(&self) -> &FakeFs {
446        self.app_state.fs.as_fake()
447    }
448
449    pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
450        &self.app_state.channel_store
451    }
452
453    pub fn user_store(&self) -> &ModelHandle<UserStore> {
454        &self.app_state.user_store
455    }
456
457    pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
458        &self.app_state.languages
459    }
460
461    pub fn client(&self) -> &Arc<Client> {
462        &self.app_state.client
463    }
464
465    pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
466        UserId::from_proto(
467            self.app_state
468                .user_store
469                .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
470        )
471    }
472
473    pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
474        let mut authed_user = self
475            .app_state
476            .user_store
477            .read_with(cx, |user_store, _| user_store.watch_current_user());
478        while authed_user.next().await.unwrap().is_none() {}
479    }
480
481    pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
482        self.app_state
483            .user_store
484            .update(cx, |store, _| store.clear_contacts())
485            .await;
486    }
487
488    pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
489        Ref::map(self.state.borrow(), |state| &state.local_projects)
490    }
491
492    pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
493        Ref::map(self.state.borrow(), |state| &state.remote_projects)
494    }
495
496    pub fn local_projects_mut<'a>(
497        &'a self,
498    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
499        RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
500    }
501
502    pub fn remote_projects_mut<'a>(
503        &'a self,
504    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
505        RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
506    }
507
508    pub fn buffers_for_project<'a>(
509        &'a self,
510        project: &ModelHandle<Project>,
511    ) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
512        RefMut::map(self.state.borrow_mut(), |state| {
513            state.buffers.entry(project.clone()).or_default()
514        })
515    }
516
517    pub fn buffers<'a>(
518        &'a self,
519    ) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
520    {
521        RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
522    }
523
524    pub fn channel_buffers<'a>(
525        &'a self,
526    ) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
527        RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
528    }
529
530    pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
531        self.app_state
532            .user_store
533            .read_with(cx, |store, _| ContactsSummary {
534                current: store
535                    .contacts()
536                    .iter()
537                    .map(|contact| contact.user.github_login.clone())
538                    .collect(),
539                outgoing_requests: store
540                    .outgoing_contact_requests()
541                    .iter()
542                    .map(|user| user.github_login.clone())
543                    .collect(),
544                incoming_requests: store
545                    .incoming_contact_requests()
546                    .iter()
547                    .map(|user| user.github_login.clone())
548                    .collect(),
549            })
550    }
551
552    pub async fn build_local_project(
553        &self,
554        root_path: impl AsRef<Path>,
555        cx: &mut TestAppContext,
556    ) -> (ModelHandle<Project>, WorktreeId) {
557        let project = self.build_empty_local_project(cx);
558        let (worktree, _) = project
559            .update(cx, |p, cx| {
560                p.find_or_create_local_worktree(root_path, true, cx)
561            })
562            .await
563            .unwrap();
564        worktree
565            .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
566            .await;
567        (project, worktree.read_with(cx, |tree, _| tree.id()))
568    }
569
570    pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
571        cx.update(|cx| {
572            Project::local(
573                self.client().clone(),
574                self.app_state.user_store.clone(),
575                self.app_state.languages.clone(),
576                self.app_state.fs.clone(),
577                cx,
578            )
579        })
580    }
581
582    pub async fn build_remote_project(
583        &self,
584        host_project_id: u64,
585        guest_cx: &mut TestAppContext,
586    ) -> ModelHandle<Project> {
587        let active_call = guest_cx.read(ActiveCall::global);
588        let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
589        room.update(guest_cx, |room, cx| {
590            room.join_project(
591                host_project_id,
592                self.app_state.languages.clone(),
593                self.app_state.fs.clone(),
594                cx,
595            )
596        })
597        .await
598        .unwrap()
599    }
600
601    pub fn build_workspace(
602        &self,
603        project: &ModelHandle<Project>,
604        cx: &mut TestAppContext,
605    ) -> WindowHandle<Workspace> {
606        cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
607    }
608
609    pub async fn add_admin_to_channel(
610        &self,
611        user: (&TestClient, &mut TestAppContext),
612        channel: u64,
613        cx_self: &mut TestAppContext,
614    ) {
615        let (other_client, other_cx) = user;
616
617        self.app_state
618            .channel_store
619            .update(cx_self, |channel_store, cx| {
620                channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx)
621            })
622            .await
623            .unwrap();
624
625        cx_self.foreground().run_until_parked();
626
627        other_client
628            .app_state
629            .channel_store
630            .update(other_cx, |channels, _| {
631                channels.respond_to_channel_invite(channel, true)
632            })
633            .await
634            .unwrap();
635    }
636}
637
638impl Drop for TestClient {
639    fn drop(&mut self) {
640        self.app_state.client.teardown();
641    }
642}