@@ -62,7 +62,7 @@ impl RateLimiter {
let mut bucket = self
.buckets
.entry(bucket_key.clone())
- .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
+ .or_insert_with(|| RateBucket::new::<T>(now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
@@ -72,19 +72,19 @@ impl RateLimiter {
}
}
- async fn load_bucket<K: RateLimit>(
+ async fn load_bucket<T: RateLimit>(
&self,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
- .get_rate_bucket(user_id, K::db_name())
+ .get_rate_bucket(user_id, T::db_name())
.await?
- .map(|saved_bucket| RateBucket {
- capacity: K::capacity(),
- refill_time_per_token: K::refill_duration(),
- token_count: saved_bucket.token_count as usize,
- last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+ .map(|saved_bucket| {
+ RateBucket::from_db::<T>(
+ saved_bucket.token_count as usize,
+ DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+ )
}))
}
@@ -124,15 +124,24 @@ struct RateBucket {
}
impl RateBucket {
- fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
- RateBucket {
- capacity,
- token_count: capacity,
- refill_time_per_token: refill_duration / capacity as i32,
+ fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
+ Self {
+ capacity: T::capacity(),
+ token_count: T::capacity(),
+ refill_time_per_token: T::refill_duration() / T::capacity() as i32,
last_refill: now,
}
}
+ fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
+ Self {
+ capacity: T::capacity(),
+ token_count,
+ refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+ last_refill,
+ }
+ }
+
fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now);
if self.token_count > 0 {
@@ -148,9 +157,12 @@ impl RateBucket {
if elapsed >= self.refill_time_per_token {
let new_tokens =
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
-
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
- self.last_refill = now;
+
+ let unused_refill_time = Duration::milliseconds(
+ elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
+ );
+ self.last_refill = now - unused_refill_time;
}
}
}
@@ -218,8 +230,19 @@ mod tests {
.await
.unwrap();
- // After one second, user 1 can make another request before being rate-limited again.
- now += Duration::seconds(1);
+ // After 1.5s, user 1 can make another request before being rate-limited again.
+ now += Duration::milliseconds(1500);
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap_err();
+
+ // After 500ms, user 1 can make another request before being rate-limited again.
+ now += Duration::milliseconds(500);
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
@@ -238,6 +261,17 @@ mod tests {
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
+
+ // After 1s, user 1 can make another request before being rate-limited again.
+ now += Duration::seconds(1);
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap_err();
}
struct RateLimitA;