users.rs

  1use chrono::NaiveDateTime;
  2
  3use super::*;
  4
  5impl Database {
  6    /// Creates a new user.
  7    pub async fn create_user(
  8        &self,
  9        email_address: &str,
 10        admin: bool,
 11        params: NewUserParams,
 12    ) -> Result<NewUserResult> {
 13        self.transaction(|tx| async {
 14            let tx = tx;
 15            let user = user::Entity::insert(user::ActiveModel {
 16                email_address: ActiveValue::set(Some(email_address.into())),
 17                github_login: ActiveValue::set(params.github_login.clone()),
 18                github_user_id: ActiveValue::set(Some(params.github_user_id)),
 19                admin: ActiveValue::set(admin),
 20                metrics_id: ActiveValue::set(Uuid::new_v4()),
 21                ..Default::default()
 22            })
 23            .on_conflict(
 24                OnConflict::column(user::Column::GithubLogin)
 25                    .update_columns([
 26                        user::Column::Admin,
 27                        user::Column::EmailAddress,
 28                        user::Column::GithubUserId,
 29                    ])
 30                    .to_owned(),
 31            )
 32            .exec_with_returning(&*tx)
 33            .await?;
 34
 35            Ok(NewUserResult {
 36                user_id: user.id,
 37                metrics_id: user.metrics_id.to_string(),
 38                signup_device_id: None,
 39                inviting_user_id: None,
 40            })
 41        })
 42        .await
 43    }
 44
 45    /// Returns a user by ID. There are no access checks here, so this should only be used internally.
 46    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<user::Model>> {
 47        self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) })
 48            .await
 49    }
 50
 51    /// Returns all users by ID. There are no access checks here, so this should only be used internally.
 52    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
 53        if ids.len() >= 10000_usize {
 54            return Err(anyhow!("too many users"))?;
 55        }
 56        self.transaction(|tx| async {
 57            let tx = tx;
 58            Ok(user::Entity::find()
 59                .filter(user::Column::Id.is_in(ids.iter().copied()))
 60                .all(&*tx)
 61                .await?)
 62        })
 63        .await
 64    }
 65
 66    /// Returns a user by email address. There are no access checks here, so this should only be used internally.
 67    pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
 68        self.transaction(|tx| async move {
 69            Ok(user::Entity::find()
 70                .filter(user::Column::EmailAddress.eq(email))
 71                .one(&*tx)
 72                .await?)
 73        })
 74        .await
 75    }
 76
 77    /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally.
 78    pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result<Option<User>> {
 79        self.transaction(|tx| async move {
 80            Ok(user::Entity::find()
 81                .filter(user::Column::GithubUserId.eq(github_user_id))
 82                .one(&*tx)
 83                .await?)
 84        })
 85        .await
 86    }
 87
 88    /// Returns a user by GitHub login. There are no access checks here, so this should only be used internally.
 89    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 90        self.transaction(|tx| async move {
 91            Ok(user::Entity::find()
 92                .filter(user::Column::GithubLogin.eq(github_login))
 93                .one(&*tx)
 94                .await?)
 95        })
 96        .await
 97    }
 98
 99    pub async fn get_or_create_user_by_github_account(
100        &self,
101        github_login: &str,
102        github_user_id: Option<i32>,
103        github_email: Option<&str>,
104        github_user_created_at: Option<DateTimeUtc>,
105        initial_channel_id: Option<ChannelId>,
106    ) -> Result<User> {
107        self.transaction(|tx| async move {
108            self.get_or_create_user_by_github_account_tx(
109                github_login,
110                github_user_id,
111                github_email,
112                github_user_created_at.map(|created_at| created_at.naive_utc()),
113                initial_channel_id,
114                &tx,
115            )
116            .await
117        })
118        .await
119    }
120
121    pub async fn get_or_create_user_by_github_account_tx(
122        &self,
123        github_login: &str,
124        github_user_id: Option<i32>,
125        github_email: Option<&str>,
126        github_user_created_at: Option<NaiveDateTime>,
127        initial_channel_id: Option<ChannelId>,
128        tx: &DatabaseTransaction,
129    ) -> Result<User> {
130        if let Some(github_user_id) = github_user_id {
131            if let Some(user_by_github_user_id) = user::Entity::find()
132                .filter(user::Column::GithubUserId.eq(github_user_id))
133                .one(tx)
134                .await?
135            {
136                let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
137                user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
138                if github_user_created_at.is_some() {
139                    user_by_github_user_id.github_user_created_at =
140                        ActiveValue::set(github_user_created_at);
141                }
142                Ok(user_by_github_user_id.update(tx).await?)
143            } else if let Some(user_by_github_login) = user::Entity::find()
144                .filter(user::Column::GithubLogin.eq(github_login))
145                .one(tx)
146                .await?
147            {
148                let mut user_by_github_login = user_by_github_login.into_active_model();
149                user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
150                if github_user_created_at.is_some() {
151                    user_by_github_login.github_user_created_at =
152                        ActiveValue::set(github_user_created_at);
153                }
154                Ok(user_by_github_login.update(tx).await?)
155            } else {
156                let user = user::Entity::insert(user::ActiveModel {
157                    email_address: ActiveValue::set(github_email.map(|email| email.into())),
158                    github_login: ActiveValue::set(github_login.into()),
159                    github_user_id: ActiveValue::set(Some(github_user_id)),
160                    github_user_created_at: ActiveValue::set(github_user_created_at),
161                    admin: ActiveValue::set(false),
162                    invite_count: ActiveValue::set(0),
163                    invite_code: ActiveValue::set(None),
164                    metrics_id: ActiveValue::set(Uuid::new_v4()),
165                    ..Default::default()
166                })
167                .exec_with_returning(tx)
168                .await?;
169                if let Some(channel_id) = initial_channel_id {
170                    channel_member::Entity::insert(channel_member::ActiveModel {
171                        id: ActiveValue::NotSet,
172                        channel_id: ActiveValue::Set(channel_id),
173                        user_id: ActiveValue::Set(user.id),
174                        accepted: ActiveValue::Set(true),
175                        role: ActiveValue::Set(ChannelRole::Guest),
176                    })
177                    .exec(tx)
178                    .await?;
179                }
180                Ok(user)
181            }
182        } else {
183            let user = user::Entity::find()
184                .filter(user::Column::GithubLogin.eq(github_login))
185                .one(tx)
186                .await?
187                .ok_or_else(|| anyhow!("no such user {}", github_login))?;
188            Ok(user)
189        }
190    }
191
192    /// get_all_users returns the next page of users. To get more call again with
193    /// the same limit and the page incremented by 1.
194    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
195        self.transaction(|tx| async move {
196            Ok(user::Entity::find()
197                .order_by_asc(user::Column::GithubLogin)
198                .limit(limit as u64)
199                .offset(page as u64 * limit as u64)
200                .all(&*tx)
201                .await?)
202        })
203        .await
204    }
205
206    /// Returns the metrics id for the user.
207    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
208        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
209        enum QueryAs {
210            MetricsId,
211        }
212
213        self.transaction(|tx| async move {
214            let metrics_id: Uuid = user::Entity::find_by_id(id)
215                .select_only()
216                .column(user::Column::MetricsId)
217                .into_values::<_, QueryAs>()
218                .one(&*tx)
219                .await?
220                .ok_or_else(|| anyhow!("could not find user"))?;
221            Ok(metrics_id.to_string())
222        })
223        .await
224    }
225
226    /// Sets "connected_once" on the user for analytics.
227    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
228        self.transaction(|tx| async move {
229            user::Entity::update_many()
230                .filter(user::Column::Id.eq(id))
231                .set(user::ActiveModel {
232                    connected_once: ActiveValue::set(connected_once),
233                    ..Default::default()
234                })
235                .exec(&*tx)
236                .await?;
237            Ok(())
238        })
239        .await
240    }
241
242    /// Sets "accepted_tos_at" on the user to the given timestamp.
243    pub async fn set_user_accepted_tos_at(
244        &self,
245        id: UserId,
246        accepted_tos_at: Option<DateTime>,
247    ) -> Result<()> {
248        self.transaction(|tx| async move {
249            user::Entity::update_many()
250                .filter(user::Column::Id.eq(id))
251                .set(user::ActiveModel {
252                    accepted_tos_at: ActiveValue::set(accepted_tos_at),
253                    ..Default::default()
254                })
255                .exec(&*tx)
256                .await?;
257            Ok(())
258        })
259        .await
260    }
261
262    /// hard delete the user.
263    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
264        self.transaction(|tx| async move {
265            access_token::Entity::delete_many()
266                .filter(access_token::Column::UserId.eq(id))
267                .exec(&*tx)
268                .await?;
269            user::Entity::delete_by_id(id).exec(&*tx).await?;
270            Ok(())
271        })
272        .await
273    }
274
275    /// Find users where github_login ILIKE name_query.
276    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
277        self.transaction(|tx| async {
278            let tx = tx;
279            let like_string = Self::fuzzy_like_string(name_query);
280            let query = "
281                SELECT users.*
282                FROM users
283                WHERE github_login ILIKE $1
284                ORDER BY github_login <-> $2
285                LIMIT $3
286            ";
287
288            Ok(user::Entity::find()
289                .from_raw_sql(Statement::from_sql_and_values(
290                    self.pool.get_database_backend(),
291                    query,
292                    vec![like_string.into(), name_query.into(), limit.into()],
293                ))
294                .all(&*tx)
295                .await?)
296        })
297        .await
298    }
299
300    /// fuzzy_like_string creates a string for matching in-order using fuzzy_search_users.
301    /// e.g. "cir" would become "%c%i%r%"
302    pub fn fuzzy_like_string(string: &str) -> String {
303        let mut result = String::with_capacity(string.len() * 2 + 1);
304        for c in string.chars() {
305            if c.is_alphanumeric() {
306                result.push('%');
307                result.push(c);
308            }
309        }
310        result.push('%');
311        result
312    }
313
314    /// Creates a new feature flag.
315    pub async fn create_user_flag(&self, flag: &str, enabled_for_all: bool) -> Result<FlagId> {
316        self.transaction(|tx| async move {
317            let flag = feature_flag::Entity::insert(feature_flag::ActiveModel {
318                flag: ActiveValue::set(flag.to_string()),
319                enabled_for_all: ActiveValue::set(enabled_for_all),
320                ..Default::default()
321            })
322            .exec(&*tx)
323            .await?
324            .last_insert_id;
325
326            Ok(flag)
327        })
328        .await
329    }
330
331    /// Add the given user to the feature flag
332    pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> {
333        self.transaction(|tx| async move {
334            user_feature::Entity::insert(user_feature::ActiveModel {
335                user_id: ActiveValue::set(user),
336                feature_id: ActiveValue::set(flag),
337            })
338            .exec(&*tx)
339            .await?;
340
341            Ok(())
342        })
343        .await
344    }
345
346    /// Returns the active flags for the user.
347    pub async fn get_user_flags(&self, user: UserId) -> Result<Vec<String>> {
348        self.transaction(|tx| async move {
349            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
350            enum QueryAs {
351                Flag,
352            }
353
354            let flags_enabled_for_all = feature_flag::Entity::find()
355                .filter(feature_flag::Column::EnabledForAll.eq(true))
356                .select_only()
357                .column(feature_flag::Column::Flag)
358                .into_values::<_, QueryAs>()
359                .all(&*tx)
360                .await?;
361
362            let flags_enabled_for_user = user::Model {
363                id: user,
364                ..Default::default()
365            }
366            .find_linked(user::UserFlags)
367            .select_only()
368            .column(feature_flag::Column::Flag)
369            .into_values::<_, QueryAs>()
370            .all(&*tx)
371            .await?;
372
373            let mut all_flags = HashSet::from_iter(flags_enabled_for_all);
374            all_flags.extend(flags_enabled_for_user);
375
376            Ok(all_flags.into_iter().collect())
377        })
378        .await
379    }
380}