collab_tests.rs

  1use call::Room;
  2use client::ChannelId;
  3use gpui::{Entity, TestAppContext};
  4
  5mod agent_sharing_tests;
  6mod channel_buffer_tests;
  7mod channel_guest_tests;
  8mod channel_tests;
  9mod db_tests;
 10mod editor_tests;
 11mod following_tests;
 12mod git_tests;
 13mod integration_tests;
 14mod notification_tests;
 15mod random_channel_buffer_tests;
 16mod random_project_collaboration_tests;
 17mod randomized_test_helpers;
 18mod remote_editing_collaboration_tests;
 19mod test_server;
 20
 21pub use randomized_test_helpers::{
 22    RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan,
 23};
 24pub use test_server::{TestClient, TestServer};
 25
 26#[derive(Debug, Eq, PartialEq)]
 27struct RoomParticipants {
 28    remote: Vec<String>,
 29    pending: Vec<String>,
 30}
 31
 32fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomParticipants {
 33    room.read_with(cx, |room, _| {
 34        let mut remote = room
 35            .remote_participants()
 36            .values()
 37            .map(|participant| participant.user.github_login.clone().to_string())
 38            .collect::<Vec<_>>();
 39        let mut pending = room
 40            .pending_participants()
 41            .iter()
 42            .map(|user| user.github_login.clone().to_string())
 43            .collect::<Vec<_>>();
 44        remote.sort();
 45        pending.sort();
 46        RoomParticipants { remote, pending }
 47    })
 48}
 49
 50fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
 51    cx.read(|cx| room.read(cx).channel_id())
 52}
 53
 54mod auth_token_tests {
 55    use collab::auth::{
 56        AccessTokenJson, MAX_ACCESS_TOKENS_TO_STORE, VerifyAccessTokenResult, create_access_token,
 57        verify_access_token,
 58    };
 59    use rand::prelude::*;
 60    use scrypt::Scrypt;
 61    use scrypt::password_hash::{PasswordHasher, SaltString};
 62    use sea_orm::EntityTrait;
 63
 64    use collab::db::{Database, NewUserParams, UserId, access_token};
 65    use collab::*;
 66
 67    #[gpui::test]
 68    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
 69        let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
 70        let db = test_db.db();
 71
 72        let user = db
 73            .create_user(
 74                "example@example.com",
 75                None,
 76                false,
 77                NewUserParams {
 78                    github_login: "example".into(),
 79                    github_user_id: 1,
 80                },
 81            )
 82            .await
 83            .unwrap();
 84
 85        let token = create_access_token(db, user.user_id, None).await.unwrap();
 86        assert!(matches!(
 87            verify_access_token(&token, user.user_id, db).await.unwrap(),
 88            VerifyAccessTokenResult {
 89                is_valid: true,
 90                impersonator_id: None,
 91            }
 92        ));
 93
 94        let old_token = create_previous_access_token(user.user_id, None, db)
 95            .await
 96            .unwrap();
 97
 98        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
 99            .unwrap()
100            .id;
101
102        let hash = db
103            .transaction(|tx| async move {
104                Ok(access_token::Entity::find_by_id(old_token_id)
105                    .one(&*tx)
106                    .await?)
107            })
108            .await
109            .unwrap()
110            .unwrap()
111            .hash;
112        assert!(hash.starts_with("$scrypt$"));
113
114        assert!(matches!(
115            verify_access_token(&old_token, user.user_id, db)
116                .await
117                .unwrap(),
118            VerifyAccessTokenResult {
119                is_valid: true,
120                impersonator_id: None,
121            }
122        ));
123
124        let hash = db
125            .transaction(|tx| async move {
126                Ok(access_token::Entity::find_by_id(old_token_id)
127                    .one(&*tx)
128                    .await?)
129            })
130            .await
131            .unwrap()
132            .unwrap()
133            .hash;
134        assert!(hash.starts_with("$sha256$"));
135
136        assert!(matches!(
137            verify_access_token(&old_token, user.user_id, db)
138                .await
139                .unwrap(),
140            VerifyAccessTokenResult {
141                is_valid: true,
142                impersonator_id: None,
143            }
144        ));
145
146        assert!(matches!(
147            verify_access_token(&token, user.user_id, db).await.unwrap(),
148            VerifyAccessTokenResult {
149                is_valid: true,
150                impersonator_id: None,
151            }
152        ));
153    }
154
155    async fn create_previous_access_token(
156        user_id: UserId,
157        impersonated_user_id: Option<UserId>,
158        db: &Database,
159    ) -> Result<String> {
160        let access_token = collab::auth::random_token();
161        let access_token_hash = previous_hash_access_token(&access_token)?;
162        let id = db
163            .create_access_token(
164                user_id,
165                impersonated_user_id,
166                &access_token_hash,
167                MAX_ACCESS_TOKENS_TO_STORE,
168            )
169            .await?;
170        Ok(serde_json::to_string(&AccessTokenJson {
171            version: 1,
172            id,
173            token: access_token,
174        })?)
175    }
176
177    #[expect(clippy::result_large_err)]
178    fn previous_hash_access_token(token: &str) -> Result<String> {
179        // Avoid slow hashing in debug mode.
180        let params = if cfg!(debug_assertions) {
181            scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
182        } else {
183            scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
184        };
185
186        Ok(Scrypt
187            .hash_password_customized(
188                token.as_bytes(),
189                None,
190                None,
191                params,
192                &SaltString::generate(PasswordHashRngCompat::new()),
193            )
194            .map_err(anyhow::Error::new)?
195            .to_string())
196    }
197
198    // TODO: remove once we password_hash v0.6 is released.
199    struct PasswordHashRngCompat(rand::rngs::ThreadRng);
200
201    impl PasswordHashRngCompat {
202        fn new() -> Self {
203            Self(rand::rng())
204        }
205    }
206
207    impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
208        fn next_u32(&mut self) -> u32 {
209            self.0.next_u32()
210        }
211
212        fn next_u64(&mut self) -> u64 {
213            self.0.next_u64()
214        }
215
216        fn fill_bytes(&mut self, dest: &mut [u8]) {
217            self.0.fill_bytes(dest);
218        }
219
220        fn try_fill_bytes(
221            &mut self,
222            dest: &mut [u8],
223        ) -> Result<(), scrypt::password_hash::rand_core::Error> {
224            self.fill_bytes(dest);
225            Ok(())
226        }
227    }
228
229    impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
230}