collab: Include `subscription_period` in LLM token claims (#28819)

Marshall Bowers created

This PR updates the LLM token claims to include the user's active
subscription period.

Release Notes:

- N/A

Change summary

crates/collab/src/db/queries/billing_subscriptions.rs | 23 +++++++++++++
crates/collab/src/llm/token.rs                        | 22 ++++++++++--
crates/collab/src/rpc.rs                              |  6 ++-
3 files changed, 45 insertions(+), 6 deletions(-)

Detailed changes

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -108,6 +108,28 @@ impl Database {
         .await
     }
 
+    pub async fn get_active_billing_subscription(
+        &self,
+        user_id: UserId,
+    ) -> Result<Option<billing_subscription::Model>> {
+        self.transaction(|tx| async move {
+            Ok(billing_subscription::Entity::find()
+                .inner_join(billing_customer::Entity)
+                .filter(billing_customer::Column::UserId.eq(user_id))
+                .filter(
+                    Condition::all()
+                        .add(
+                            billing_subscription::Column::StripeSubscriptionStatus
+                                .eq(StripeSubscriptionStatus::Active),
+                        )
+                        .add(billing_subscription::Column::Kind.is_not_null()),
+                )
+                .one(&*tx)
+                .await?)
+        })
+        .await
+    }
+
     /// Returns all of the billing subscriptions for the user with the specified ID.
     ///
     /// Note that this returns the subscriptions regardless of their status.
@@ -145,6 +167,7 @@ impl Database {
                         billing_subscription::Column::StripeSubscriptionStatus
                             .eq(StripeSubscriptionStatus::Active),
                     )
+                    .filter(billing_subscription::Column::Kind.is_null())
                     .order_by_asc(billing_subscription::Column::Id)
                     .stream(&*tx)
                     .await?;

crates/collab/src/llm/token.rs 🔗

@@ -1,13 +1,14 @@
 use crate::Cents;
-use crate::db::user;
+use crate::db::{billing_subscription, user};
 use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
 use crate::{Config, db::billing_preference};
 use anyhow::{Result, anyhow};
-use chrono::{NaiveDateTime, Utc};
+use chrono::{DateTime, NaiveDateTime, Utc};
 use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
 use serde::{Deserialize, Serialize};
 use std::time::Duration;
 use thiserror::Error;
+use util::maybe;
 use uuid::Uuid;
 
 #[derive(Clone, Debug, Default, Serialize, Deserialize)]
@@ -29,6 +30,8 @@ pub struct LlmTokenClaims {
     pub max_monthly_spend_in_cents: u32,
     pub custom_llm_monthly_allowance_in_cents: Option<u32>,
     pub plan: rpc::proto::Plan,
+    #[serde(default)]
+    pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
 }
 
 const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
@@ -39,8 +42,9 @@ impl LlmTokenClaims {
         is_staff: bool,
         billing_preferences: Option<billing_preference::Model>,
         feature_flags: &Vec<String>,
-        has_llm_subscription: bool,
+        has_legacy_llm_subscription: bool,
         plan: rpc::proto::Plan,
+        subscription: Option<billing_subscription::Model>,
         system_id: Option<String>,
         config: &Config,
     ) -> Result<String> {
@@ -69,7 +73,7 @@ impl LlmTokenClaims {
             has_predict_edits_feature_flag: feature_flags
                 .iter()
                 .any(|flag| flag == "predict-edits"),
-            has_llm_subscription,
+            has_llm_subscription: has_legacy_llm_subscription,
             max_monthly_spend_in_cents: billing_preferences
                 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
                     preferences.max_monthly_llm_usage_spending_in_cents as u32
@@ -78,6 +82,16 @@ impl LlmTokenClaims {
                 .custom_llm_monthly_allowance_in_cents
                 .map(|allowance| allowance as u32),
             plan,
+            subscription_period: maybe!({
+                let subscription = subscription?;
+                let period_start = subscription.stripe_current_period_start?;
+                let period_start = DateTime::from_timestamp(period_start, 0)?;
+
+                let period_end = subscription.stripe_current_period_end?;
+                let period_end = DateTime::from_timestamp(period_end, 0)?;
+
+                Some((period_start.naive_utc(), period_end.naive_utc()))
+            }),
         };
 
         Ok(jsonwebtoken::encode(

crates/collab/src/rpc.rs 🔗

@@ -4135,7 +4135,8 @@ async fn get_llm_api_token(
         Err(anyhow!("terms of service not accepted"))?
     }
 
-    let has_llm_subscription = session.has_llm_subscription(&db).await?;
+    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
+    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
     let billing_preferences = db.get_billing_preferences(user.id).await?;
 
     let token = LlmTokenClaims::create(
@@ -4143,8 +4144,9 @@ async fn get_llm_api_token(
         session.is_staff(),
         billing_preferences,
         &flags,
-        has_llm_subscription,
+        has_legacy_llm_subscription,
         session.current_plan(&db).await?,
+        billing_subscription,
         session.system_id.clone(),
         &session.app_state.config,
     )?;