1use crate::{db::UserId, executor::Executor, Database, Error, Result};
2use anyhow::anyhow;
3use chrono::{DateTime, Duration, Utc};
4use dashmap::{DashMap, DashSet};
5use sea_orm::prelude::DateTimeUtc;
6use std::sync::Arc;
7use util::ResultExt;
8
9pub trait RateLimit: 'static {
10 fn capacity() -> usize;
11 fn refill_duration() -> Duration;
12 fn db_name() -> &'static str;
13}
14
15/// Used to enforce per-user rate limits
16pub struct RateLimiter {
17 buckets: DashMap<(UserId, String), RateBucket>,
18 dirty_buckets: DashSet<(UserId, String)>,
19 db: Arc<Database>,
20}
21
22impl RateLimiter {
23 pub fn new(db: Arc<Database>) -> Self {
24 RateLimiter {
25 buckets: DashMap::new(),
26 dirty_buckets: DashSet::new(),
27 db,
28 }
29 }
30
31 /// Spawns a new task that periodically saves rate limit data to the database.
32 pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
33 const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
34
35 executor.clone().spawn_detached(async move {
36 loop {
37 executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
38 rate_limiter.save().await.log_err();
39 }
40 });
41 }
42
43 /// Returns an error if the user has exceeded the specified `RateLimit`.
44 /// Attempts to read the from the database if no cached RateBucket currently exists.
45 pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
46 self.check_internal::<T>(user_id, Utc::now()).await
47 }
48
49 async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
50 let bucket_key = (user_id, T::db_name().to_string());
51
52 // Attempt to fetch the bucket from the database if it hasn't been cached.
53 // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
54 // but this enforces limits across restarts so long as the database is reachable.
55 if !self.buckets.contains_key(&bucket_key) {
56 if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
57 self.buckets.insert(bucket_key.clone(), bucket);
58 self.dirty_buckets.insert(bucket_key.clone());
59 }
60 }
61
62 let mut bucket = self
63 .buckets
64 .entry(bucket_key.clone())
65 .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
66
67 if bucket.value_mut().allow(now) {
68 self.dirty_buckets.insert(bucket_key);
69 Ok(())
70 } else {
71 Err(anyhow!("rate limit exceeded"))?
72 }
73 }
74
75 async fn load_bucket<K: RateLimit>(
76 &self,
77 user_id: UserId,
78 ) -> Result<Option<RateBucket>, Error> {
79 Ok(self
80 .db
81 .get_rate_bucket(user_id, K::db_name())
82 .await?
83 .map(|saved_bucket| RateBucket {
84 capacity: K::capacity(),
85 refill_time_per_token: K::refill_duration(),
86 token_count: saved_bucket.token_count as usize,
87 last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
88 }))
89 }
90
91 pub async fn save(&self) -> Result<()> {
92 let mut buckets = Vec::new();
93 self.dirty_buckets.retain(|key| {
94 if let Some(bucket) = self.buckets.get(&key) {
95 buckets.push(crate::db::rate_buckets::Model {
96 user_id: key.0,
97 rate_limit_name: key.1.clone(),
98 token_count: bucket.token_count as i32,
99 last_refill: bucket.last_refill.naive_utc(),
100 });
101 }
102 false
103 });
104
105 match self.db.save_rate_buckets(&buckets).await {
106 Ok(()) => Ok(()),
107 Err(err) => {
108 for bucket in buckets {
109 self.dirty_buckets
110 .insert((bucket.user_id, bucket.rate_limit_name));
111 }
112 Err(err)
113 }
114 }
115 }
116}
117
118#[derive(Clone)]
119struct RateBucket {
120 capacity: usize,
121 token_count: usize,
122 refill_time_per_token: Duration,
123 last_refill: DateTimeUtc,
124}
125
126impl RateBucket {
127 fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
128 RateBucket {
129 capacity,
130 token_count: capacity,
131 refill_time_per_token: refill_duration / capacity as i32,
132 last_refill: now,
133 }
134 }
135
136 fn allow(&mut self, now: DateTimeUtc) -> bool {
137 self.refill(now);
138 if self.token_count > 0 {
139 self.token_count -= 1;
140 true
141 } else {
142 false
143 }
144 }
145
146 fn refill(&mut self, now: DateTimeUtc) {
147 let elapsed = now - self.last_refill;
148 if elapsed >= self.refill_time_per_token {
149 let new_tokens =
150 elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
151
152 self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
153 self.last_refill = now;
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::db::{NewUserParams, TestDb};
162 use gpui::TestAppContext;
163
164 #[gpui::test]
165 async fn test_rate_limiter(cx: &mut TestAppContext) {
166 let test_db = TestDb::sqlite(cx.executor().clone());
167 let db = test_db.db().clone();
168 let user_1 = db
169 .create_user(
170 "user-1@zed.dev",
171 false,
172 NewUserParams {
173 github_login: "user-1".into(),
174 github_user_id: 1,
175 },
176 )
177 .await
178 .unwrap()
179 .user_id;
180 let user_2 = db
181 .create_user(
182 "user-2@zed.dev",
183 false,
184 NewUserParams {
185 github_login: "user-2".into(),
186 github_user_id: 2,
187 },
188 )
189 .await
190 .unwrap()
191 .user_id;
192
193 let mut now = Utc::now();
194
195 let rate_limiter = RateLimiter::new(db.clone());
196
197 // User 1 can access resource A two times before being rate-limited.
198 rate_limiter
199 .check_internal::<RateLimitA>(user_1, now)
200 .await
201 .unwrap();
202 rate_limiter
203 .check_internal::<RateLimitA>(user_1, now)
204 .await
205 .unwrap();
206 rate_limiter
207 .check_internal::<RateLimitA>(user_1, now)
208 .await
209 .unwrap_err();
210
211 // User 2 can access resource A and user 1 can access resource B.
212 rate_limiter
213 .check_internal::<RateLimitB>(user_2, now)
214 .await
215 .unwrap();
216 rate_limiter
217 .check_internal::<RateLimitB>(user_1, now)
218 .await
219 .unwrap();
220
221 // After one second, user 1 can make another request before being rate-limited again.
222 now += Duration::seconds(1);
223 rate_limiter
224 .check_internal::<RateLimitA>(user_1, now)
225 .await
226 .unwrap();
227 rate_limiter
228 .check_internal::<RateLimitA>(user_1, now)
229 .await
230 .unwrap_err();
231
232 rate_limiter.save().await.unwrap();
233
234 // Rate limits are reloaded from the database, so user A is still rate-limited
235 // for resource A.
236 let rate_limiter = RateLimiter::new(db.clone());
237 rate_limiter
238 .check_internal::<RateLimitA>(user_1, now)
239 .await
240 .unwrap_err();
241 }
242
243 struct RateLimitA;
244
245 impl RateLimit for RateLimitA {
246 fn capacity() -> usize {
247 2
248 }
249
250 fn refill_duration() -> Duration {
251 Duration::seconds(2)
252 }
253
254 fn db_name() -> &'static str {
255 "rate-limit-a"
256 }
257 }
258
259 struct RateLimitB;
260
261 impl RateLimit for RateLimitB {
262 fn capacity() -> usize {
263 10
264 }
265
266 fn refill_duration() -> Duration {
267 Duration::seconds(3)
268 }
269
270 fn db_name() -> &'static str {
271 "rate-limit-b"
272 }
273 }
274}