test_server.rs

  1use crate::{
  2    db::{tests::TestDb, NewUserParams, UserId},
  3    executor::Executor,
  4    rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  5    AppState,
  6};
  7use anyhow::anyhow;
  8use call::ActiveCall;
  9use channel::{ChannelBuffer, ChannelStore};
 10use client::{
 11    self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
 12};
 13use collections::{HashMap, HashSet};
 14use fs::FakeFs;
 15use futures::{channel::oneshot, StreamExt as _};
 16use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
 17use language::LanguageRegistry;
 18use parking_lot::Mutex;
 19use project::{Project, WorktreeId};
 20use rpc::RECEIVE_TIMEOUT;
 21use settings::SettingsStore;
 22use std::{
 23    cell::{Ref, RefCell, RefMut},
 24    env,
 25    ops::{Deref, DerefMut},
 26    path::Path,
 27    sync::{
 28        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 29        Arc,
 30    },
 31};
 32use util::http::FakeHttpClient;
 33use workspace::{Workspace, WorkspaceStore};
 34
 35pub struct TestServer {
 36    pub app_state: Arc<AppState>,
 37    pub test_live_kit_server: Arc<live_kit_client::TestServer>,
 38    server: Arc<Server>,
 39    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
 40    forbid_connections: Arc<AtomicBool>,
 41    _test_db: TestDb,
 42}
 43
 44pub struct TestClient {
 45    pub username: String,
 46    pub app_state: Arc<workspace::AppState>,
 47    channel_store: ModelHandle<ChannelStore>,
 48    state: RefCell<TestClientState>,
 49}
 50
 51#[derive(Default)]
 52struct TestClientState {
 53    local_projects: Vec<ModelHandle<Project>>,
 54    remote_projects: Vec<ModelHandle<Project>>,
 55    buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
 56    channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
 57}
 58
 59pub struct ContactsSummary {
 60    pub current: Vec<String>,
 61    pub outgoing_requests: Vec<String>,
 62    pub incoming_requests: Vec<String>,
 63}
 64
 65impl TestServer {
 66    pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
 67        static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
 68
 69        let use_postgres = env::var("USE_POSTGRES").ok();
 70        let use_postgres = use_postgres.as_deref();
 71        let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
 72            TestDb::postgres(deterministic.build_background())
 73        } else {
 74            TestDb::sqlite(deterministic.build_background())
 75        };
 76        let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
 77        let live_kit_server = live_kit_client::TestServer::create(
 78            format!("http://livekit.{}.test", live_kit_server_id),
 79            format!("devkey-{}", live_kit_server_id),
 80            format!("secret-{}", live_kit_server_id),
 81            deterministic.build_background(),
 82        )
 83        .unwrap();
 84        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
 85        let epoch = app_state
 86            .db
 87            .create_server(&app_state.config.zed_environment)
 88            .await
 89            .unwrap();
 90        let server = Server::new(
 91            epoch,
 92            app_state.clone(),
 93            Executor::Deterministic(deterministic.build_background()),
 94        );
 95        server.start().await.unwrap();
 96        // Advance clock to ensure the server's cleanup task is finished.
 97        deterministic.advance_clock(CLEANUP_TIMEOUT);
 98        Self {
 99            app_state,
100            server,
101            connection_killers: Default::default(),
102            forbid_connections: Default::default(),
103            _test_db: test_db,
104            test_live_kit_server: live_kit_server,
105        }
106    }
107
108    pub async fn reset(&self) {
109        self.app_state.db.reset();
110        let epoch = self
111            .app_state
112            .db
113            .create_server(&self.app_state.config.zed_environment)
114            .await
115            .unwrap();
116        self.server.reset(epoch);
117    }
118
119    pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
120        cx.update(|cx| {
121            if cx.has_global::<SettingsStore>() {
122                panic!("Same cx used to create two test clients")
123            }
124            cx.set_global(SettingsStore::test(cx));
125        });
126
127        let http = FakeHttpClient::with_404_response();
128        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
129        {
130            user.id
131        } else {
132            self.app_state
133                .db
134                .create_user(
135                    &format!("{name}@example.com"),
136                    false,
137                    NewUserParams {
138                        github_login: name.into(),
139                        github_user_id: 0,
140                        invite_count: 0,
141                    },
142                )
143                .await
144                .expect("creating user failed")
145                .user_id
146        };
147        let client_name = name.to_string();
148        let mut client = cx.read(|cx| Client::new(http.clone(), cx));
149        let server = self.server.clone();
150        let db = self.app_state.db.clone();
151        let connection_killers = self.connection_killers.clone();
152        let forbid_connections = self.forbid_connections.clone();
153
154        Arc::get_mut(&mut client)
155            .unwrap()
156            .set_id(user_id.to_proto())
157            .override_authenticate(move |cx| {
158                cx.spawn(|_| async move {
159                    let access_token = "the-token".to_string();
160                    Ok(Credentials {
161                        user_id: user_id.to_proto(),
162                        access_token,
163                    })
164                })
165            })
166            .override_establish_connection(move |credentials, cx| {
167                assert_eq!(credentials.user_id, user_id.0 as u64);
168                assert_eq!(credentials.access_token, "the-token");
169
170                let server = server.clone();
171                let db = db.clone();
172                let connection_killers = connection_killers.clone();
173                let forbid_connections = forbid_connections.clone();
174                let client_name = client_name.clone();
175                cx.spawn(move |cx| async move {
176                    if forbid_connections.load(SeqCst) {
177                        Err(EstablishConnectionError::other(anyhow!(
178                            "server is forbidding connections"
179                        )))
180                    } else {
181                        let (client_conn, server_conn, killed) =
182                            Connection::in_memory(cx.background());
183                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
184                        let user = db
185                            .get_user_by_id(user_id)
186                            .await
187                            .expect("retrieving user failed")
188                            .unwrap();
189                        cx.background()
190                            .spawn(server.handle_connection(
191                                server_conn,
192                                client_name,
193                                user,
194                                Some(connection_id_tx),
195                                Executor::Deterministic(cx.background()),
196                            ))
197                            .detach();
198                        let connection_id = connection_id_rx.await.unwrap();
199                        connection_killers
200                            .lock()
201                            .insert(connection_id.into(), killed);
202                        Ok(client_conn)
203                    }
204                })
205            });
206
207        let fs = FakeFs::new(cx.background());
208        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
209        let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
210        let mut language_registry = LanguageRegistry::test();
211        language_registry.set_executor(cx.background());
212        let app_state = Arc::new(workspace::AppState {
213            client: client.clone(),
214            user_store: user_store.clone(),
215            workspace_store,
216            languages: Arc::new(language_registry),
217            fs: fs.clone(),
218            build_window_options: |_, _, _| Default::default(),
219            initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
220            background_actions: || &[],
221        });
222
223        cx.update(|cx| {
224            theme::init((), cx);
225            Project::init(&client, cx);
226            client::init(&client, cx);
227            language::init(cx);
228            editor::init_settings(cx);
229            workspace::init(app_state.clone(), cx);
230            audio::init((), cx);
231            call::init(client.clone(), user_store.clone(), cx);
232            channel::init(&client, user_store, cx);
233        });
234
235        client
236            .authenticate_and_connect(false, &cx.to_async())
237            .await
238            .unwrap();
239
240        let client = TestClient {
241            app_state,
242            username: name.to_string(),
243            channel_store: cx.read(ChannelStore::global).clone(),
244            state: Default::default(),
245        };
246        client.wait_for_current_user(cx).await;
247        client
248    }
249
250    pub fn disconnect_client(&self, peer_id: PeerId) {
251        self.connection_killers
252            .lock()
253            .remove(&peer_id)
254            .unwrap()
255            .store(true, SeqCst);
256    }
257
258    pub fn simulate_long_connection_interruption(
259        &self,
260        peer_id: PeerId,
261        deterministic: &Arc<Deterministic>,
262    ) {
263        self.forbid_connections();
264        self.disconnect_client(peer_id);
265        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
266        self.allow_connections();
267        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
268        deterministic.run_until_parked();
269    }
270
271    pub fn forbid_connections(&self) {
272        self.forbid_connections.store(true, SeqCst);
273    }
274
275    pub fn allow_connections(&self) {
276        self.forbid_connections.store(false, SeqCst);
277    }
278
279    pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
280        for ix in 1..clients.len() {
281            let (left, right) = clients.split_at_mut(ix);
282            let (client_a, cx_a) = left.last_mut().unwrap();
283            for (client_b, cx_b) in right {
284                client_a
285                    .app_state
286                    .user_store
287                    .update(*cx_a, |store, cx| {
288                        store.request_contact(client_b.user_id().unwrap(), cx)
289                    })
290                    .await
291                    .unwrap();
292                cx_a.foreground().run_until_parked();
293                client_b
294                    .app_state
295                    .user_store
296                    .update(*cx_b, |store, cx| {
297                        store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
298                    })
299                    .await
300                    .unwrap();
301            }
302        }
303    }
304
305    pub async fn make_channel(
306        &self,
307        channel: &str,
308        parent: Option<u64>,
309        admin: (&TestClient, &mut TestAppContext),
310        members: &mut [(&TestClient, &mut TestAppContext)],
311    ) -> u64 {
312        let (_, admin_cx) = admin;
313        let channel_id = admin_cx
314            .read(ChannelStore::global)
315            .update(admin_cx, |channel_store, cx| {
316                channel_store.create_channel(channel, parent, cx)
317            })
318            .await
319            .unwrap();
320
321        for (member_client, member_cx) in members {
322            admin_cx
323                .read(ChannelStore::global)
324                .update(admin_cx, |channel_store, cx| {
325                    channel_store.invite_member(
326                        channel_id,
327                        member_client.user_id().unwrap(),
328                        false,
329                        cx,
330                    )
331                })
332                .await
333                .unwrap();
334
335            admin_cx.foreground().run_until_parked();
336
337            member_cx
338                .read(ChannelStore::global)
339                .update(*member_cx, |channels, _| {
340                    channels.respond_to_channel_invite(channel_id, true)
341                })
342                .await
343                .unwrap();
344        }
345
346        channel_id
347    }
348
349    pub async fn make_channel_tree(
350        &self,
351        channels: &[(&str, Option<&str>)],
352        creator: (&TestClient, &mut TestAppContext),
353    ) -> Vec<u64> {
354        let mut observed_channels = HashMap::default();
355        let mut result = Vec::new();
356        for (channel, parent) in channels {
357            let id;
358            if let Some(parent) = parent {
359                if let Some(parent_id) = observed_channels.get(parent) {
360                    id = self
361                        .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
362                        .await;
363                } else {
364                    panic!(
365                        "Edge {}->{} referenced before {} was created",
366                        parent, channel, parent
367                    )
368                }
369            } else {
370                id = self
371                    .make_channel(channel, None, (creator.0, creator.1), &mut [])
372                    .await;
373            }
374
375            observed_channels.insert(channel, id);
376            result.push(id);
377        }
378
379        result
380    }
381
382    pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
383        self.make_contacts(clients).await;
384
385        let (left, right) = clients.split_at_mut(1);
386        let (_client_a, cx_a) = &mut left[0];
387        let active_call_a = cx_a.read(ActiveCall::global);
388
389        for (client_b, cx_b) in right {
390            let user_id_b = client_b.current_user_id(*cx_b).to_proto();
391            active_call_a
392                .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
393                .await
394                .unwrap();
395
396            cx_b.foreground().run_until_parked();
397            let active_call_b = cx_b.read(ActiveCall::global);
398            active_call_b
399                .update(*cx_b, |call, cx| call.accept_incoming(cx))
400                .await
401                .unwrap();
402        }
403    }
404
405    pub async fn build_app_state(
406        test_db: &TestDb,
407        fake_server: &live_kit_client::TestServer,
408    ) -> Arc<AppState> {
409        Arc::new(AppState {
410            db: test_db.db().clone(),
411            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
412            config: Default::default(),
413        })
414    }
415}
416
417impl Deref for TestServer {
418    type Target = Server;
419
420    fn deref(&self) -> &Self::Target {
421        &self.server
422    }
423}
424
425impl Drop for TestServer {
426    fn drop(&mut self) {
427        self.server.teardown();
428        self.test_live_kit_server.teardown().unwrap();
429    }
430}
431
432impl Deref for TestClient {
433    type Target = Arc<Client>;
434
435    fn deref(&self) -> &Self::Target {
436        &self.app_state.client
437    }
438}
439
440impl TestClient {
441    pub fn fs(&self) -> &FakeFs {
442        self.app_state.fs.as_fake()
443    }
444
445    pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
446        &self.channel_store
447    }
448
449    pub fn user_store(&self) -> &ModelHandle<UserStore> {
450        &self.app_state.user_store
451    }
452
453    pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
454        &self.app_state.languages
455    }
456
457    pub fn client(&self) -> &Arc<Client> {
458        &self.app_state.client
459    }
460
461    pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
462        UserId::from_proto(
463            self.app_state
464                .user_store
465                .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
466        )
467    }
468
469    pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
470        let mut authed_user = self
471            .app_state
472            .user_store
473            .read_with(cx, |user_store, _| user_store.watch_current_user());
474        while authed_user.next().await.unwrap().is_none() {}
475    }
476
477    pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
478        self.app_state
479            .user_store
480            .update(cx, |store, _| store.clear_contacts())
481            .await;
482    }
483
484    pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
485        Ref::map(self.state.borrow(), |state| &state.local_projects)
486    }
487
488    pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
489        Ref::map(self.state.borrow(), |state| &state.remote_projects)
490    }
491
492    pub fn local_projects_mut<'a>(
493        &'a self,
494    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
495        RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
496    }
497
498    pub fn remote_projects_mut<'a>(
499        &'a self,
500    ) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
501        RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
502    }
503
504    pub fn buffers_for_project<'a>(
505        &'a self,
506        project: &ModelHandle<Project>,
507    ) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
508        RefMut::map(self.state.borrow_mut(), |state| {
509            state.buffers.entry(project.clone()).or_default()
510        })
511    }
512
513    pub fn buffers<'a>(
514        &'a self,
515    ) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
516    {
517        RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
518    }
519
520    pub fn channel_buffers<'a>(
521        &'a self,
522    ) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
523        RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
524    }
525
526    pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
527        self.app_state
528            .user_store
529            .read_with(cx, |store, _| ContactsSummary {
530                current: store
531                    .contacts()
532                    .iter()
533                    .map(|contact| contact.user.github_login.clone())
534                    .collect(),
535                outgoing_requests: store
536                    .outgoing_contact_requests()
537                    .iter()
538                    .map(|user| user.github_login.clone())
539                    .collect(),
540                incoming_requests: store
541                    .incoming_contact_requests()
542                    .iter()
543                    .map(|user| user.github_login.clone())
544                    .collect(),
545            })
546    }
547
548    pub async fn build_local_project(
549        &self,
550        root_path: impl AsRef<Path>,
551        cx: &mut TestAppContext,
552    ) -> (ModelHandle<Project>, WorktreeId) {
553        let project = self.build_empty_local_project(cx);
554        let (worktree, _) = project
555            .update(cx, |p, cx| {
556                p.find_or_create_local_worktree(root_path, true, cx)
557            })
558            .await
559            .unwrap();
560        worktree
561            .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
562            .await;
563        (project, worktree.read_with(cx, |tree, _| tree.id()))
564    }
565
566    pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
567        cx.update(|cx| {
568            Project::local(
569                self.client().clone(),
570                self.app_state.user_store.clone(),
571                self.app_state.languages.clone(),
572                self.app_state.fs.clone(),
573                cx,
574            )
575        })
576    }
577
578    pub async fn build_remote_project(
579        &self,
580        host_project_id: u64,
581        guest_cx: &mut TestAppContext,
582    ) -> ModelHandle<Project> {
583        let active_call = guest_cx.read(ActiveCall::global);
584        let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
585        room.update(guest_cx, |room, cx| {
586            room.join_project(
587                host_project_id,
588                self.app_state.languages.clone(),
589                self.app_state.fs.clone(),
590                cx,
591            )
592        })
593        .await
594        .unwrap()
595    }
596
597    pub fn build_workspace(
598        &self,
599        project: &ModelHandle<Project>,
600        cx: &mut TestAppContext,
601    ) -> WindowHandle<Workspace> {
602        cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
603    }
604
605    pub async fn add_admin_to_channel(
606        &self,
607        user: (&TestClient, &mut TestAppContext),
608        channel: u64,
609        cx_self: &mut TestAppContext,
610    ) {
611        let (other_client, other_cx) = user;
612
613        cx_self
614            .read(ChannelStore::global)
615            .update(cx_self, |channel_store, cx| {
616                channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx)
617            })
618            .await
619            .unwrap();
620
621        cx_self.foreground().run_until_parked();
622
623        other_cx
624            .read(ChannelStore::global)
625            .update(other_cx, |channel_store, _| {
626                channel_store.respond_to_channel_invite(channel, true)
627            })
628            .await
629            .unwrap();
630    }
631}
632
633impl Drop for TestClient {
634    fn drop(&mut self) {
635        self.app_state.client.teardown();
636    }
637}