rate_limiter.rs

  1use crate::{db::UserId, executor::Executor, Database, Error, Result};
  2use anyhow::anyhow;
  3use chrono::{DateTime, Duration, Utc};
  4use dashmap::{DashMap, DashSet};
  5use sea_orm::prelude::DateTimeUtc;
  6use std::sync::Arc;
  7use util::ResultExt;
  8
  9pub trait RateLimit: 'static {
 10    fn capacity() -> usize;
 11    fn refill_duration() -> Duration;
 12    fn db_name() -> &'static str;
 13}
 14
 15/// Used to enforce per-user rate limits
 16pub struct RateLimiter {
 17    buckets: DashMap<(UserId, String), RateBucket>,
 18    dirty_buckets: DashSet<(UserId, String)>,
 19    db: Arc<Database>,
 20}
 21
 22impl RateLimiter {
 23    pub fn new(db: Arc<Database>) -> Self {
 24        RateLimiter {
 25            buckets: DashMap::new(),
 26            dirty_buckets: DashSet::new(),
 27            db,
 28        }
 29    }
 30
 31    /// Spawns a new task that periodically saves rate limit data to the database.
 32    pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
 33        const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
 34
 35        executor.clone().spawn_detached(async move {
 36            loop {
 37                executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
 38                rate_limiter.save().await.log_err();
 39            }
 40        });
 41    }
 42
 43    /// Returns an error if the user has exceeded the specified `RateLimit`.
 44    /// Attempts to read the from the database if no cached RateBucket currently exists.
 45    pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
 46        self.check_internal::<T>(user_id, Utc::now()).await
 47    }
 48
 49    async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
 50        let bucket_key = (user_id, T::db_name().to_string());
 51
 52        // Attempt to fetch the bucket from the database if it hasn't been cached.
 53        // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
 54        // but this enforces limits across restarts so long as the database is reachable.
 55        if !self.buckets.contains_key(&bucket_key) {
 56            if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
 57                self.buckets.insert(bucket_key.clone(), bucket);
 58                self.dirty_buckets.insert(bucket_key.clone());
 59            }
 60        }
 61
 62        let mut bucket = self
 63            .buckets
 64            .entry(bucket_key.clone())
 65            .or_insert_with(|| RateBucket::new::<T>(now));
 66
 67        if bucket.value_mut().allow(now) {
 68            self.dirty_buckets.insert(bucket_key);
 69            Ok(())
 70        } else {
 71            Err(anyhow!("rate limit exceeded"))?
 72        }
 73    }
 74
 75    async fn load_bucket<T: RateLimit>(
 76        &self,
 77        user_id: UserId,
 78    ) -> Result<Option<RateBucket>, Error> {
 79        Ok(self
 80            .db
 81            .get_rate_bucket(user_id, T::db_name())
 82            .await?
 83            .map(|saved_bucket| {
 84                RateBucket::from_db::<T>(
 85                    saved_bucket.token_count as usize,
 86                    DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
 87                )
 88            }))
 89    }
 90
 91    pub async fn save(&self) -> Result<()> {
 92        let mut buckets = Vec::new();
 93        self.dirty_buckets.retain(|key| {
 94            if let Some(bucket) = self.buckets.get(&key) {
 95                buckets.push(crate::db::rate_buckets::Model {
 96                    user_id: key.0,
 97                    rate_limit_name: key.1.clone(),
 98                    token_count: bucket.token_count as i32,
 99                    last_refill: bucket.last_refill.naive_utc(),
100                });
101            }
102            false
103        });
104
105        match self.db.save_rate_buckets(&buckets).await {
106            Ok(()) => Ok(()),
107            Err(err) => {
108                for bucket in buckets {
109                    self.dirty_buckets
110                        .insert((bucket.user_id, bucket.rate_limit_name));
111                }
112                Err(err)
113            }
114        }
115    }
116}
117
118#[derive(Clone)]
119struct RateBucket {
120    capacity: usize,
121    token_count: usize,
122    refill_time_per_token: Duration,
123    last_refill: DateTimeUtc,
124}
125
126impl RateBucket {
127    fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
128        Self {
129            capacity: T::capacity(),
130            token_count: T::capacity(),
131            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
132            last_refill: now,
133        }
134    }
135
136    fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
137        Self {
138            capacity: T::capacity(),
139            token_count,
140            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
141            last_refill,
142        }
143    }
144
145    fn allow(&mut self, now: DateTimeUtc) -> bool {
146        self.refill(now);
147        if self.token_count > 0 {
148            self.token_count -= 1;
149            true
150        } else {
151            false
152        }
153    }
154
155    fn refill(&mut self, now: DateTimeUtc) {
156        let elapsed = now - self.last_refill;
157        if elapsed >= self.refill_time_per_token {
158            let new_tokens =
159                elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
160            self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
161
162            let unused_refill_time = Duration::milliseconds(
163                elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
164            );
165            self.last_refill = now - unused_refill_time;
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::db::{NewUserParams, TestDb};
174    use gpui::TestAppContext;
175
176    #[gpui::test]
177    async fn test_rate_limiter(cx: &mut TestAppContext) {
178        let test_db = TestDb::sqlite(cx.executor().clone());
179        let db = test_db.db().clone();
180        let user_1 = db
181            .create_user(
182                "user-1@zed.dev",
183                false,
184                NewUserParams {
185                    github_login: "user-1".into(),
186                    github_user_id: 1,
187                },
188            )
189            .await
190            .unwrap()
191            .user_id;
192        let user_2 = db
193            .create_user(
194                "user-2@zed.dev",
195                false,
196                NewUserParams {
197                    github_login: "user-2".into(),
198                    github_user_id: 2,
199                },
200            )
201            .await
202            .unwrap()
203            .user_id;
204
205        let mut now = Utc::now();
206
207        let rate_limiter = RateLimiter::new(db.clone());
208
209        // User 1 can access resource A two times before being rate-limited.
210        rate_limiter
211            .check_internal::<RateLimitA>(user_1, now)
212            .await
213            .unwrap();
214        rate_limiter
215            .check_internal::<RateLimitA>(user_1, now)
216            .await
217            .unwrap();
218        rate_limiter
219            .check_internal::<RateLimitA>(user_1, now)
220            .await
221            .unwrap_err();
222
223        // User 2 can access resource A and user 1 can access resource B.
224        rate_limiter
225            .check_internal::<RateLimitB>(user_2, now)
226            .await
227            .unwrap();
228        rate_limiter
229            .check_internal::<RateLimitB>(user_1, now)
230            .await
231            .unwrap();
232
233        // After 1.5s, user 1 can make another request before being rate-limited again.
234        now += Duration::milliseconds(1500);
235        rate_limiter
236            .check_internal::<RateLimitA>(user_1, now)
237            .await
238            .unwrap();
239        rate_limiter
240            .check_internal::<RateLimitA>(user_1, now)
241            .await
242            .unwrap_err();
243
244        // After 500ms, user 1 can make another request before being rate-limited again.
245        now += Duration::milliseconds(500);
246        rate_limiter
247            .check_internal::<RateLimitA>(user_1, now)
248            .await
249            .unwrap();
250        rate_limiter
251            .check_internal::<RateLimitA>(user_1, now)
252            .await
253            .unwrap_err();
254
255        rate_limiter.save().await.unwrap();
256
257        // Rate limits are reloaded from the database, so user A is still rate-limited
258        // for resource A.
259        let rate_limiter = RateLimiter::new(db.clone());
260        rate_limiter
261            .check_internal::<RateLimitA>(user_1, now)
262            .await
263            .unwrap_err();
264
265        // After 1s, user 1 can make another request before being rate-limited again.
266        now += Duration::seconds(1);
267        rate_limiter
268            .check_internal::<RateLimitA>(user_1, now)
269            .await
270            .unwrap();
271        rate_limiter
272            .check_internal::<RateLimitA>(user_1, now)
273            .await
274            .unwrap_err();
275    }
276
277    struct RateLimitA;
278
279    impl RateLimit for RateLimitA {
280        fn capacity() -> usize {
281            2
282        }
283
284        fn refill_duration() -> Duration {
285            Duration::seconds(2)
286        }
287
288        fn db_name() -> &'static str {
289            "rate-limit-a"
290        }
291    }
292
293    struct RateLimitB;
294
295    impl RateLimit for RateLimitB {
296        fn capacity() -> usize {
297            10
298        }
299
300        fn refill_duration() -> Duration {
301            Duration::seconds(3)
302        }
303
304        fn db_name() -> &'static str {
305            "rate-limit-b"
306        }
307    }
308}