1use crate::{db::UserId, executor::Executor, Database, Error, Result};
2use anyhow::anyhow;
3use chrono::{DateTime, Duration, Utc};
4use dashmap::{DashMap, DashSet};
5use sea_orm::prelude::DateTimeUtc;
6use std::sync::Arc;
7use util::ResultExt;
8
9pub trait RateLimit: 'static {
10 fn capacity() -> usize;
11 fn refill_duration() -> Duration;
12 fn db_name() -> &'static str;
13}
14
15/// Used to enforce per-user rate limits
16pub struct RateLimiter {
17 buckets: DashMap<(UserId, String), RateBucket>,
18 dirty_buckets: DashSet<(UserId, String)>,
19 db: Arc<Database>,
20}
21
22impl RateLimiter {
23 pub fn new(db: Arc<Database>) -> Self {
24 RateLimiter {
25 buckets: DashMap::new(),
26 dirty_buckets: DashSet::new(),
27 db,
28 }
29 }
30
31 /// Spawns a new task that periodically saves rate limit data to the database.
32 pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
33 const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
34
35 executor.clone().spawn_detached(async move {
36 loop {
37 executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
38 rate_limiter.save().await.log_err();
39 }
40 });
41 }
42
43 /// Returns an error if the user has exceeded the specified `RateLimit`.
44 /// Attempts to read the from the database if no cached RateBucket currently exists.
45 pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
46 self.check_internal::<T>(user_id, Utc::now()).await
47 }
48
49 async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
50 let bucket_key = (user_id, T::db_name().to_string());
51
52 // Attempt to fetch the bucket from the database if it hasn't been cached.
53 // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
54 // but this enforces limits across restarts so long as the database is reachable.
55 if !self.buckets.contains_key(&bucket_key) {
56 if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
57 self.buckets.insert(bucket_key.clone(), bucket);
58 self.dirty_buckets.insert(bucket_key.clone());
59 }
60 }
61
62 let mut bucket = self
63 .buckets
64 .entry(bucket_key.clone())
65 .or_insert_with(|| RateBucket::new::<T>(now));
66
67 if bucket.value_mut().allow(now) {
68 self.dirty_buckets.insert(bucket_key);
69 Ok(())
70 } else {
71 Err(anyhow!("rate limit exceeded"))?
72 }
73 }
74
75 async fn load_bucket<T: RateLimit>(
76 &self,
77 user_id: UserId,
78 ) -> Result<Option<RateBucket>, Error> {
79 Ok(self
80 .db
81 .get_rate_bucket(user_id, T::db_name())
82 .await?
83 .map(|saved_bucket| {
84 RateBucket::from_db::<T>(
85 saved_bucket.token_count as usize,
86 DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
87 )
88 }))
89 }
90
91 pub async fn save(&self) -> Result<()> {
92 let mut buckets = Vec::new();
93 self.dirty_buckets.retain(|key| {
94 if let Some(bucket) = self.buckets.get(&key) {
95 buckets.push(crate::db::rate_buckets::Model {
96 user_id: key.0,
97 rate_limit_name: key.1.clone(),
98 token_count: bucket.token_count as i32,
99 last_refill: bucket.last_refill.naive_utc(),
100 });
101 }
102 false
103 });
104
105 match self.db.save_rate_buckets(&buckets).await {
106 Ok(()) => Ok(()),
107 Err(err) => {
108 for bucket in buckets {
109 self.dirty_buckets
110 .insert((bucket.user_id, bucket.rate_limit_name));
111 }
112 Err(err)
113 }
114 }
115 }
116}
117
118#[derive(Clone)]
119struct RateBucket {
120 capacity: usize,
121 token_count: usize,
122 refill_time_per_token: Duration,
123 last_refill: DateTimeUtc,
124}
125
126impl RateBucket {
127 fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
128 Self {
129 capacity: T::capacity(),
130 token_count: T::capacity(),
131 refill_time_per_token: T::refill_duration() / T::capacity() as i32,
132 last_refill: now,
133 }
134 }
135
136 fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
137 Self {
138 capacity: T::capacity(),
139 token_count,
140 refill_time_per_token: T::refill_duration() / T::capacity() as i32,
141 last_refill,
142 }
143 }
144
145 fn allow(&mut self, now: DateTimeUtc) -> bool {
146 self.refill(now);
147 if self.token_count > 0 {
148 self.token_count -= 1;
149 true
150 } else {
151 false
152 }
153 }
154
155 fn refill(&mut self, now: DateTimeUtc) {
156 let elapsed = now - self.last_refill;
157 if elapsed >= self.refill_time_per_token {
158 let new_tokens =
159 elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
160 self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
161
162 let unused_refill_time = Duration::milliseconds(
163 elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
164 );
165 self.last_refill = now - unused_refill_time;
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::db::{NewUserParams, TestDb};
174 use gpui::TestAppContext;
175
176 #[gpui::test]
177 async fn test_rate_limiter(cx: &mut TestAppContext) {
178 let test_db = TestDb::sqlite(cx.executor().clone());
179 let db = test_db.db().clone();
180 let user_1 = db
181 .create_user(
182 "user-1@zed.dev",
183 false,
184 NewUserParams {
185 github_login: "user-1".into(),
186 github_user_id: 1,
187 },
188 )
189 .await
190 .unwrap()
191 .user_id;
192 let user_2 = db
193 .create_user(
194 "user-2@zed.dev",
195 false,
196 NewUserParams {
197 github_login: "user-2".into(),
198 github_user_id: 2,
199 },
200 )
201 .await
202 .unwrap()
203 .user_id;
204
205 let mut now = Utc::now();
206
207 let rate_limiter = RateLimiter::new(db.clone());
208
209 // User 1 can access resource A two times before being rate-limited.
210 rate_limiter
211 .check_internal::<RateLimitA>(user_1, now)
212 .await
213 .unwrap();
214 rate_limiter
215 .check_internal::<RateLimitA>(user_1, now)
216 .await
217 .unwrap();
218 rate_limiter
219 .check_internal::<RateLimitA>(user_1, now)
220 .await
221 .unwrap_err();
222
223 // User 2 can access resource A and user 1 can access resource B.
224 rate_limiter
225 .check_internal::<RateLimitB>(user_2, now)
226 .await
227 .unwrap();
228 rate_limiter
229 .check_internal::<RateLimitB>(user_1, now)
230 .await
231 .unwrap();
232
233 // After 1.5s, user 1 can make another request before being rate-limited again.
234 now += Duration::milliseconds(1500);
235 rate_limiter
236 .check_internal::<RateLimitA>(user_1, now)
237 .await
238 .unwrap();
239 rate_limiter
240 .check_internal::<RateLimitA>(user_1, now)
241 .await
242 .unwrap_err();
243
244 // After 500ms, user 1 can make another request before being rate-limited again.
245 now += Duration::milliseconds(500);
246 rate_limiter
247 .check_internal::<RateLimitA>(user_1, now)
248 .await
249 .unwrap();
250 rate_limiter
251 .check_internal::<RateLimitA>(user_1, now)
252 .await
253 .unwrap_err();
254
255 rate_limiter.save().await.unwrap();
256
257 // Rate limits are reloaded from the database, so user A is still rate-limited
258 // for resource A.
259 let rate_limiter = RateLimiter::new(db.clone());
260 rate_limiter
261 .check_internal::<RateLimitA>(user_1, now)
262 .await
263 .unwrap_err();
264
265 // After 1s, user 1 can make another request before being rate-limited again.
266 now += Duration::seconds(1);
267 rate_limiter
268 .check_internal::<RateLimitA>(user_1, now)
269 .await
270 .unwrap();
271 rate_limiter
272 .check_internal::<RateLimitA>(user_1, now)
273 .await
274 .unwrap_err();
275 }
276
277 struct RateLimitA;
278
279 impl RateLimit for RateLimitA {
280 fn capacity() -> usize {
281 2
282 }
283
284 fn refill_duration() -> Duration {
285 Duration::seconds(2)
286 }
287
288 fn db_name() -> &'static str {
289 "rate-limit-a"
290 }
291 }
292
293 struct RateLimitB;
294
295 impl RateLimit for RateLimitB {
296 fn capacity() -> usize {
297 10
298 }
299
300 fn refill_duration() -> Duration {
301 Duration::seconds(3)
302 }
303
304 fn db_name() -> &'static str {
305 "rate-limit-b"
306 }
307 }
308}