1use crate::{Database, Error, Result, db::UserId, executor::Executor};
2use chrono::{DateTime, Duration, Utc};
3use dashmap::{DashMap, DashSet};
4use rpc::ErrorCodeExt;
5use sea_orm::prelude::DateTimeUtc;
6use std::sync::Arc;
7use util::ResultExt;
8
9pub trait RateLimit: Send + Sync {
10 fn capacity(&self) -> usize;
11 fn refill_duration(&self) -> Duration;
12 fn db_name(&self) -> &'static str;
13}
14
15/// Used to enforce per-user rate limits
16pub struct RateLimiter {
17 buckets: DashMap<(UserId, String), RateBucket>,
18 dirty_buckets: DashSet<(UserId, String)>,
19 db: Arc<Database>,
20}
21
22impl RateLimiter {
23 pub fn new(db: Arc<Database>) -> Self {
24 RateLimiter {
25 buckets: DashMap::new(),
26 dirty_buckets: DashSet::new(),
27 db,
28 }
29 }
30
31 /// Spawns a new task that periodically saves rate limit data to the database.
32 pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
33 const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
34
35 executor.clone().spawn_detached(async move {
36 loop {
37 executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
38 rate_limiter.save().await.log_err();
39 }
40 });
41 }
42
43 /// Returns an error if the user has exceeded the specified `RateLimit`.
44 /// Attempts to read the from the database if no cached RateBucket currently exists.
45 pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
46 self.check_internal(limit, user_id, Utc::now()).await
47 }
48
49 async fn check_internal(
50 &self,
51 limit: &dyn RateLimit,
52 user_id: UserId,
53 now: DateTimeUtc,
54 ) -> Result<()> {
55 let bucket_key = (user_id, limit.db_name().to_string());
56
57 // Attempt to fetch the bucket from the database if it hasn't been cached.
58 // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
59 // but this enforces limits across restarts so long as the database is reachable.
60 if !self.buckets.contains_key(&bucket_key) {
61 if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
62 self.buckets.insert(bucket_key.clone(), bucket);
63 self.dirty_buckets.insert(bucket_key.clone());
64 }
65 }
66
67 let mut bucket = self
68 .buckets
69 .entry(bucket_key.clone())
70 .or_insert_with(|| RateBucket::new(limit, now));
71
72 if bucket.value_mut().allow(now) {
73 self.dirty_buckets.insert(bucket_key);
74 Ok(())
75 } else {
76 Err(rpc::proto::ErrorCode::RateLimitExceeded
77 .message("rate limit exceeded".into())
78 .anyhow())?
79 }
80 }
81
82 async fn load_bucket(
83 &self,
84 limit: &dyn RateLimit,
85 user_id: UserId,
86 ) -> Result<Option<RateBucket>, Error> {
87 Ok(self
88 .db
89 .get_rate_bucket(user_id, limit.db_name())
90 .await?
91 .map(|saved_bucket| {
92 RateBucket::from_db(
93 limit,
94 saved_bucket.token_count as usize,
95 DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
96 )
97 }))
98 }
99
100 pub async fn save(&self) -> Result<()> {
101 let mut buckets = Vec::new();
102 self.dirty_buckets.retain(|key| {
103 if let Some(bucket) = self.buckets.get(key) {
104 buckets.push(crate::db::rate_buckets::Model {
105 user_id: key.0,
106 rate_limit_name: key.1.clone(),
107 token_count: bucket.token_count as i32,
108 last_refill: bucket.last_refill.naive_utc(),
109 });
110 }
111 false
112 });
113
114 match self.db.save_rate_buckets(&buckets).await {
115 Ok(()) => Ok(()),
116 Err(err) => {
117 for bucket in buckets {
118 self.dirty_buckets
119 .insert((bucket.user_id, bucket.rate_limit_name));
120 }
121 Err(err)
122 }
123 }
124 }
125}
126
127#[derive(Clone, Debug)]
128struct RateBucket {
129 capacity: usize,
130 token_count: usize,
131 refill_time_per_token: Duration,
132 last_refill: DateTimeUtc,
133}
134
135impl RateBucket {
136 fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
137 Self {
138 capacity: limit.capacity(),
139 token_count: limit.capacity(),
140 refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
141 last_refill: now,
142 }
143 }
144
145 fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
146 Self {
147 capacity: limit.capacity(),
148 token_count,
149 refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
150 last_refill,
151 }
152 }
153
154 fn allow(&mut self, now: DateTimeUtc) -> bool {
155 self.refill(now);
156 if self.token_count > 0 {
157 self.token_count -= 1;
158 true
159 } else {
160 false
161 }
162 }
163
164 fn refill(&mut self, now: DateTimeUtc) {
165 let elapsed = now - self.last_refill;
166 if elapsed >= self.refill_time_per_token {
167 let new_tokens =
168 elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
169 self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
170
171 let unused_refill_time = Duration::milliseconds(
172 elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
173 );
174 self.last_refill = now - unused_refill_time;
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::db::{NewUserParams, TestDb};
183 use gpui::TestAppContext;
184
185 #[gpui::test]
186 async fn test_rate_limiter(cx: &mut TestAppContext) {
187 let test_db = TestDb::sqlite(cx.executor().clone());
188 let db = test_db.db().clone();
189 let user_1 = db
190 .create_user(
191 "user-1@zed.dev",
192 None,
193 false,
194 NewUserParams {
195 github_login: "user-1".into(),
196 github_user_id: 1,
197 },
198 )
199 .await
200 .unwrap()
201 .user_id;
202 let user_2 = db
203 .create_user(
204 "user-2@zed.dev",
205 None,
206 false,
207 NewUserParams {
208 github_login: "user-2".into(),
209 github_user_id: 2,
210 },
211 )
212 .await
213 .unwrap()
214 .user_id;
215
216 let mut now = Utc::now();
217
218 let rate_limiter = RateLimiter::new(db.clone());
219 let rate_limit_a = Box::new(RateLimitA);
220 let rate_limit_b = Box::new(RateLimitB);
221
222 // User 1 can access resource A two times before being rate-limited.
223 rate_limiter
224 .check_internal(&*rate_limit_a, user_1, now)
225 .await
226 .unwrap();
227 rate_limiter
228 .check_internal(&*rate_limit_a, user_1, now)
229 .await
230 .unwrap();
231 rate_limiter
232 .check_internal(&*rate_limit_a, user_1, now)
233 .await
234 .unwrap_err();
235
236 // User 2 can access resource A and user 1 can access resource B.
237 rate_limiter
238 .check_internal(&*rate_limit_b, user_2, now)
239 .await
240 .unwrap();
241 rate_limiter
242 .check_internal(&*rate_limit_b, user_1, now)
243 .await
244 .unwrap();
245
246 // After 1.5s, user 1 can make another request before being rate-limited again.
247 now += Duration::milliseconds(1500);
248 rate_limiter
249 .check_internal(&*rate_limit_a, user_1, now)
250 .await
251 .unwrap();
252 rate_limiter
253 .check_internal(&*rate_limit_a, user_1, now)
254 .await
255 .unwrap_err();
256
257 // After 500ms, user 1 can make another request before being rate-limited again.
258 now += Duration::milliseconds(500);
259 rate_limiter
260 .check_internal(&*rate_limit_a, user_1, now)
261 .await
262 .unwrap();
263 rate_limiter
264 .check_internal(&*rate_limit_a, user_1, now)
265 .await
266 .unwrap_err();
267
268 rate_limiter.save().await.unwrap();
269
270 // Rate limits are reloaded from the database, so user A is still rate-limited
271 // for resource A.
272 let rate_limiter = RateLimiter::new(db.clone());
273 rate_limiter
274 .check_internal(&*rate_limit_a, user_1, now)
275 .await
276 .unwrap_err();
277
278 // After 1s, user 1 can make another request before being rate-limited again.
279 now += Duration::seconds(1);
280 rate_limiter
281 .check_internal(&*rate_limit_a, user_1, now)
282 .await
283 .unwrap();
284 rate_limiter
285 .check_internal(&*rate_limit_a, user_1, now)
286 .await
287 .unwrap_err();
288 }
289
290 struct RateLimitA;
291
292 impl RateLimit for RateLimitA {
293 fn capacity(&self) -> usize {
294 2
295 }
296
297 fn refill_duration(&self) -> Duration {
298 Duration::seconds(2)
299 }
300
301 fn db_name(&self) -> &'static str {
302 "rate-limit-a"
303 }
304 }
305
306 struct RateLimitB;
307
308 impl RateLimit for RateLimitB {
309 fn capacity(&self) -> usize {
310 10
311 }
312
313 fn refill_duration(&self) -> Duration {
314 Duration::seconds(3)
315 }
316
317 fn db_name(&self) -> &'static str {
318 "rate-limit-b"
319 }
320 }
321}