test_server.rs

  1use crate::{
  2    db::{tests::TestDb, NewUserParams, UserId},
  3    executor::Executor,
  4    rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  5    AppState,
  6};
  7use anyhow::anyhow;
  8use call::ActiveCall;
  9use channel::{ChannelBuffer, ChannelStore};
 10use client::{
 11    self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
 12};
 13use collections::{HashMap, HashSet};
 14use fs::FakeFs;
 15use futures::{channel::oneshot, StreamExt as _};
 16use gpui::{BackgroundExecutor, Context, Model, TestAppContext, WindowHandle};
 17use language::LanguageRegistry;
 18use node_runtime::FakeNodeRuntime;
 19
 20use parking_lot::Mutex;
 21use project::{Project, WorktreeId};
 22use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT};
 23use settings::SettingsStore;
 24use std::{
 25    cell::{Ref, RefCell, RefMut},
 26    env,
 27    ops::{Deref, DerefMut},
 28    path::Path,
 29    sync::{
 30        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 31        Arc,
 32    },
 33};
 34use util::http::FakeHttpClient;
 35use workspace::{Workspace, WorkspaceStore};
 36
 37pub struct TestServer {
 38    pub app_state: Arc<AppState>,
 39    pub test_live_kit_server: Arc<live_kit_client::TestServer>,
 40    server: Arc<Server>,
 41    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
 42    forbid_connections: Arc<AtomicBool>,
 43    _test_db: TestDb,
 44}
 45
 46pub struct TestClient {
 47    pub username: String,
 48    pub app_state: Arc<workspace::AppState>,
 49    channel_store: Model<ChannelStore>,
 50    // todo!(notifications)
 51    // notification_store: Model<NotificationStore>,
 52    state: RefCell<TestClientState>,
 53}
 54
 55#[derive(Default)]
 56struct TestClientState {
 57    local_projects: Vec<Model<Project>>,
 58    remote_projects: Vec<Model<Project>>,
 59    buffers: HashMap<Model<Project>, HashSet<Model<language::Buffer>>>,
 60    channel_buffers: HashSet<Model<ChannelBuffer>>,
 61}
 62
 63pub struct ContactsSummary {
 64    pub current: Vec<String>,
 65    pub outgoing_requests: Vec<String>,
 66    pub incoming_requests: Vec<String>,
 67}
 68
 69impl TestServer {
 70    pub async fn start(deterministic: BackgroundExecutor) -> Self {
 71        static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
 72
 73        let use_postgres = env::var("USE_POSTGRES").ok();
 74        let use_postgres = use_postgres.as_deref();
 75        let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
 76            TestDb::postgres(deterministic.clone())
 77        } else {
 78            TestDb::sqlite(deterministic.clone())
 79        };
 80        let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
 81        let live_kit_server = live_kit_client::TestServer::create(
 82            format!("http://livekit.{}.test", live_kit_server_id),
 83            format!("devkey-{}", live_kit_server_id),
 84            format!("secret-{}", live_kit_server_id),
 85            deterministic.clone(),
 86        )
 87        .unwrap();
 88        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
 89        let epoch = app_state
 90            .db
 91            .create_server(&app_state.config.zed_environment)
 92            .await
 93            .unwrap();
 94        let server = Server::new(
 95            epoch,
 96            app_state.clone(),
 97            Executor::Deterministic(deterministic.clone()),
 98        );
 99        server.start().await.unwrap();
100        // Advance clock to ensure the server's cleanup task is finished.
101        deterministic.advance_clock(CLEANUP_TIMEOUT);
102        Self {
103            app_state,
104            server,
105            connection_killers: Default::default(),
106            forbid_connections: Default::default(),
107            _test_db: test_db,
108            test_live_kit_server: live_kit_server,
109        }
110    }
111
112    pub async fn reset(&self) {
113        self.app_state.db.reset();
114        let epoch = self
115            .app_state
116            .db
117            .create_server(&self.app_state.config.zed_environment)
118            .await
119            .unwrap();
120        self.server.reset(epoch);
121    }
122
123    pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
124        cx.update(|cx| {
125            println!("PLS SHOW UP");
126            if cx.has_global::<SettingsStore>() {
127                panic!("Same cx used to create two test clients")
128            }
129            let settings = SettingsStore::test(cx);
130            cx.set_global(settings);
131        });
132
133        let http = FakeHttpClient::with_404_response();
134        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
135        {
136            user.id
137        } else {
138            self.app_state
139                .db
140                .create_user(
141                    &format!("{name}@example.com"),
142                    false,
143                    NewUserParams {
144                        github_login: name.into(),
145                        github_user_id: 0,
146                    },
147                )
148                .await
149                .expect("creating user failed")
150                .user_id
151        };
152        let client_name = name.to_string();
153        let mut client = cx.read(|cx| Client::new(http.clone(), cx));
154        let server = self.server.clone();
155        let db = self.app_state.db.clone();
156        let connection_killers = self.connection_killers.clone();
157        let forbid_connections = self.forbid_connections.clone();
158
159        Arc::get_mut(&mut client)
160            .unwrap()
161            .set_id(user_id.to_proto())
162            .override_authenticate(move |cx| {
163                cx.spawn(|_| async move {
164                    let access_token = "the-token".to_string();
165                    Ok(Credentials {
166                        user_id: user_id.to_proto(),
167                        access_token,
168                    })
169                })
170            })
171            .override_establish_connection(move |credentials, cx| {
172                assert_eq!(credentials.user_id, user_id.0 as u64);
173                assert_eq!(credentials.access_token, "the-token");
174
175                let server = server.clone();
176                let db = db.clone();
177                let connection_killers = connection_killers.clone();
178                let forbid_connections = forbid_connections.clone();
179                let client_name = client_name.clone();
180                cx.spawn(move |cx| async move {
181                    if forbid_connections.load(SeqCst) {
182                        Err(EstablishConnectionError::other(anyhow!(
183                            "server is forbidding connections"
184                        )))
185                    } else {
186                        let (client_conn, server_conn, killed) =
187                            Connection::in_memory(cx.background_executor().clone());
188                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
189                        let user = db
190                            .get_user_by_id(user_id)
191                            .await
192                            .expect("retrieving user failed")
193                            .unwrap();
194                        cx.background_executor()
195                            .spawn(server.handle_connection(
196                                server_conn,
197                                client_name,
198                                user,
199                                Some(connection_id_tx),
200                                Executor::Deterministic(cx.background_executor().clone()),
201                            ))
202                            .detach();
203                        let connection_id = connection_id_rx.await.unwrap();
204                        connection_killers
205                            .lock()
206                            .insert(connection_id.into(), killed);
207                        Ok(client_conn)
208                    }
209                })
210            });
211
212        let fs = FakeFs::new(cx.executor().clone());
213        let user_store = cx.build_model(|cx| UserStore::new(client.clone(), http, cx));
214        let workspace_store = cx.build_model(|cx| WorkspaceStore::new(client.clone(), cx));
215        let mut language_registry = LanguageRegistry::test();
216        language_registry.set_executor(cx.executor().clone());
217        let app_state = Arc::new(workspace::AppState {
218            client: client.clone(),
219            user_store: user_store.clone(),
220            workspace_store,
221            languages: Arc::new(language_registry),
222            fs: fs.clone(),
223            build_window_options: |_, _, _| Default::default(),
224            initialize_workspace: |_, _, _, _| gpui::Task::ready(Ok(())),
225            node_runtime: FakeNodeRuntime::new(),
226        });
227
228        cx.update(|cx| {
229            theme::init(cx);
230            Project::init(&client, cx);
231            client::init(&client, cx);
232            language::init(cx);
233            editor::init_settings(cx);
234            workspace::init(app_state.clone(), cx);
235            audio::init((), cx);
236            call::init(client.clone(), user_store.clone(), cx);
237            channel::init(&client, user_store.clone(), cx);
238            //todo(notifications)
239            // notifications::init(client.clone(), user_store, cx);
240        });
241
242        client
243            .authenticate_and_connect(false, &cx.to_async())
244            .await
245            .unwrap();
246
247        let client = TestClient {
248            app_state,
249            username: name.to_string(),
250            channel_store: cx.read(ChannelStore::global).clone(),
251            // todo!(notifications)
252            // notification_store: cx.read(NotificationStore::global).clone(),
253            state: Default::default(),
254        };
255        client.wait_for_current_user(cx).await;
256        client
257    }
258
259    pub fn disconnect_client(&self, peer_id: PeerId) {
260        self.connection_killers
261            .lock()
262            .remove(&peer_id)
263            .unwrap()
264            .store(true, SeqCst);
265    }
266
267    //todo!(workspace)
268    #[allow(dead_code)]
269    pub fn simulate_long_connection_interruption(
270        &self,
271        peer_id: PeerId,
272        deterministic: BackgroundExecutor,
273    ) {
274        self.forbid_connections();
275        self.disconnect_client(peer_id);
276        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
277        self.allow_connections();
278        deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
279        deterministic.run_until_parked();
280    }
281
282    pub fn forbid_connections(&self) {
283        self.forbid_connections.store(true, SeqCst);
284    }
285
286    pub fn allow_connections(&self) {
287        self.forbid_connections.store(false, SeqCst);
288    }
289
290    pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
291        for ix in 1..clients.len() {
292            let (left, right) = clients.split_at_mut(ix);
293            let (client_a, cx_a) = left.last_mut().unwrap();
294            for (client_b, cx_b) in right {
295                client_a
296                    .app_state
297                    .user_store
298                    .update(*cx_a, |store, cx| {
299                        store.request_contact(client_b.user_id().unwrap(), cx)
300                    })
301                    .await
302                    .unwrap();
303                cx_a.executor().run_until_parked();
304                client_b
305                    .app_state
306                    .user_store
307                    .update(*cx_b, |store, cx| {
308                        store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
309                    })
310                    .await
311                    .unwrap();
312            }
313        }
314    }
315
316    pub async fn make_channel(
317        &self,
318        channel: &str,
319        parent: Option<u64>,
320        admin: (&TestClient, &mut TestAppContext),
321        members: &mut [(&TestClient, &mut TestAppContext)],
322    ) -> u64 {
323        let (_, admin_cx) = admin;
324        let channel_id = admin_cx
325            .read(ChannelStore::global)
326            .update(admin_cx, |channel_store, cx| {
327                channel_store.create_channel(channel, parent, cx)
328            })
329            .await
330            .unwrap();
331
332        for (member_client, member_cx) in members {
333            admin_cx
334                .read(ChannelStore::global)
335                .update(admin_cx, |channel_store, cx| {
336                    channel_store.invite_member(
337                        channel_id,
338                        member_client.user_id().unwrap(),
339                        ChannelRole::Member,
340                        cx,
341                    )
342                })
343                .await
344                .unwrap();
345
346            admin_cx.executor().run_until_parked();
347
348            member_cx
349                .read(ChannelStore::global)
350                .update(*member_cx, |channels, cx| {
351                    channels.respond_to_channel_invite(channel_id, true, cx)
352                })
353                .await
354                .unwrap();
355        }
356
357        channel_id
358    }
359
360    pub async fn make_channel_tree(
361        &self,
362        channels: &[(&str, Option<&str>)],
363        creator: (&TestClient, &mut TestAppContext),
364    ) -> Vec<u64> {
365        let mut observed_channels = HashMap::default();
366        let mut result = Vec::new();
367        for (channel, parent) in channels {
368            let id;
369            if let Some(parent) = parent {
370                if let Some(parent_id) = observed_channels.get(parent) {
371                    id = self
372                        .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut [])
373                        .await;
374                } else {
375                    panic!(
376                        "Edge {}->{} referenced before {} was created",
377                        parent, channel, parent
378                    )
379                }
380            } else {
381                id = self
382                    .make_channel(channel, None, (creator.0, creator.1), &mut [])
383                    .await;
384            }
385
386            observed_channels.insert(channel, id);
387            result.push(id);
388        }
389
390        result
391    }
392
393    pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
394        self.make_contacts(clients).await;
395
396        let (left, right) = clients.split_at_mut(1);
397        let (_client_a, cx_a) = &mut left[0];
398        let active_call_a = cx_a.read(ActiveCall::global);
399
400        for (client_b, cx_b) in right {
401            let user_id_b = client_b.current_user_id(*cx_b).to_proto();
402            active_call_a
403                .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
404                .await
405                .unwrap();
406
407            cx_b.executor().run_until_parked();
408            let active_call_b = cx_b.read(ActiveCall::global);
409            active_call_b
410                .update(*cx_b, |call, cx| call.accept_incoming(cx))
411                .await
412                .unwrap();
413        }
414    }
415
416    pub async fn build_app_state(
417        test_db: &TestDb,
418        fake_server: &live_kit_client::TestServer,
419    ) -> Arc<AppState> {
420        Arc::new(AppState {
421            db: test_db.db().clone(),
422            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
423            config: Default::default(),
424        })
425    }
426}
427
428impl Deref for TestServer {
429    type Target = Server;
430
431    fn deref(&self) -> &Self::Target {
432        &self.server
433    }
434}
435
436impl Drop for TestServer {
437    fn drop(&mut self) {
438        self.server.teardown();
439        self.test_live_kit_server.teardown().unwrap();
440    }
441}
442
443impl Deref for TestClient {
444    type Target = Arc<Client>;
445
446    fn deref(&self) -> &Self::Target {
447        &self.app_state.client
448    }
449}
450
451impl TestClient {
452    pub fn fs(&self) -> &FakeFs {
453        self.app_state.fs.as_fake()
454    }
455
456    pub fn channel_store(&self) -> &Model<ChannelStore> {
457        &self.channel_store
458    }
459
460    // todo!(notifications)
461    // pub fn notification_store(&self) -> &Model<NotificationStore> {
462    //     &self.notification_store
463    // }
464
465    pub fn user_store(&self) -> &Model<UserStore> {
466        &self.app_state.user_store
467    }
468
469    pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
470        &self.app_state.languages
471    }
472
473    pub fn client(&self) -> &Arc<Client> {
474        &self.app_state.client
475    }
476
477    pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
478        UserId::from_proto(
479            self.app_state
480                .user_store
481                .read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
482        )
483    }
484
485    pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
486        let mut authed_user = self
487            .app_state
488            .user_store
489            .read_with(cx, |user_store, _| user_store.watch_current_user());
490        while authed_user.next().await.unwrap().is_none() {}
491    }
492
493    pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
494        self.app_state
495            .user_store
496            .update(cx, |store, _| store.clear_contacts())
497            .await;
498    }
499
500    pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<Model<Project>>> + 'a {
501        Ref::map(self.state.borrow(), |state| &state.local_projects)
502    }
503
504    pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<Model<Project>>> + 'a {
505        Ref::map(self.state.borrow(), |state| &state.remote_projects)
506    }
507
508    pub fn local_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<Model<Project>>> + 'a {
509        RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
510    }
511
512    pub fn remote_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<Model<Project>>> + 'a {
513        RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
514    }
515
516    pub fn buffers_for_project<'a>(
517        &'a self,
518        project: &Model<Project>,
519    ) -> impl DerefMut<Target = HashSet<Model<language::Buffer>>> + 'a {
520        RefMut::map(self.state.borrow_mut(), |state| {
521            state.buffers.entry(project.clone()).or_default()
522        })
523    }
524
525    pub fn buffers<'a>(
526        &'a self,
527    ) -> impl DerefMut<Target = HashMap<Model<Project>, HashSet<Model<language::Buffer>>>> + 'a
528    {
529        RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
530    }
531
532    pub fn channel_buffers<'a>(
533        &'a self,
534    ) -> impl DerefMut<Target = HashSet<Model<ChannelBuffer>>> + 'a {
535        RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
536    }
537
538    pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
539        self.app_state
540            .user_store
541            .read_with(cx, |store, _| ContactsSummary {
542                current: store
543                    .contacts()
544                    .iter()
545                    .map(|contact| contact.user.github_login.clone())
546                    .collect(),
547                outgoing_requests: store
548                    .outgoing_contact_requests()
549                    .iter()
550                    .map(|user| user.github_login.clone())
551                    .collect(),
552                incoming_requests: store
553                    .incoming_contact_requests()
554                    .iter()
555                    .map(|user| user.github_login.clone())
556                    .collect(),
557            })
558    }
559
560    pub async fn build_local_project(
561        &self,
562        root_path: impl AsRef<Path>,
563        cx: &mut TestAppContext,
564    ) -> (Model<Project>, WorktreeId) {
565        let project = self.build_empty_local_project(cx);
566        let (worktree, _) = project
567            .update(cx, |p, cx| {
568                p.find_or_create_local_worktree(root_path, true, cx)
569            })
570            .await
571            .unwrap();
572        worktree
573            .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
574            .await;
575        (project, worktree.read_with(cx, |tree, _| tree.id()))
576    }
577
578    pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> Model<Project> {
579        cx.update(|cx| {
580            Project::local(
581                self.client().clone(),
582                self.app_state.node_runtime.clone(),
583                self.app_state.user_store.clone(),
584                self.app_state.languages.clone(),
585                self.app_state.fs.clone(),
586                cx,
587            )
588        })
589    }
590
591    pub async fn build_remote_project(
592        &self,
593        host_project_id: u64,
594        guest_cx: &mut TestAppContext,
595    ) -> Model<Project> {
596        let active_call = guest_cx.read(ActiveCall::global);
597        let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
598        room.update(guest_cx, |room, cx| {
599            room.join_project(
600                host_project_id,
601                self.app_state.languages.clone(),
602                self.app_state.fs.clone(),
603                cx,
604            )
605        })
606        .await
607        .unwrap()
608    }
609
610    //todo(workspace)
611    #[allow(dead_code)]
612    pub fn build_workspace(
613        &self,
614        project: &Model<Project>,
615        cx: &mut TestAppContext,
616    ) -> WindowHandle<Workspace> {
617        cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
618    }
619}
620
621impl Drop for TestClient {
622    fn drop(&mut self) {
623        self.app_state.client.teardown();
624    }
625}