rate_limiter.rs

  1use crate::{Database, Error, Result, db::UserId, executor::Executor};
  2use chrono::{DateTime, Duration, Utc};
  3use dashmap::{DashMap, DashSet};
  4use rpc::ErrorCodeExt;
  5use sea_orm::prelude::DateTimeUtc;
  6use std::sync::Arc;
  7use util::ResultExt;
  8
  9pub trait RateLimit: Send + Sync {
 10    fn capacity(&self) -> usize;
 11    fn refill_duration(&self) -> Duration;
 12    fn db_name(&self) -> &'static str;
 13}
 14
 15/// Used to enforce per-user rate limits
 16pub struct RateLimiter {
 17    buckets: DashMap<(UserId, String), RateBucket>,
 18    dirty_buckets: DashSet<(UserId, String)>,
 19    db: Arc<Database>,
 20}
 21
 22impl RateLimiter {
 23    pub fn new(db: Arc<Database>) -> Self {
 24        RateLimiter {
 25            buckets: DashMap::new(),
 26            dirty_buckets: DashSet::new(),
 27            db,
 28        }
 29    }
 30
 31    /// Spawns a new task that periodically saves rate limit data to the database.
 32    pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
 33        const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
 34
 35        executor.clone().spawn_detached(async move {
 36            loop {
 37                executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
 38                rate_limiter.save().await.log_err();
 39            }
 40        });
 41    }
 42
 43    /// Returns an error if the user has exceeded the specified `RateLimit`.
 44    /// Attempts to read the from the database if no cached RateBucket currently exists.
 45    pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
 46        self.check_internal(limit, user_id, Utc::now()).await
 47    }
 48
 49    async fn check_internal(
 50        &self,
 51        limit: &dyn RateLimit,
 52        user_id: UserId,
 53        now: DateTimeUtc,
 54    ) -> Result<()> {
 55        let bucket_key = (user_id, limit.db_name().to_string());
 56
 57        // Attempt to fetch the bucket from the database if it hasn't been cached.
 58        // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
 59        // but this enforces limits across restarts so long as the database is reachable.
 60        if !self.buckets.contains_key(&bucket_key) {
 61            if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
 62                self.buckets.insert(bucket_key.clone(), bucket);
 63                self.dirty_buckets.insert(bucket_key.clone());
 64            }
 65        }
 66
 67        let mut bucket = self
 68            .buckets
 69            .entry(bucket_key.clone())
 70            .or_insert_with(|| RateBucket::new(limit, now));
 71
 72        if bucket.value_mut().allow(now) {
 73            self.dirty_buckets.insert(bucket_key);
 74            Ok(())
 75        } else {
 76            Err(rpc::proto::ErrorCode::RateLimitExceeded
 77                .message("rate limit exceeded".into())
 78                .anyhow())?
 79        }
 80    }
 81
 82    async fn load_bucket(
 83        &self,
 84        limit: &dyn RateLimit,
 85        user_id: UserId,
 86    ) -> Result<Option<RateBucket>, Error> {
 87        Ok(self
 88            .db
 89            .get_rate_bucket(user_id, limit.db_name())
 90            .await?
 91            .map(|saved_bucket| {
 92                RateBucket::from_db(
 93                    limit,
 94                    saved_bucket.token_count as usize,
 95                    DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
 96                )
 97            }))
 98    }
 99
