rate_limiter.rs

  1use crate::{db::UserId, executor::Executor, Database, Error, Result};
  2use anyhow::anyhow;
  3use chrono::{DateTime, Duration, Utc};
  4use dashmap::{DashMap, DashSet};
  5use sea_orm::prelude::DateTimeUtc;
  6use std::sync::Arc;
  7use util::ResultExt;
  8
  9pub trait RateLimit: 'static {
 10    fn capacity() -> usize;
 11    fn refill_duration() -> Duration;
 12    fn db_name() -> &'static str;
 13}
 14
 15/// Used to enforce per-user rate limits
 16pub struct RateLimiter {
 17    buckets: DashMap<(UserId, String), RateBucket>,
 18    dirty_buckets: DashSet<(UserId, String)>,
 19    db: Arc<Database>,
 20}
 21
 22impl RateLimiter {
 23    pub fn new(db: Arc<Database>) -> Self {
 24        RateLimiter {
 25            buckets: DashMap::new(),
 26            dirty_buckets: DashSet::new(),
 27            db,
 28        }
 29    }
 30
 31    /// Spawns a new task that periodically saves rate limit data to the database.
 32    pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
 33        const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
 34
 35        executor.clone().spawn_detached(async move {
 36            loop {
 37                executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
 38                rate_limiter.save().await.log_err();
 39            }
 40        });
 41    }
 42
 43    /// Returns an error if the user has exceeded the specified `RateLimit`.
 44    /// Attempts to read the from the database if no cached RateBucket currently exists.
 45    pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
 46        self.check_internal::<T>(user_id, Utc::now()).await
 47    }
 48
 49    async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
 50        let bucket_key = (user_id, T::db_name().to_string());
 51
 52        // Attempt to fetch the bucket from the database if it hasn't been cached.
 53        // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
 54        // but this enforces limits across restarts so long as the database is reachable.
 55        if !self.buckets.contains_key(&bucket_key) {
 56            if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
 57                self.buckets.insert(bucket_key.clone(), bucket);
 58                self.dirty_buckets.insert(bucket_key.clone());
 59            }
 60        }
 61
 62        let mut bucket = self
 63            .buckets
 64            .entry(bucket_key.clone())
 65            .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
 66
 67        if bucket.value_mut().allow(now) {
 68            self.dirty_buckets.insert(bucket_key);
 69            Ok(())
 70        } else {
 71            Err(anyhow!("rate limit exceeded"))?
 72        }
 73    }
 74
 75    async fn load_bucket<K: RateLimit>(
 76        &self,
 77        user_id: UserId,
 78    ) -> Result<Option<RateBucket>, Error> {
 79        Ok(self
 80            .db
 81            .get_rate_bucket(user_id, K::db_name())
 82            .await?
 83            .map(|saved_bucket| RateBucket {
 84                capacity: K::capacity(),
 85                refill_time_per_token: K::refill_duration(),
 86                token_count: saved_bucket.token_count as usize,
 87                last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
 88            }))
 89    }
 90
 91    pub async fn save(&self) -> Result<()> {
 92        let mut buckets = Vec::new();
 93        self.dirty_buckets.retain(|key| {
 94            if let Some(bucket) = self.buckets.get(&key) {
 95                buckets.push(crate::db::rate_buckets::Model {
 96                    user_id: key.0,
 97                    rate_limit_name: key.1.clone(),
 98                    token_count: bucket.token_count as i32,
 99                    last_refill: bucket.last_refill.naive_utc(),
100                });
101            }
102            false
103        });
104
105        match self.db.save_rate_buckets(&buckets).await {
106            Ok(()) => Ok(()),
107            Err(err) => {
108                for bucket in buckets {
109                    self.dirty_buckets
110                        .insert((bucket.user_id, bucket.rate_limit_name));
111                }
112                Err(err)
113            }
114        }
115    }
116}
117
118#[derive(Clone)]
119struct RateBucket {
120    capacity: usize,
121    token_count: usize,
122    refill_time_per_token: Duration,
123    last_refill: DateTimeUtc,
124}
125
126impl RateBucket {
127    fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
128        RateBucket {
129            capacity,
130            token_count: capacity,
131            refill_time_per_token: refill_duration / capacity as i32,
132            last_refill: now,
133        }
134    }
135
136    fn allow(&mut self, now: DateTimeUtc) -> bool {
137        self.refill(now);
138        if self.token_count > 0 {
139            self.token_count -= 1;
140            true
141        } else {
142            false
143        }
144    }
145
146    fn refill(&mut self, now: DateTimeUtc) {
147        let elapsed = now - self.last_refill;
148        if elapsed >= self.refill_time_per_token {
149            let new_tokens =
150                elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
151
152            self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
153            self.last_refill = now;
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::db::{NewUserParams, TestDb};
162    use gpui::TestAppContext;
163
164    #[gpui::test]
165    async fn test_rate_limiter(cx: &mut TestAppContext) {
166        let test_db = TestDb::sqlite(cx.executor().clone());
167        let db = test_db.db().clone();
168        let user_1 = db
169            .create_user(
170                "user-1@zed.dev",
171                false,
172                NewUserParams {
173                    github_login: "user-1".into(),
174                    github_user_id: 1,
175                },
176            )
177            .await
178            .unwrap()
179            .user_id;
180        let user_2 = db
181            .create_user(
182                "user-2@zed.dev",
183                false,
184                NewUserParams {
185                    github_login: "user-2".into(),
186                    github_user_id: 2,
187                },
188            )
189            .await
190            .unwrap()
191            .user_id;
192
193        let mut now = Utc::now();
194
195        let rate_limiter = RateLimiter::new(db.clone());
196
197        // User 1 can access resource A two times before being rate-limited.
198        rate_limiter
199            .check_internal::<RateLimitA>(user_1, now)
200            .await
201            .unwrap();
202        rate_limiter
203            .check_internal::<RateLimitA>(user_1, now)
204            .await
205            .unwrap();
206        rate_limiter
207            .check_internal::<RateLimitA>(user_1, now)
208            .await
209            .unwrap_err();
210
211        // User 2 can access resource A and user 1 can access resource B.
212        rate_limiter
213            .check_internal::<RateLimitB>(user_2, now)
214            .await
215            .unwrap();
216        rate_limiter
217            .check_internal::<RateLimitB>(user_1, now)
218            .await
219            .unwrap();
220
221        // After one second, user 1 can make another request before being rate-limited again.
222        now += Duration::seconds(1);
223        rate_limiter
224            .check_internal::<RateLimitA>(user_1, now)
225            .await
226            .unwrap();
227        rate_limiter
228            .check_internal::<RateLimitA>(user_1, now)
229            .await
230            .unwrap_err();
231
232        rate_limiter.save().await.unwrap();
233
234        // Rate limits are reloaded from the database, so user A is still rate-limited
235        // for resource A.
236        let rate_limiter = RateLimiter::new(db.clone());
237        rate_limiter
238            .check_internal::<RateLimitA>(user_1, now)
239            .await
240            .unwrap_err();
241    }
242
243    struct RateLimitA;
244
245    impl RateLimit for RateLimitA {
246        fn capacity() -> usize {
247            2
248        }
249
250        fn refill_duration() -> Duration {
251            Duration::seconds(2)
252        }
253
254        fn db_name() -> &'static str {
255            "rate-limit-a"
256        }
257    }
258
259    struct RateLimitB;
260
261    impl RateLimit for RateLimitB {
262        fn capacity() -> usize {
263            10
264        }
265
266        fn refill_duration() -> Duration {
267            Duration::seconds(3)
268        }
269
270        fn db_name() -> &'static str {
271            "rate-limit-b"
272        }
273    }
274}