collab: Adapt rate limits based on plan (#15548)

Marshall Bowers and Max created

This PR updates the rate limits to adapt based on the user's current
plan.

For the free plan rate limits I just took one-tenth of the existing rate
limits (which are now the Pro limits). We can adjust, as needed.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/src/api/billing.rs                         |  26 -
crates/collab/src/db/queries/billing_subscriptions.rs    |  19 
crates/collab/src/db/tests/billing_subscription_tests.rs |  16 
crates/collab/src/rate_limiter.rs                        |  85 +++--
crates/collab/src/rpc.rs                                 | 155 +++++++--
5 files changed, 195 insertions(+), 106 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -169,9 +169,7 @@ struct ManageBillingSubscriptionBody {
     github_user_id: i32,
     intent: ManageSubscriptionIntent,
     /// The ID of the subscription to manage.
-    ///
-    /// If not provided, we will try to use the active subscription (if there is only one).
-    subscription_id: Option<BillingSubscriptionId>,
+    subscription_id: BillingSubscriptionId,
 }
 
 #[derive(Debug, Serialize)]
@@ -206,23 +204,11 @@ async fn manage_billing_subscription(
     let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
         .context("failed to parse customer ID")?;
 
-    let subscription = if let Some(subscription_id) = body.subscription_id {
-        app.db
-            .get_billing_subscription_by_id(subscription_id)
-            .await?
-            .ok_or_else(|| anyhow!("subscription not found"))?
-    } else {
-        // If no subscription ID was provided, try to find the only active subscription ID.
-        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
-        if subscriptions.len() > 1 {
-            Err(anyhow!("user has multiple active subscriptions"))?;
-        }
-
-        subscriptions
-            .into_iter()
-            .next()
-            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
-    };
+    let subscription = app
+        .db
+        .get_billing_subscription_by_id(body.subscription_id)
+        .await?
+        .ok_or_else(|| anyhow!("subscription not found"))?;
 
     let flow = match body.intent {
         ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -110,13 +110,15 @@ impl Database {
         .await
     }
 
-    /// Returns all of the active billing subscriptions for the user with the specified ID.
-    pub async fn get_active_billing_subscriptions(
-        &self,
-        user_id: UserId,
-    ) -> Result<Vec<billing_subscription::Model>> {
+    /// Returns whether the user has an active billing subscription.
+    pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
+        Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
+    }
+
+    /// Returns the count of the active billing subscriptions for the user with the specified ID.
+    pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
         self.transaction(|tx| async move {
-            let subscriptions = billing_subscription::Entity::find()
+            let count = billing_subscription::Entity::find()
                 .inner_join(billing_customer::Entity)
                 .filter(
                     billing_customer::Column::UserId.eq(user_id).and(
@@ -124,11 +126,10 @@ impl Database {
                             .eq(StripeSubscriptionStatus::Active),
                     ),
                 )
-                .order_by_asc(billing_subscription::Column::Id)
-                .all(&*tx)
+                .count(&*tx)
                 .await?;
 
-            Ok(subscriptions)
+            Ok(count as usize)
         })
         .await
     }

crates/collab/src/db/tests/billing_subscription_tests.rs 🔗

@@ -17,9 +17,12 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
     // A user with no subscription has no active billing subscriptions.
     {
         let user_id = new_test_user(db, "no-subscription-user@example.com").await;
-        let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
+        let subscription_count = db
+            .count_active_billing_subscriptions(user_id)
+            .await
+            .unwrap();
 
-        assert_eq!(subscriptions.len(), 0);
+        assert_eq!(subscription_count, 0);
     }
 
     // A user with an active subscription has one active billing subscription.
@@ -42,7 +45,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
         .await
         .unwrap();
 
-        let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
+        let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
         assert_eq!(subscriptions.len(), 1);
 
         let subscription = &subscriptions[0];
@@ -76,7 +79,10 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
         .await
         .unwrap();
 
-        let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
-        assert_eq!(subscriptions.len(), 0);
+        let subscription_count = db
+            .count_active_billing_subscriptions(user_id)
+            .await
+            .unwrap();
+        assert_eq!(subscription_count, 0);
     }
 }

crates/collab/src/rate_limiter.rs 🔗

@@ -6,10 +6,10 @@ use sea_orm::prelude::DateTimeUtc;
 use std::sync::Arc;
 use util::ResultExt;
 
-pub trait RateLimit: 'static {
-    fn capacity() -> usize;
-    fn refill_duration() -> Duration;
-    fn db_name() -> &'static str;
+pub trait RateLimit: Send + Sync {
+    fn capacity(&self) -> usize;
+    fn refill_duration(&self) -> Duration;
+    fn db_name(&self) -> &'static str;
 }
 
 /// Used to enforce per-user rate limits
@@ -42,18 +42,23 @@ impl RateLimiter {
 
     /// Returns an error if the user has exceeded the specified `RateLimit`.
     /// Attempts to read the from the database if no cached RateBucket currently exists.
-    pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
-        self.check_internal::<T>(user_id, Utc::now()).await
+    pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
+        self.check_internal(limit, user_id, Utc::now()).await
     }
 
-    async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
-        let bucket_key = (user_id, T::db_name().to_string());
+    async fn check_internal(
+        &self,
+        limit: &dyn RateLimit,
+        user_id: UserId,
+        now: DateTimeUtc,
+    ) -> Result<()> {
+        let bucket_key = (user_id, limit.db_name().to_string());
 
         // Attempt to fetch the bucket from the database if it hasn't been cached.
         // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
         // but this enforces limits across restarts so long as the database is reachable.
         if !self.buckets.contains_key(&bucket_key) {
-            if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
+            if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
                 self.buckets.insert(bucket_key.clone(), bucket);
                 self.dirty_buckets.insert(bucket_key.clone());
             }
@@ -62,7 +67,7 @@ impl RateLimiter {
         let mut bucket = self
             .buckets
             .entry(bucket_key.clone())
-            .or_insert_with(|| RateBucket::new::<T>(now));
+            .or_insert_with(|| RateBucket::new(limit, now));
 
         if bucket.value_mut().allow(now) {
             self.dirty_buckets.insert(bucket_key);
@@ -72,16 +77,18 @@ impl RateLimiter {
         }
     }
 
-    async fn load_bucket<T: RateLimit>(
+    async fn load_bucket(
         &self,
+        limit: &dyn RateLimit,
         user_id: UserId,
     ) -> Result<Option<RateBucket>, Error> {
         Ok(self
             .db
-            .get_rate_bucket(user_id, T::db_name())
+            .get_rate_bucket(user_id, limit.db_name())
             .await?
             .map(|saved_bucket| {
-                RateBucket::from_db::<T>(
+                RateBucket::from_db(
+                    limit,
                     saved_bucket.token_count as usize,
                     DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
                 )
@@ -124,20 +131,20 @@ struct RateBucket {
 }
 
 impl RateBucket {
-    fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
+    fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
         Self {
-            capacity: T::capacity(),
-            token_count: T::capacity(),
-            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+            capacity: limit.capacity(),
+            token_count: limit.capacity(),
+            refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
             last_refill: now,
         }
     }
 
-    fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
+    fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
         Self {
-            capacity: T::capacity(),
+            capacity: limit.capacity(),
             token_count,
-            refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+            refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
             last_refill,
         }
     }
@@ -205,50 +212,52 @@ mod tests {
         let mut now = Utc::now();
 
         let rate_limiter = RateLimiter::new(db.clone());
+        let rate_limit_a = Box::new(RateLimitA);
+        let rate_limit_b = Box::new(RateLimitB);
 
         // User 1 can access resource A two times before being rate-limited.
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap_err();
 
         // User 2 can access resource A and user 1 can access resource B.
         rate_limiter
-            .check_internal::<RateLimitB>(user_2, now)
+            .check_internal(&*rate_limit_b, user_2, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitB>(user_1, now)
+            .check_internal(&*rate_limit_b, user_1, now)
             .await
             .unwrap();
 
         // After 1.5s, user 1 can make another request before being rate-limited again.
         now += Duration::milliseconds(1500);
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap_err();
 
         // After 500ms, user 1 can make another request before being rate-limited again.
         now += Duration::milliseconds(500);
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap_err();
 
@@ -258,18 +267,18 @@ mod tests {
         // for resource A.
         let rate_limiter = RateLimiter::new(db.clone());
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap_err();
 
         // After 1s, user 1 can make another request before being rate-limited again.
         now += Duration::seconds(1);
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap();
         rate_limiter
-            .check_internal::<RateLimitA>(user_1, now)
+            .check_internal(&*rate_limit_a, user_1, now)
             .await
             .unwrap_err();
     }
@@ -277,15 +286,15 @@ mod tests {
     struct RateLimitA;
 
     impl RateLimit for RateLimitA {
-        fn capacity() -> usize {
+        fn capacity(&self) -> usize {
             2
         }
 
-        fn refill_duration() -> Duration {
+        fn refill_duration(&self) -> Duration {
             Duration::seconds(2)
         }
 
-        fn db_name() -> &'static str {
+        fn db_name(&self) -> &'static str {
             "rate-limit-a"
         }
     }
@@ -293,15 +302,15 @@ mod tests {
     struct RateLimitB;
 
     impl RateLimit for RateLimitB {
-        fn capacity() -> usize {
+        fn capacity(&self) -> usize {
             10
         }
 
-        fn refill_duration() -> Duration {
+        fn refill_duration(&self) -> Duration {
             Duration::seconds(3)
         }
 
-        fn db_name() -> &'static str {
+        fn db_name(&self) -> &'static str {
             "rate-limit-b"
         }
     }

crates/collab/src/rpc.rs 🔗

@@ -199,6 +199,23 @@ impl Session {
         }
     }
 
+    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
+        if self.is_staff() {
+            return Ok(proto::Plan::ZedPro);
+        }
+
+        let Some(user_id) = self.user_id() else {
+            return Ok(proto::Plan::Free);
+        };
+
+        let db = self.db().await;
+        if db.has_active_billing_subscription(user_id).await? {
+            Ok(proto::Plan::ZedPro)
+        } else {
+            Ok(proto::Plan::Free)
+        }
+    }
+
     fn dev_server_id(&self) -> Option<DevServerId> {
         match &self.principal {
             Principal::User(_) | Principal::Impersonated { .. } => None,
@@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
     version.0.minor() < 139
 }
 
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
-    let db = session.db().await;
-    let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
-
-    let plan = if session.is_staff() || !active_subscriptions.is_empty() {
-        proto::Plan::ZedPro
-    } else {
-        proto::Plan::Free
-    };
+async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
+    let plan = session.current_plan().await?;
 
     session
         .peer
@@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version(
     Ok(())
 }
 
-struct CompleteWithLanguageModelRateLimit;
+struct ZedProCompleteWithLanguageModelRateLimit;
 
-impl RateLimit for CompleteWithLanguageModelRateLimit {
-    fn capacity() -> usize {
+impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
+    fn capacity(&self) -> usize {
         std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
             .ok()
             .and_then(|v| v.parse().ok())
             .unwrap_or(120) // Picked arbitrarily
     }
 
-    fn refill_duration() -> chrono::Duration {
+    fn refill_duration(&self) -> chrono::Duration {
         chrono::Duration::hours(1)
     }
 
-    fn db_name() -> &'static str {
-        "complete-with-language-model"
+    fn db_name(&self) -> &'static str {
+        "zed-pro:complete-with-language-model"
+    }
+}
+
+struct FreeCompleteWithLanguageModelRateLimit;
+
+impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
+    fn capacity(&self) -> usize {
+        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
+            .ok()
+            .and_then(|v| v.parse().ok())
+            .unwrap_or(120 / 10) // Picked arbitrarily
+    }
+
+    fn refill_duration(&self) -> chrono::Duration {
+        chrono::Duration::hours(1)
+    }
+
+    fn db_name(&self) -> &'static str {
+        "free:complete-with-language-model"
     }
 }
 
@@ -4562,9 +4591,14 @@ async fn complete_with_language_model(
     };
     authorize_access_to_language_models(&session).await?;
 
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
+        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
+    };
+
     session
         .rate_limiter
-        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+        .check(&*rate_limit, session.user_id())
         .await?;
 
     let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model(
     };
     authorize_access_to_language_models(&session).await?;
 
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
+        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
+    };
+
     session
         .rate_limiter
-        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+        .check(&*rate_limit, session.user_id())
         .await?;
 
     match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4684,9 +4723,14 @@ async fn count_language_model_tokens(
     };
     authorize_access_to_language_models(&session).await?;
 
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
+        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
+    };
+
     session
         .rate_limiter
-        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
+        .check(&*rate_limit, session.user_id())
         .await?;
 
     let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4713,41 +4757,79 @@ async fn count_language_model_tokens(
     Ok(())
 }
 
-struct CountLanguageModelTokensRateLimit;
+struct ZedProCountLanguageModelTokensRateLimit;
 
-impl RateLimit for CountLanguageModelTokensRateLimit {
-    fn capacity() -> usize {
+impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
+    fn capacity(&self) -> usize {
         std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
             .ok()
             .and_then(|v| v.parse().ok())
             .unwrap_or(600) // Picked arbitrarily
     }
 
-    fn refill_duration() -> chrono::Duration {
+    fn refill_duration(&self) -> chrono::Duration {
         chrono::Duration::hours(1)
     }
 
-    fn db_name() -> &'static str {
-        "count-language-model-tokens"
+    fn db_name(&self) -> &'static str {
+        "zed-pro:count-language-model-tokens"
     }
 }
 
-struct ComputeEmbeddingsRateLimit;
+struct FreeCountLanguageModelTokensRateLimit;
 
-impl RateLimit for ComputeEmbeddingsRateLimit {
-    fn capacity() -> usize {
+impl RateLimit for FreeCountLanguageModelTokensRateLimit {
+    fn capacity(&self) -> usize {
+        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
+            .ok()
+            .and_then(|v| v.parse().ok())
+            .unwrap_or(600 / 10) // Picked arbitrarily
+    }
+
+    fn refill_duration(&self) -> chrono::Duration {
+        chrono::Duration::hours(1)
+    }
+
+    fn db_name(&self) -> &'static str {
+        "free:count-language-model-tokens"
+    }
+}
+
+struct ZedProComputeEmbeddingsRateLimit;
+
+impl RateLimit for ZedProComputeEmbeddingsRateLimit {
+    fn capacity(&self) -> usize {
         std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
             .ok()
             .and_then(|v| v.parse().ok())
             .unwrap_or(5000) // Picked arbitrarily
     }
 
-    fn refill_duration() -> chrono::Duration {
+    fn refill_duration(&self) -> chrono::Duration {
+        chrono::Duration::hours(1)
+    }
+
+    fn db_name(&self) -> &'static str {
+        "zed-pro:compute-embeddings"
+    }
+}
+
+struct FreeComputeEmbeddingsRateLimit;
+
+impl RateLimit for FreeComputeEmbeddingsRateLimit {
+    fn capacity(&self) -> usize {
+        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
+            .ok()
+            .and_then(|v| v.parse().ok())
+            .unwrap_or(5000 / 10) // Picked arbitrarily
+    }
+
+    fn refill_duration(&self) -> chrono::Duration {
         chrono::Duration::hours(1)
     }
 
-    fn db_name() -> &'static str {
-        "compute-embeddings"
+    fn db_name(&self) -> &'static str {
+        "free:compute-embeddings"
     }
 }
 
@@ -4760,9 +4842,14 @@ async fn compute_embeddings(
     let api_key = api_key.context("no OpenAI API key configured on the server")?;
     authorize_access_to_language_models(&session).await?;
 
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
+        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
+    };
+
     session
         .rate_limiter
-        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
+        .check(&*rate_limit, session.user_id())
         .await?;
 
     let embeddings = match request.model.as_str() {
@@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
     let db = session.db().await;
     let flags = db.get_user_flags(session.user_id()).await?;
     if flags.iter().any(|flag| flag == "language-models") {
-        Ok(())
-    } else {
-        Err(anyhow!("permission denied"))?
+        return Ok(());
     }
+
+    Err(anyhow!("permission denied"))?
 }
 
 /// Get a Supermaven API key for the user