crates/collab/src/db/queries.rs 🔗
@@ -14,7 +14,6 @@ pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
-pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;
Marshall Bowers created
This PR removes the `RateLimiter` from the collab codebase, as it is no
longer used.
Release Notes:
- N/A
crates/collab/src/db/queries.rs | 1
crates/collab/src/db/queries/rate_buckets.rs | 58 ---
crates/collab/src/db/tables.rs | 1
crates/collab/src/db/tables/rate_buckets.rs | 31 --
crates/collab/src/lib.rs | 4
crates/collab/src/main.rs | 8
crates/collab/src/rate_limiter.rs | 321 ----------------------
crates/collab/src/tests/test_server.rs | 3
8 files changed, 3 insertions(+), 424 deletions(-)
@@ -14,7 +14,6 @@ pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
-pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;
@@ -1,58 +0,0 @@
-use super::*;
-use crate::db::tables::rate_buckets;
-use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
-
-impl Database {
- /// Saves the rate limit for the given user and rate limit name if the last_refill is later
- /// than the currently saved timestamp.
- pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
- if buckets.is_empty() {
- return Ok(());
- }
-
- self.transaction(|tx| async move {
- rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
- rate_buckets::ActiveModel {
- user_id: ActiveValue::Set(bucket.user_id),
- rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
- token_count: ActiveValue::Set(bucket.token_count),
- last_refill: ActiveValue::Set(bucket.last_refill),
- }
- }))
- .on_conflict(
- OnConflict::columns([
- rate_buckets::Column::UserId,
- rate_buckets::Column::RateLimitName,
- ])
- .update_columns([
- rate_buckets::Column::TokenCount,
- rate_buckets::Column::LastRefill,
- ])
- .to_owned(),
- )
- .exec(&*tx)
- .await?;
-
- Ok(())
- })
- .await
- }
-
- /// Retrieves the rate limit for the given user and rate limit name.
- pub async fn get_rate_bucket(
- &self,
- user_id: UserId,
- rate_limit_name: &str,
- ) -> Result<Option<rate_buckets::Model>> {
- self.transaction(|tx| async move {
- let rate_limit = rate_buckets::Entity::find()
- .filter(rate_buckets::Column::UserId.eq(user_id))
- .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
- .one(&*tx)
- .await?;
-
- Ok(rate_limit)
- })
- .await
- }
-}
@@ -28,7 +28,6 @@ pub mod project;
pub mod project_collaborator;
pub mod project_repository;
pub mod project_repository_statuses;
-pub mod rate_buckets;
pub mod room;
pub mod room_participant;
pub mod server;
@@ -1,31 +0,0 @@
-use crate::db::UserId;
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
-#[sea_orm(table_name = "rate_buckets")]
-pub struct Model {
- #[sea_orm(primary_key, auto_increment = false)]
- pub user_id: UserId,
- #[sea_orm(primary_key, auto_increment = false)]
- pub rate_limit_name: String,
- pub token_count: i32,
- pub last_refill: DateTime,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::user::Entity",
- from = "Column::UserId",
- to = "super::user::Column::Id"
- )]
- User,
-}
-
-impl Related<super::user::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::User.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -6,7 +6,6 @@ pub mod env;
pub mod executor;
pub mod llm;
pub mod migrations;
-mod rate_limiter;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
@@ -25,7 +24,6 @@ pub use cents::*;
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
-pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@@ -295,7 +293,6 @@ pub struct AppState {
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
- pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
pub config: Config,
@@ -348,7 +345,6 @@ impl AppState {
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
- rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
@@ -13,8 +13,8 @@ use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
- AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
- env, executor::Executor, rpc::ResultExt,
+ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
+ executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
@@ -111,10 +111,6 @@ async fn main() -> Result<()> {
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();
- RateLimiter::save_periodically(
- state.rate_limiter.clone(),
- state.executor.clone(),
- );
let epoch = state
.db
@@ -1,321 +0,0 @@
-use crate::{Database, Error, Result, db::UserId, executor::Executor};
-use chrono::{DateTime, Duration, Utc};
-use dashmap::{DashMap, DashSet};
-use rpc::ErrorCodeExt;
-use sea_orm::prelude::DateTimeUtc;
-use std::sync::Arc;
-use util::ResultExt;
-
-pub trait RateLimit: Send + Sync {
- fn capacity(&self) -> usize;
- fn refill_duration(&self) -> Duration;
- fn db_name(&self) -> &'static str;
-}
-
-/// Used to enforce per-user rate limits
-pub struct RateLimiter {
- buckets: DashMap<(UserId, String), RateBucket>,
- dirty_buckets: DashSet<(UserId, String)>,
- db: Arc<Database>,
-}
-
-impl RateLimiter {
- pub fn new(db: Arc<Database>) -> Self {
- RateLimiter {
- buckets: DashMap::new(),
- dirty_buckets: DashSet::new(),
- db,
- }
- }
-
- /// Spawns a new task that periodically saves rate limit data to the database.
- pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
- const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
-
- executor.clone().spawn_detached(async move {
- loop {
- executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
- rate_limiter.save().await.log_err();
- }
- });
- }
-
- /// Returns an error if the user has exceeded the specified `RateLimit`.
- /// Attempts to read the from the database if no cached RateBucket currently exists.
- pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
- self.check_internal(limit, user_id, Utc::now()).await
- }
-
- async fn check_internal(
- &self,
- limit: &dyn RateLimit,
- user_id: UserId,
- now: DateTimeUtc,
- ) -> Result<()> {
- let bucket_key = (user_id, limit.db_name().to_string());
-
- // Attempt to fetch the bucket from the database if it hasn't been cached.
- // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
- // but this enforces limits across restarts so long as the database is reachable.
- if !self.buckets.contains_key(&bucket_key) {
- if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
- self.buckets.insert(bucket_key.clone(), bucket);
- self.dirty_buckets.insert(bucket_key.clone());
- }
- }
-
- let mut bucket = self
- .buckets
- .entry(bucket_key.clone())
- .or_insert_with(|| RateBucket::new(limit, now));
-
- if bucket.value_mut().allow(now) {
- self.dirty_buckets.insert(bucket_key);
- Ok(())
- } else {
- Err(rpc::proto::ErrorCode::RateLimitExceeded
- .message("rate limit exceeded".into())
- .anyhow())?
- }
- }
-
- async fn load_bucket(
- &self,
- limit: &dyn RateLimit,
- user_id: UserId,
- ) -> Result<Option<RateBucket>, Error> {
- Ok(self
- .db
- .get_rate_bucket(user_id, limit.db_name())
- .await?
- .map(|saved_bucket| {
- RateBucket::from_db(
- limit,
- saved_bucket.token_count as usize,
- DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
- )
- }))
- }
-
- pub async fn save(&self) -> Result<()> {
- let mut buckets = Vec::new();
- self.dirty_buckets.retain(|key| {
- if let Some(bucket) = self.buckets.get(key) {
- buckets.push(crate::db::rate_buckets::Model {
- user_id: key.0,
- rate_limit_name: key.1.clone(),
- token_count: bucket.token_count as i32,
- last_refill: bucket.last_refill.naive_utc(),
- });
- }
- false
- });
-
- match self.db.save_rate_buckets(&buckets).await {
- Ok(()) => Ok(()),
- Err(err) => {
- for bucket in buckets {
- self.dirty_buckets
- .insert((bucket.user_id, bucket.rate_limit_name));
- }
- Err(err)
- }
- }
- }
-}
-
-#[derive(Clone, Debug)]
-struct RateBucket {
- capacity: usize,
- token_count: usize,
- refill_time_per_token: Duration,
- last_refill: DateTimeUtc,
-}
-
-impl RateBucket {
- fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
- Self {
- capacity: limit.capacity(),
- token_count: limit.capacity(),
- refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
- last_refill: now,
- }
- }
-
- fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
- Self {
- capacity: limit.capacity(),
- token_count,
- refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
- last_refill,
- }
- }
-
- fn allow(&mut self, now: DateTimeUtc) -> bool {
- self.refill(now);
- if self.token_count > 0 {
- self.token_count -= 1;
- true
- } else {
- false
- }
- }
-
- fn refill(&mut self, now: DateTimeUtc) {
- let elapsed = now - self.last_refill;
- if elapsed >= self.refill_time_per_token {
- let new_tokens =
- elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
- self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
-
- let unused_refill_time = Duration::milliseconds(
- elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
- );
- self.last_refill = now - unused_refill_time;
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::db::{NewUserParams, TestDb};
- use gpui::TestAppContext;
-
- #[gpui::test]
- async fn test_rate_limiter(cx: &mut TestAppContext) {
- let test_db = TestDb::sqlite(cx.executor().clone());
- let db = test_db.db().clone();
- let user_1 = db
- .create_user(
- "user-1@zed.dev",
- None,
- false,
- NewUserParams {
- github_login: "user-1".into(),
- github_user_id: 1,
- },
- )
- .await
- .unwrap()
- .user_id;
- let user_2 = db
- .create_user(
- "user-2@zed.dev",
- None,
- false,
- NewUserParams {
- github_login: "user-2".into(),
- github_user_id: 2,
- },
- )
- .await
- .unwrap()
- .user_id;
-
- let mut now = Utc::now();
-
- let rate_limiter = RateLimiter::new(db.clone());
- let rate_limit_a = Box::new(RateLimitA);
- let rate_limit_b = Box::new(RateLimitB);
-
- // User 1 can access resource A two times before being rate-limited.
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap_err();
-
- // User 2 can access resource A and user 1 can access resource B.
- rate_limiter
- .check_internal(&*rate_limit_b, user_2, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_b, user_1, now)
- .await
- .unwrap();
-
- // After 1.5s, user 1 can make another request before being rate-limited again.
- now += Duration::milliseconds(1500);
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap_err();
-
- // After 500ms, user 1 can make another request before being rate-limited again.
- now += Duration::milliseconds(500);
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap_err();
-
- rate_limiter.save().await.unwrap();
-
- // Rate limits are reloaded from the database, so user A is still rate-limited
- // for resource A.
- let rate_limiter = RateLimiter::new(db.clone());
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap_err();
-
- // After 1s, user 1 can make another request before being rate-limited again.
- now += Duration::seconds(1);
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap();
- rate_limiter
- .check_internal(&*rate_limit_a, user_1, now)
- .await
- .unwrap_err();
- }
-
- struct RateLimitA;
-
- impl RateLimit for RateLimitA {
- fn capacity(&self) -> usize {
- 2
- }
-
- fn refill_duration(&self) -> Duration {
- Duration::seconds(2)
- }
-
- fn db_name(&self) -> &'static str {
- "rate-limit-a"
- }
- }
-
- struct RateLimitB;
-
- impl RateLimit for RateLimitB {
- fn capacity(&self) -> usize {
- 10
- }
-
- fn refill_duration(&self) -> Duration {
- Duration::seconds(3)
- }
-
- fn db_name(&self) -> &'static str {
- "rate-limit-b"
- }
- }
-}
@@ -1,5 +1,5 @@
use crate::{
- AppState, Config, RateLimiter,
+ AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
executor::Executor,
rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion},
@@ -517,7 +517,6 @@ impl TestServer {
blob_store_client: None,
stripe_client: None,
stripe_billing: None,
- rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
kinesis_client: None,
config: Config {