collab: Limit customers to one free trial (#29232)

Marshall Bowers created

This PR makes it so customers can only subscribe to the trial once.

Release Notes:

- N/A

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql                        |  3 
crates/collab/migrations/20250422194500_add_trial_started_at_to_billing_customers.sql |  2 
crates/collab/src/api/billing.rs                                                      | 29 
crates/collab/src/db/queries/billing_customers.rs                                     |  4 
crates/collab/src/db/tables/billing_customer.rs                                       |  1 
5 files changed, 36 insertions(+), 3 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -492,7 +492,8 @@ CREATE TABLE IF NOT EXISTS billing_customers (
     created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
     user_id INTEGER NOT NULL REFERENCES users (id),
     has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
-    stripe_customer_id TEXT NOT NULL
+    stripe_customer_id TEXT NOT NULL,
+    trial_started_at TIMESTAMP
 );
 
 CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);

crates/collab/src/api/billing.rs 🔗

@@ -287,7 +287,7 @@ async fn create_billing_subscription(
         }
     }
 
-    let customer_id = if let Some(existing_customer) = existing_billing_customer {
+    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
         CustomerId::from_str(&existing_customer.stripe_customer_id)
             .context("failed to parse customer ID")?
     } else {
@@ -320,6 +320,15 @@ async fn create_billing_subscription(
                 .await?
         }
         Some(ProductCode::ZedProTrial) => {
+            if let Some(existing_billing_customer) = &existing_billing_customer {
+                if existing_billing_customer.trial_started_at.is_some() {
+                    return Err(Error::http(
+                        StatusCode::FORBIDDEN,
+                        "user already used free trial".into(),
+                    ));
+                }
+            }
+
             stripe_billing
                 .checkout_with_zed_pro_trial(
                     app.config.zed_pro_price_id()?,
@@ -817,6 +826,24 @@ async fn handle_customer_subscription_event(
             .await?
             .ok_or_else(|| anyhow!("billing customer not found"))?;
 
+    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
+        if subscription.status == SubscriptionStatus::Trialing {
+            let current_period_start =
+                DateTime::from_timestamp(subscription.current_period_start, 0)
+                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
+
+            app.db
+                .update_billing_customer(
+                    billing_customer.id,
+                    &UpdateBillingCustomerParams {
+                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
+                        ..Default::default()
+                    },
+                )
+                .await?;
+        }
+    }
+
     let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
         && subscription
             .cancellation_details

crates/collab/src/db/queries/billing_customers.rs 🔗

@@ -11,6 +11,7 @@ pub struct UpdateBillingCustomerParams {
     pub user_id: ActiveValue<UserId>,
     pub stripe_customer_id: ActiveValue<String>,
     pub has_overdue_invoices: ActiveValue<bool>,
+    pub trial_started_at: ActiveValue<Option<DateTime>>,
 }
 
 impl Database {
@@ -45,7 +46,8 @@ impl Database {
                 user_id: params.user_id.clone(),
                 stripe_customer_id: params.stripe_customer_id.clone(),
                 has_overdue_invoices: params.has_overdue_invoices.clone(),
-                ..Default::default()
+                trial_started_at: params.trial_started_at.clone(),
+                created_at: ActiveValue::not_set(),
             })
             .exec(&*tx)
             .await?;