100    pub async fn save(&self) -> Result<()> {
101        let mut buckets = Vec::new();
102        self.dirty_buckets.retain(|key| {
103            if let Some(bucket) = self.buckets.get(key) {
104                buckets.push(crate::db::rate_buckets::Model {
105                    user_id: key.0,
106                    rate_limit_name: key.1.clone(),
107                    token_count: bucket.token_count as i32,
108                    last_refill: bucket.last_refill.naive_utc(),
109                });
110            }
111            false
112        });
113
114        match self.db.save_rate_buckets(&buckets).await {
115            Ok(()) => Ok(()),
116            Err(err) => {
117                for bucket in buckets {
118                    self.dirty_buckets
119                        .insert((bucket.user_id, bucket.rate_limit_name));
120                }
121                Err(err)
122            }
123        }
124    }
125}
126
127#[derive(Clone, Debug)]
128struct RateBucket {
129    capacity: usize,
130    token_count: usize,
131    refill_time_per_token: Duration,
132    last_refill: DateTimeUtc,
133}
134
135impl RateBucket {
136    fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
137        Self {
138            capacity: limit.capacity(),
139            token_count: limit.capacity(),
140            refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
141            last_refill: now,
142        }
143    }
144
145    fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
146        Self {
147            capacity: limit.capacity(),
148            token_count,
149            refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
150            last_refill,
151        }
152    }
153
154    fn allow(&mut self, now: DateTimeUtc) -> bool {
155        self.refill(now);
156        if self.token_count > 0 {
157            self.token_count -= 1;
158            true
159        } else {
160            false
161        }
162    }
163
164    fn refill(&mut self, now: DateTimeUtc) {
165        let elapsed = now - self.last_refill;
166        if elapsed >= self.refill_time_per_token {
167            let new_tokens =
168                elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
169            self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
170
171            let unused_refill_time = Duration::milliseconds(
172                elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
173            );
174            self.last_refill = now - unused_refill_time;
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::db::{NewUserParams, TestDb};
183    use gpui::TestAppContext;
184
185    #[gpui::test]
186    async fn test_rate_limiter(cx: &mut TestAppContext) {
187        let test_db = TestDb::sqlite(cx.executor().clone());
188        let db = test_db.db().clone();
189        let user_1 = db
190            .create_user(
191                "user-1@zed.dev",
192                None,
193                false,
194                NewUserParams {
195                    github_login: "user-1".into(),
196                    github_user_id: 1,
197                },
198            )
199            .await
200            .unwrap()
201            .user_id;
202        let user_2 = db
203            .create_user(
204                "user-2@zed.dev",
205                None,
206                false,
207                NewUserParams {
208                    github_login: "user-2".into(),
209                    github_user_id: 2,
210                },
211            )
212            .await
213            .unwrap()
214            .user_id;
215
216        let mut now = Utc::now();
217
218        let rate_limiter = RateLimiter::new(db.clone());
219        let rate_limit_a = Box::new(RateLimitA);
220        let rate_limit_b = Box::new(RateLimitB);
221
222        // User 1 can access resource A two times before being rate-limited.
223        rate_limiter
224            .check_internal(&*rate_limit_a, user_1, now)
225            .await
226            .unwrap();
227        rate_limiter
228            .check_internal(&*rate_limit_a, user_1, now)
229            .await
230            .unwrap();
231        rate_limiter
232            .check_internal(&*rate_limit_a, user_1, now)
233            .await
234            .unwrap_err();
235
236        // User 2 can access resource A and user 1 can access resource B.
237        rate_limiter
238            .check_internal(&*rate_limit_b, user_2, now)
239            .await
240            .unwrap();
241        rate_limiter
242            .check_internal(&*rate_limit_b, user_1, now)
243            .await
244            .unwrap();
245
246        // After 1.5s, user 1 can make another request before being rate-limited again.
247        now += Duration::milliseconds(1500);
248        rate_limiter
249            .check_internal(&*rate_limit_a, user_1, now)
250            .await
251            .unwrap();
252        rate_limiter
253            .check_internal(&*rate_limit_a, user_1, now)
254            .await
255            .unwrap_err();
256
257        // After 500ms, user 1 can make another request before being rate-limited again.
258        now += Duration::milliseconds(500);
259        rate_limiter
260            .check_internal(&*rate_limit_a, user_1, now)
261            .await
262            .unwrap();
263        rate_limiter
264            .check_internal(&*rate_limit_a, user_1, now)
265            .await
266            .unwrap_err();
267
268        rate_limiter.save().await.unwrap();
269
270        // Rate limits are reloaded from the database, so user A is still rate-limited
271        // for resource A.
272        let rate_limiter = RateLimiter::new(db.clone());
273        rate_limiter
274            .check_internal(&*rate_limit_a, user_1, now)
275            .await
276            .unwrap_err();
277
278        // After 1s, user 1 can make another request before being rate-limited again.
279        now += Duration::seconds(1);
280        rate_limiter
281            .check_internal(&*rate_limit_a, user_1, now)
282            .await
283            .unwrap();
284        rate_limiter
285            .check_internal(&*rate_limit_a, user_1, now)
286            .await
287            .unwrap_err();
288    }
289
290    struct RateLimitA;
291
292    impl RateLimit for RateLimitA {
293        fn capacity(&self) -> usize {
294            2
295        }
296
297        fn refill_duration(&self) -> Duration {
298            Duration::seconds(2)
299        }
300
301        fn db_name(&self) -> &'static str {
302            "rate-limit-a"
303        }
304    }
305
306    struct RateLimitB;
307
308    impl RateLimit for RateLimitB {
309        fn capacity(&self) -> usize {
310            10
311        }
312
313        fn refill_duration(&self) -> Duration {
314            Duration::seconds(3)
315        }
316
317        fn db_name(&self) -> &'static str {
318            "rate-limit-b"
319        }
320    }
321}