1use super::*;
2use crate::db::tables::rate_buckets;
3use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
4
5impl Database {
6 /// Saves the rate limit for the given user and rate limit name if the last_refill is later
7 /// than the currently saved timestamp.
8 pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
9 if buckets.is_empty() {
10 return Ok(());
11 }
12
13 self.transaction(|tx| async move {
14 rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
15 rate_buckets::ActiveModel {
16 user_id: ActiveValue::Set(bucket.user_id),
17 rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
18 token_count: ActiveValue::Set(bucket.token_count),
19 last_refill: ActiveValue::Set(bucket.last_refill),
20 }
21 }))
22 .on_conflict(
23 OnConflict::columns([
24 rate_buckets::Column::UserId,
25 rate_buckets::Column::RateLimitName,
26 ])
27 .update_columns([
28 rate_buckets::Column::TokenCount,
29 rate_buckets::Column::LastRefill,
30 ])
31 .to_owned(),
32 )
33 .exec(&*tx)
34 .await?;
35
36 Ok(())
37 })
38 .await
39 }
40
41 /// Retrieves the rate limit for the given user and rate limit name.
42 pub async fn get_rate_bucket(
43 &self,
44 user_id: UserId,
45 rate_limit_name: &str,
46 ) -> Result<Option<rate_buckets::Model>> {
47 self.transaction(|tx| async move {
48 let rate_limit = rate_buckets::Entity::find()
49 .filter(rate_buckets::Column::UserId.eq(user_id))
50 .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
51 .one(&*tx)
52 .await?;
53
54 Ok(rate_limit)
55 })
56 .await
57 }
58}