Fix logic errors in `RateLimiter` (#12421)

Antonio Scandurra created

This pull request fixes two issues in `RateLimiter` that caused
excessive rate-limiting to take place:

- c19083a35c89a22395595f8934c117a14943ed24 fixes a mistake that caused
us to load buckets from the database incorrectly and set the
`refill_time_per_token` to equal the `refill_duration`. This was the
primary reason why rate limiting was acting oddly.
- 34b88d14f6d9fde4d967554fc2e81498c1be3e26 fixes another slight logic
error that caused tokens to be underprovisioned. This was minor compared
to the bug above.

Release Notes:

- N/A

Change summary

crates/collab/src/rate_limiter.rs | 68 ++++++++++++++++++++++++--------
1 file changed, 51 insertions(+), 17 deletions(-)

Detailed changes

crates/collab/src/rate_limiter.rs 🔗

@@ -62,7 +62,7 @@ impl RateLimiter {
         let mut bucket = self
             .buckets
             .entry(bucket_key.clone())
-            .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
+            .or_insert_with(|| RateBucket::new::<T>(now));
 
         if bucket.value_mut().allow(now) {
             self.dirty_buckets.insert(bucket_key);
@@ -72,19 +72,19 @@ impl RateLimiter {
         }
     }
 
-    async fn load_bucket<K: RateLimit>(
+    async fn load_bucket<T: RateLimit>(
         &self,
         user_id: UserId,
     ) -> Result<Option<RateBucket>, Error> {
         Ok(self
             .db
-            .get_rate_bucket(user_id, K::db_name())
+            .get_rate_bucket(user_id, T::db_name())
             .await?
-            .map(|saved_bucket| RateBucket {
-                capacity: K::capacity(),
-                refill_time_per_token: K::refill_duration(),
-                token_count: saved_bucket.token_count as usize,
-                last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+            .map(|saved_bucket| {
+                RateBucket::from_db::<T>(
+                    saved_bucket.token_count as usize,
+                    DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+                )
             }))
     }
 
@@ -124,15 +124,24 @@ struct RateBucket {
 }
 
 impl RateBucket {
-    fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
-        RateBucket {
-            capacity,
-            token_count: capacity,
-            refill_time_per_token: refill_duration / capacity as i32,
+    fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
+        Self {
+            capacity: T::capacity(),
+            token_count: T::capacity(),
+            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
             last_refill: now,
         }
     }
 
+    fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
+        Self {
+            capacity: T::capacity(),
+            token_count,
+            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+            last_refill,
+        }
+    }
+
     fn allow(&mut self, now: DateTimeUtc) -> bool {
         self.refill(now);
         if self.token_count > 0 {
@@ -148,9 +157,12 @@ impl RateBucket {
         if elapsed >= self.refill_time_per_token {
             let new_tokens =
                 elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
-
             self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
-            self.last_refill = now;
+
+            let unused_refill_time = Duration::milliseconds(
+                elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
+            );
+            self.last_refill = now - unused_refill_time;
         }
     }
 }
@@ -218,8 +230,19 @@ mod tests {
             .await
             .unwrap();
 
-        // After one second, user 1 can make another request before being rate-limited again.
-        now += Duration::seconds(1);
+        // After 1.5s, user 1 can make another request before being rate-limited again.
+        now += Duration::milliseconds(1500);
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap_err();
+
+        // After 500ms, user 1 can make another request before being rate-limited again.
+        now += Duration::milliseconds(500);
         rate_limiter
             .check_internal::<RateLimitA>(user_1, now)
             .await
@@ -238,6 +261,17 @@ mod tests {
             .check_internal::<RateLimitA>(user_1, now)
             .await
             .unwrap_err();
+
+        // After 1s, user 1 can make another request before being rate-limited again.
+        now += Duration::seconds(1);
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap_err();
     }
 
     struct RateLimitA;