Reduce the amount of queries performed when updating plan (#31268)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/collab/src/db/tables/user.rs |  5 ++
crates/collab/src/rpc.rs            | 54 ++++++++++++++----------------
2 files changed, 30 insertions(+), 29 deletions(-)

Detailed changes

crates/collab/src/db/tables/user.rs 🔗

@@ -55,6 +55,11 @@ impl Model {
 
         account_created_at
     }
+
+    /// Returns the age of the user's account.
+    pub fn account_age(&self) -> chrono::Duration {
+        chrono::Utc::now().naive_utc() - self.account_created_at()
+    }
 }
 
 impl Related<super::access_token::Entity> for Entity {

crates/collab/src/rpc.rs 🔗

@@ -110,6 +110,13 @@ pub enum Principal {
 }
 
 impl Principal {
+    fn user(&self) -> &User {
+        match self {
+            Principal::User(user) => user,
+            Principal::Impersonated { user, .. } => user,
+        }
+    }
+
     fn update_span(&self, span: &tracing::Span) {
         match &self {
             Principal::User(user) => {
@@ -741,7 +748,7 @@ impl Server {
                 supermaven_client,
             };
 
-            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
+            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
                 tracing::error!(?error, "failed to send initial client update");
                 return;
             }
@@ -825,7 +832,6 @@ impl Server {
     async fn send_initial_client_update(
         &self,
         connection_id: ConnectionId,
-        principal: &Principal,
         zed_version: ZedVersion,
         mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
         session: &Session,
@@ -841,7 +847,7 @@ impl Server {
             let _ = send_connection_id.send(connection_id);
         }
 
-        match principal {
+        match &session.principal {
             Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
                 if !user.connected_once {
                     self.peer.send(connection_id, proto::ShowContacts {})?;
@@ -851,7 +857,7 @@ impl Server {
                         .await?;
                 }
 
-                update_user_plan(user.id, session).await?;
+                update_user_plan(session).await?;
 
                 let contacts = self.app_state.db.get_contacts(user.id).await?;
 
@@ -941,10 +947,10 @@ impl Server {
             .context("user not found")?;
 
         let update_user_plan = make_update_user_plan_message(
+            &user,
+            user.admin,
             &self.app_state.db,
             self.app_state.llm_db.clone(),
-            user_id,
-            user.admin,
         )
         .await?;
 
@@ -2707,26 +2713,25 @@ async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Re
 }
 
 async fn make_update_user_plan_message(
+    user: &User,
+    is_staff: bool,
     db: &Arc<Database>,
     llm_db: Option<Arc<LlmDatabase>>,
-    user_id: UserId,
-    is_staff: bool,
 ) -> Result<proto::UpdateUserPlan> {
-    let feature_flags = db.get_user_flags(user_id).await?;
-    let plan = current_plan(db, user_id, is_staff).await?;
-    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
-    let billing_preferences = db.get_billing_preferences(user_id).await?;
-    let user = db.get_user_by_id(user_id).await?;
+    let feature_flags = db.get_user_flags(user.id).await?;
+    let plan = current_plan(db, user.id, is_staff).await?;
+    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
+    let billing_preferences = db.get_billing_preferences(user.id).await?;
 
     let (subscription_period, usage) = if let Some(llm_db) = llm_db {
-        let subscription = db.get_active_billing_subscription(user_id).await?;
+        let subscription = db.get_active_billing_subscription(user.id).await?;
 
         let subscription_period =
             crate::db::billing_subscription::Model::current_period(subscription, is_staff);
 
         let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
             llm_db
-                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
+                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
                 .await?
         } else {
             None
@@ -2737,17 +2742,8 @@ async fn make_update_user_plan_message(
         (None, None)
     };
 
-    // Calculate account_too_young
-    let account_too_young = if matches!(plan, proto::Plan::ZedPro) {
-        // If they have paid, then we allow them to use all of the features
-        false
-    } else if let Some(user) = user {
-        // If we have access to the profile age, we use that
-        chrono::Utc::now().naive_utc() - user.account_created_at() < MIN_ACCOUNT_AGE_FOR_LLM_USE
-    } else {
-        // Default to false otherwise
-        false
-    };
+    let account_too_young =
+        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
 
     Ok(proto::UpdateUserPlan {
         plan: plan.into(),
@@ -2822,14 +2818,14 @@ async fn make_update_user_plan_message(
     })
 }
 
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
+async fn update_user_plan(session: &Session) -> Result<()> {
     let db = session.db().await;
 
     let update_user_plan = make_update_user_plan_message(
+        session.principal.user(),
+        session.is_staff(),
         &db.0,
         session.app_state.llm_db.clone(),
-        user_id,
-        session.is_staff(),
     )
     .await?;