collab_tests.rs

  1use call::Room;
  2use client::ChannelId;
  3use gpui::{Entity, TestAppContext};
  4
  5mod agent_sharing_tests;
  6mod channel_buffer_tests;
  7mod channel_guest_tests;
  8mod channel_tests;
  9mod db_tests;
 10mod editor_tests;
 11mod following_tests;
 12mod git_tests;
 13mod integration_tests;
 14mod notification_tests;
 15mod random_channel_buffer_tests;
 16mod random_project_collaboration_tests;
 17mod randomized_test_helpers;
 18mod remote_editing_collaboration_tests;
 19mod test_server;
 20
 21pub use randomized_test_helpers::{
 22    RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan,
 23};
 24pub use test_server::{TestClient, TestServer};
 25
 26#[derive(Debug, Eq, PartialEq)]
 27struct RoomParticipants {
 28    remote: Vec<String>,
 29    pending: Vec<String>,
 30}
 31
 32fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomParticipants {
 33    room.read_with(cx, |room, _| {
 34        let mut remote = room
 35            .remote_participants()
 36            .values()
 37            .map(|participant| participant.user.github_login.clone().to_string())
 38            .collect::<Vec<_>>();
 39        let mut pending = room
 40            .pending_participants()
 41            .iter()
 42            .map(|user| user.github_login.clone().to_string())
 43            .collect::<Vec<_>>();
 44        remote.sort();
 45        pending.sort();
 46        RoomParticipants { remote, pending }
 47    })
 48}
 49
 50fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
 51    cx.read(|cx| room.read(cx).channel_id())
 52}
 53
 54mod auth_token_tests {
 55    use collab::auth::{
 56        AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token,
 57    };
 58    use rand::prelude::*;
 59    use scrypt::Scrypt;
 60    use scrypt::password_hash::{PasswordHasher, SaltString};
 61    use sea_orm::EntityTrait;
 62
 63    use collab::db::{Database, NewUserParams, UserId, access_token};
 64    use collab::*;
 65
 66    const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 67
 68    async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
 69        const VERSION: usize = 1;
 70        let access_token = ::rpc::auth::random_token();
 71        let access_token_hash = hash_access_token(&access_token);
 72        let id = db
 73            .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
 74            .await?;
 75        Ok(serde_json::to_string(&AccessTokenJson {
 76            version: VERSION,
 77            id,
 78            token: access_token,
 79        })?)
 80    }
 81
 82    #[gpui::test]
 83    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
 84        let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
 85        let db = test_db.db();
 86
 87        let user = db
 88            .create_user(
 89                "example@example.com",
 90                None,
 91                false,
 92                NewUserParams {
 93                    github_login: "example".into(),
 94                    github_user_id: 1,
 95                },
 96            )
 97            .await
 98            .unwrap();
 99
100        let token = create_access_token(db, user.user_id).await.unwrap();
101        assert!(matches!(
102            verify_access_token(&token, user.user_id, db).await.unwrap(),
103            VerifyAccessTokenResult { is_valid: true }
104        ));
105
106        let old_token = create_previous_access_token(user.user_id, db)
107            .await
108            .unwrap();
109
110        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
111            .unwrap()
112            .id;
113
114        let hash = db
115            .transaction(|tx| async move {
116                Ok(access_token::Entity::find_by_id(old_token_id)
117                    .one(&*tx)
118                    .await?)
119            })
120            .await
121            .unwrap()
122            .unwrap()
123            .hash;
124        assert!(hash.starts_with("$scrypt$"));
125
126        assert!(matches!(
127            verify_access_token(&old_token, user.user_id, db)
128                .await
129                .unwrap(),
130            VerifyAccessTokenResult { is_valid: true }
131        ));
132
133        let hash = db
134            .transaction(|tx| async move {
135                Ok(access_token::Entity::find_by_id(old_token_id)
136                    .one(&*tx)
137                    .await?)
138            })
139            .await
140            .unwrap()
141            .unwrap()
142            .hash;
143        assert!(hash.starts_with("$sha256$"));
144
145        assert!(matches!(
146            verify_access_token(&old_token, user.user_id, db)
147                .await
148                .unwrap(),
149            VerifyAccessTokenResult { is_valid: true }
150        ));
151
152        assert!(matches!(
153            verify_access_token(&token, user.user_id, db).await.unwrap(),
154            VerifyAccessTokenResult { is_valid: true }
155        ));
156    }
157
158    async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result<String> {
159        let access_token = collab::auth::random_token();
160        let access_token_hash = previous_hash_access_token(&access_token)?;
161        let id = db
162            .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
163            .await?;
164        Ok(serde_json::to_string(&AccessTokenJson {
165            version: 1,
166            id,
167            token: access_token,
168        })?)
169    }
170
171    #[expect(clippy::result_large_err)]
172    fn previous_hash_access_token(token: &str) -> Result<String> {
173        // Avoid slow hashing in debug mode.
174        let params = if cfg!(debug_assertions) {
175            scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
176        } else {
177            scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
178        };
179
180        Ok(Scrypt
181            .hash_password_customized(
182                token.as_bytes(),
183                None,
184                None,
185                params,
186                &SaltString::generate(PasswordHashRngCompat::new()),
187            )
188            .map_err(anyhow::Error::new)?
189            .to_string())
190    }
191
192    // TODO: remove once we password_hash v0.6 is released.
193    struct PasswordHashRngCompat(rand::rngs::ThreadRng);
194
195    impl PasswordHashRngCompat {
196        fn new() -> Self {
197            Self(rand::rng())
198        }
199    }
200
201    impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
202        fn next_u32(&mut self) -> u32 {
203            self.0.next_u32()
204        }
205
206        fn next_u64(&mut self) -> u64 {
207            self.0.next_u64()
208        }
209
210        fn fill_bytes(&mut self, dest: &mut [u8]) {
211            self.0.fill_bytes(dest);
212        }
213
214        fn try_fill_bytes(
215            &mut self,
216            dest: &mut [u8],
217        ) -> Result<(), scrypt::password_hash::rand_core::Error> {
218            self.fill_bytes(dest);
219            Ok(())
220        }
221    }
222
223    impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
224}