collab: Push down plan changes to the client (#30447)

Marshall Bowers created

This PR makes it so we push down plan updates from the server when the
user's subscription changes.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs |   6 
crates/collab/src/rpc.rs         | 248 ++++++++++++++++++---------------
2 files changed, 143 insertions(+), 111 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -1137,6 +1137,12 @@ async fn handle_customer_subscription_event(
             .await?;
     }
 
+    // When the user's subscription changes, push down any changes to their plan.
+    rpc_server
+        .update_plan_for_user(billing_customer.user_id)
+        .await
+        .trace_err();
+
     // When the user's subscription changes, we want to refresh their LLM tokens
     // to either grant/revoke access.
     rpc_server

crates/collab/src/rpc.rs 🔗

@@ -2,6 +2,7 @@ mod connection_pool;
 
 use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::db::billing_subscription::SubscriptionKind;
+use crate::llm::db::LlmDatabase;
 use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
 use crate::{
     AppState, Error, Result, auth,
@@ -67,7 +68,7 @@ use std::{
     time::{Duration, Instant},
 };
 use time::OffsetDateTime;
-use tokio::sync::{MutexGuard, Semaphore, watch};
+use tokio::sync::{Semaphore, watch};
 use tower::ServiceBuilder;
 use tracing::{
     Instrument,
@@ -166,29 +167,6 @@ impl Session {
         }
     }
 
-    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
-        if self.is_staff() {
-            return Ok(proto::Plan::ZedPro);
-        }
-
-        let user_id = self.user_id();
-
-        let subscription = db.get_active_billing_subscription(user_id).await?;
-        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
-
-        let plan = if let Some(subscription_kind) = subscription_kind {
-            match subscription_kind {
-                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
-                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
-                SubscriptionKind::ZedFree => proto::Plan::Free,
-            }
-        } else {
-            proto::Plan::Free
-        };
-
-        Ok(plan)
-    }
-
     fn user_id(&self) -> UserId {
         match &self.principal {
             Principal::User(user) => user.id,
@@ -953,6 +931,32 @@ impl Server {
         Ok(())
     }
 
+    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
+        let user = self
+            .app_state
+            .db
+            .get_user_by_id(user_id)
+            .await?
+            .ok_or_else(|| anyhow!("user not found"))?;
+
+        let update_user_plan = make_update_user_plan_message(
+            &self.app_state.db,
+            self.app_state.llm_db.clone(),
+            user_id,
+            user.admin,
+        )
+        .await?;
+
+        let pool = self.connection_pool.lock();
+        for connection_id in pool.user_connection_ids(user_id) {
+            self.peer
+                .send(connection_id, update_user_plan.clone())
+                .trace_err();
+        }
+
+        Ok(())
+    }
+
     pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
         let pool = self.connection_pool.lock();
         for connection_id in pool.user_connection_ids(user_id) {
@@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
     version.0.minor() < 139
 }
 
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
-    let db = session.db().await;
+async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
+    if is_staff {
+        return Ok(proto::Plan::ZedPro);
+    }
+
+    let subscription = db.get_active_billing_subscription(user_id).await?;
+    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
+
+    let plan = if let Some(subscription_kind) = subscription_kind {
+        match subscription_kind {
+            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
+            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
+            SubscriptionKind::ZedFree => proto::Plan::Free,
+        }
+    } else {
+        proto::Plan::Free
+    };
 
+    Ok(plan)
+}
+
+async fn make_update_user_plan_message(
+    db: &Arc<Database>,
+    llm_db: Option<Arc<LlmDatabase>>,
+    user_id: UserId,
+    is_staff: bool,
+) -> Result<proto::UpdateUserPlan> {
     let feature_flags = db.get_user_flags(user_id).await?;
-    let plan = session.current_plan(&db).await?;
+    let plan = current_plan(db, user_id, is_staff).await?;
     let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
     let billing_preferences = db.get_billing_preferences(user_id).await?;
 
-    let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
+    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
         let subscription = db.get_active_billing_subscription(user_id).await?;
 
-        let subscription_period = crate::db::billing_subscription::Model::current_period(
-            subscription,
-            session.is_staff(),
-        );
+        let subscription_period =
+            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
 
         let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
             llm_db
@@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
         (None, None)
     };
 
-    session
-        .peer
-        .send(
-            session.connection_id,
-            proto::UpdateUserPlan {
-                plan: plan.into(),
-                trial_started_at: billing_customer
-                    .and_then(|billing_customer| billing_customer.trial_started_at)
-                    .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
-                is_usage_based_billing_enabled: if session.is_staff() {
-                    Some(true)
-                } else {
-                    billing_preferences
-                        .map(|preferences| preferences.model_request_overages_enabled)
-                },
-                subscription_period: subscription_period.map(|(started_at, ended_at)| {
-                    proto::SubscriptionPeriod {
-                        started_at: started_at.timestamp() as u64,
-                        ended_at: ended_at.timestamp() as u64,
-                    }
-                }),
-                usage: usage.map(|usage| {
-                    let plan = match plan {
-                        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
-                        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
-                        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+    Ok(proto::UpdateUserPlan {
+        plan: plan.into(),
+        trial_started_at: billing_customer
+            .and_then(|billing_customer| billing_customer.trial_started_at)
+            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
+        is_usage_based_billing_enabled: if is_staff {
+            Some(true)
+        } else {
+            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
+        },
+        subscription_period: subscription_period.map(|(started_at, ended_at)| {
+            proto::SubscriptionPeriod {
+                started_at: started_at.timestamp() as u64,
+                ended_at: ended_at.timestamp() as u64,
+            }
+        }),
+        usage: usage.map(|usage| {
+            let plan = match plan {
+                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
+                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
+                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
+            };
+
+            let model_requests_limit = match plan.model_requests_limit() {
+                zed_llm_client::UsageLimit::Limited(limit) => {
+                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
+                        && feature_flags
+                            .iter()
+                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
+                    {
+                        1_000
+                    } else {
+                        limit
                     };
 
-                    let model_requests_limit = match plan.model_requests_limit() {
-                        zed_llm_client::UsageLimit::Limited(limit) => {
-                            let limit = if plan == zed_llm_client::Plan::ZedProTrial
-                                && feature_flags
-                                    .iter()
-                                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
-                            {
-                                1_000
-                            } else {
-                                limit
-                            };
+                    zed_llm_client::UsageLimit::Limited(limit)
+                }
+                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
+            };
 
-                            zed_llm_client::UsageLimit::Limited(limit)
+            proto::SubscriptionUsage {
+                model_requests_usage_amount: usage.model_requests as u32,
+                model_requests_usage_limit: Some(proto::UsageLimit {
+                    variant: Some(match model_requests_limit {
+                        zed_llm_client::UsageLimit::Limited(limit) => {
+                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+                                limit: limit as u32,
+                            })
                         }
                         zed_llm_client::UsageLimit::Unlimited => {
-                            zed_llm_client::UsageLimit::Unlimited
+                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
                         }
-                    };
-
-                    proto::SubscriptionUsage {
-                        model_requests_usage_amount: usage.model_requests as u32,
-                        model_requests_usage_limit: Some(proto::UsageLimit {
-                            variant: Some(match model_requests_limit {
-                                zed_llm_client::UsageLimit::Limited(limit) => {
-                                    proto::usage_limit::Variant::Limited(
-                                        proto::usage_limit::Limited {
-                                            limit: limit as u32,
-                                        },
-                                    )
-                                }
-                                zed_llm_client::UsageLimit::Unlimited => {
-                                    proto::usage_limit::Variant::Unlimited(
-                                        proto::usage_limit::Unlimited {},
-                                    )
-                                }
-                            }),
-                        }),
-                        edit_predictions_usage_amount: usage.edit_predictions as u32,
-                        edit_predictions_usage_limit: Some(proto::UsageLimit {
-                            variant: Some(match plan.edit_predictions_limit() {
-                                zed_llm_client::UsageLimit::Limited(limit) => {
-                                    proto::usage_limit::Variant::Limited(
-                                        proto::usage_limit::Limited {
-                                            limit: limit as u32,
-                                        },
-                                    )
-                                }
-                                zed_llm_client::UsageLimit::Unlimited => {
-                                    proto::usage_limit::Variant::Unlimited(
-                                        proto::usage_limit::Unlimited {},
-                                    )
-                                }
-                            }),
-                        }),
-                    }
+                    }),
                 }),
-            },
-        )
+                edit_predictions_usage_amount: usage.edit_predictions as u32,
+                edit_predictions_usage_limit: Some(proto::UsageLimit {
+                    variant: Some(match plan.edit_predictions_limit() {
+                        zed_llm_client::UsageLimit::Limited(limit) => {
+                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+                                limit: limit as u32,
+                            })
+                        }
+                        zed_llm_client::UsageLimit::Unlimited => {
+                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
+                        }
+                    }),
+                }),
+            }
+        }),
+    })
+}
+
+async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+    let db = session.db().await;
+
+    let update_user_plan = make_update_user_plan_message(
+        &db.0,
+        session.app_state.llm_db.clone(),
+        user_id,
+        session.is_staff(),
+    )
+    .await?;
+
+    session
+        .peer
+        .send(session.connection_id, update_user_plan)
         .trace_err();
 
     Ok(())