collab: Add separate `billing_customers` table (#15457)

Marshall Bowers created

This PR adds a new `billing_customers` table to hold the billing
customers.

Previously we were storing both the `stripe_customer_id` and
`stripe_subscription_id` in the `billable_subscriptions` table. However,
this creates problems when we need to correlate subscription events back
to the subscription record, as we don't know the user that the Stripe
event corresponds to.

By moving the `stripe_customer_id` to a separate table we can create the
Stripe customer earlier in the flowβ€”before we create the Stripe Checkout
sessionβ€”and associate that customer with a user. This way when we
receive events down the line we can use the Stripe customer ID to
correlate it back to the user.

We're doing some destructive actions to the `billing_subscriptions`
table, but this is fine, as we haven't started using them yet.

Release Notes:

- N/A

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql   | 16 
crates/collab/migrations/20240730014107_add_billing_customer.sql | 18 
crates/collab/src/api/billing.rs                                 | 47 +
crates/collab/src/db.rs                                          |  1 
crates/collab/src/db/ids.rs                                      |  1 
crates/collab/src/db/queries.rs                                  |  1 
crates/collab/src/db/queries/billing_customers.rs                | 42 +
crates/collab/src/db/queries/billing_subscriptions.rs            | 12 
crates/collab/src/db/tables.rs                                   |  1 
crates/collab/src/db/tables/billing_customer.rs                  | 39 +
crates/collab/src/db/tables/billing_subscription.rs              | 17 
crates/collab/src/db/tables/user.rs                              | 10 
crates/collab/src/db/tests/billing_subscription_tests.rs         | 30 
13 files changed, 183 insertions(+), 52 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql πŸ”—

@@ -420,12 +420,20 @@ CREATE TABLE dev_server_projects (
 CREATE TABLE IF NOT EXISTS billing_subscriptions (
     id INTEGER PRIMARY KEY AUTOINCREMENT,
     created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
-    user_id INTEGER NOT NULL REFERENCES users(id),
-    stripe_customer_id TEXT NOT NULL,
+    billing_customer_id INTEGER NOT NULL REFERENCES billing_customers(id),
     stripe_subscription_id TEXT NOT NULL,
     stripe_subscription_status TEXT NOT NULL
 );
 
-CREATE INDEX "ix_billing_subscriptions_on_user_id" ON billing_subscriptions (user_id);
-CREATE INDEX "ix_billing_subscriptions_on_stripe_customer_id" ON billing_subscriptions (stripe_customer_id);
+CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
 CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id);
+
+CREATE TABLE IF NOT EXISTS billing_customers (
+    id INTEGER PRIMARY KEY AUTOINCREMENT,
+    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+    user_id INTEGER NOT NULL REFERENCES users(id),
+    stripe_customer_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
+CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);

crates/collab/migrations/20240730014107_add_billing_customer.sql πŸ”—

@@ -0,0 +1,18 @@
+CREATE TABLE IF NOT EXISTS billing_customers (
+    id SERIAL PRIMARY KEY,
+    created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT now(),
+    user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+    stripe_customer_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
+CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
+
+-- Make `billing_subscriptions` reference `billing_customers` instead of having its
+-- own `user_id` and `stripe_customer_id`.
+DROP INDEX IF EXISTS "ix_billing_subscriptions_on_user_id";
+DROP INDEX IF EXISTS "ix_billing_subscriptions_on_stripe_customer_id";
+ALTER TABLE billing_subscriptions DROP COLUMN user_id;
+ALTER TABLE billing_subscriptions DROP COLUMN stripe_customer_id;
+ALTER TABLE billing_subscriptions ADD COLUMN billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id) ON DELETE CASCADE;
+CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);

crates/collab/src/api/billing.rs πŸ”—

@@ -3,7 +3,6 @@ use std::sync::Arc;
 
 use anyhow::{anyhow, Context};
 use axum::{extract, routing::post, Extension, Json, Router};
-use collections::HashSet;
 use reqwest::StatusCode;
 use serde::{Deserialize, Serialize};
 use stripe::{
@@ -11,7 +10,7 @@ use stripe::{
     CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
     CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
     CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
-    CustomerId,
+    CreateCustomer, Customer, CustomerId,
 };
 
 use crate::db::BillingSubscriptionId;
@@ -59,28 +58,27 @@ async fn create_billing_subscription(
         ))?
     };
 
-    let existing_customer_id = {
-        let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
-        let distinct_customer_ids = existing_subscriptions
-            .iter()
-            .map(|subscription| subscription.stripe_customer_id.as_str())
-            .collect::<HashSet<_>>();
-        // Sanity: Make sure we can determine a single Stripe customer ID for the user.
-        if distinct_customer_ids.len() > 1 {
-            Err(anyhow!("user has multiple existing customer IDs"))?;
-        }
+    let customer_id =
+        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
+            CustomerId::from_str(&existing_customer.stripe_customer_id)
+                .context("failed to parse customer ID")?
+        } else {
+            let customer = Customer::create(
+                &stripe_client,
+                CreateCustomer {
+                    email: user.email_address.as_deref(),
+                    ..Default::default()
+                },
+            )
+            .await?;
 
-        distinct_customer_ids
-            .into_iter()
-            .next()
-            .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
-            .transpose()
-    }?;
+            customer.id
+        };
 
     let checkout_session = {
         let mut params = CreateCheckoutSession::new();
         params.mode = Some(stripe::CheckoutSessionMode::Subscription);
-        params.customer = existing_customer_id;
+        params.customer = Some(customer_id);
         params.client_reference_id = Some(user.github_login.as_str());
         params.line_items = Some(vec![CreateCheckoutSessionLineItems {
             price: Some(stripe_price_id.to_string()),
@@ -140,6 +138,14 @@ async fn manage_billing_subscription(
         ))?
     };
 
+    let customer = app
+        .db
+        .get_billing_customer_by_user_id(user.id)
+        .await?
+        .ok_or_else(|| anyhow!("billing customer not found"))?;
+    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
+        .context("failed to parse customer ID")?;
+
     let subscription = if let Some(subscription_id) = body.subscription_id {
         app.db
             .get_billing_subscription_by_id(subscription_id)
@@ -158,9 +164,6 @@ async fn manage_billing_subscription(
             .ok_or_else(|| anyhow!("user has no active subscriptions"))?
     };
 
-    let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
-        .context("failed to parse customer ID")?;
-
     let flow = match body.intent {
         ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
             type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,

crates/collab/src/db.rs πŸ”—

@@ -45,6 +45,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard};
 pub use tests::TestDb;
 
 pub use ids::*;
+pub use queries::billing_customers::CreateBillingCustomerParams;
 pub use queries::billing_subscriptions::CreateBillingSubscriptionParams;
 pub use queries::contributors::ContributorSelector;
 pub use sea_orm::ConnectOptions;

crates/collab/src/db/ids.rs πŸ”—

@@ -68,6 +68,7 @@ macro_rules! id_type {
 }
 
 id_type!(AccessTokenId);
+id_type!(BillingCustomerId);
 id_type!(BillingSubscriptionId);
 id_type!(BufferId);
 id_type!(ChannelBufferCollaboratorId);

crates/collab/src/db/queries/billing_customers.rs πŸ”—

@@ -0,0 +1,42 @@
+use super::*;
+
+#[derive(Debug)]
+pub struct CreateBillingCustomerParams {
+    pub user_id: UserId,
+    pub stripe_customer_id: String,
+}
+
+impl Database {
+    /// Creates a new billing customer.
+    pub async fn create_billing_customer(
+        &self,
+        params: &CreateBillingCustomerParams,
+    ) -> Result<billing_customer::Model> {
+        self.transaction(|tx| async move {
+            let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
+                user_id: ActiveValue::set(params.user_id),
+                stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
+                ..Default::default()
+            })
+            .exec_with_returning(&*tx)
+            .await?;
+
+            Ok(customer)
+        })
+        .await
+    }
+
+    /// Returns the billing customer for the user with the specified ID.
+    pub async fn get_billing_customer_by_user_id(
+        &self,
+        user_id: UserId,
+    ) -> Result<Option<billing_customer::Model>> {
+        self.transaction(|tx| async move {
+            Ok(billing_customer::Entity::find()
+                .filter(billing_customer::Column::UserId.eq(user_id))
+                .one(&*tx)
+                .await?)
+        })
+        .await
+    }
+}

crates/collab/src/db/queries/billing_subscriptions.rs πŸ”—

@@ -4,8 +4,7 @@ use super::*;
 
 #[derive(Debug)]
 pub struct CreateBillingSubscriptionParams {
-    pub user_id: UserId,
-    pub stripe_customer_id: String,
+    pub billing_customer_id: BillingCustomerId,
     pub stripe_subscription_id: String,
     pub stripe_subscription_status: StripeSubscriptionStatus,
 }
@@ -18,8 +17,7 @@ impl Database {
     ) -> Result<()> {
         self.transaction(|tx| async move {
             billing_subscription::Entity::insert(billing_subscription::ActiveModel {
-                user_id: ActiveValue::set(params.user_id),
-                stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
+                billing_customer_id: ActiveValue::set(params.billing_customer_id),
                 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
                 stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
                 ..Default::default()
@@ -56,7 +54,8 @@ impl Database {
     ) -> Result<Vec<billing_subscription::Model>> {
         self.transaction(|tx| async move {
             let subscriptions = billing_subscription::Entity::find()
-                .filter(billing_subscription::Column::UserId.eq(user_id))
+                .inner_join(billing_customer::Entity)
+                .filter(billing_customer::Column::UserId.eq(user_id))
                 .order_by_asc(billing_subscription::Column::Id)
                 .all(&*tx)
                 .await?;
@@ -73,8 +72,9 @@ impl Database {
     ) -> Result<Vec<billing_subscription::Model>> {
         self.transaction(|tx| async move {
             let subscriptions = billing_subscription::Entity::find()
+                .inner_join(billing_customer::Entity)
                 .filter(
-                    billing_subscription::Column::UserId.eq(user_id).and(
+                    billing_customer::Column::UserId.eq(user_id).and(
                         billing_subscription::Column::StripeSubscriptionStatus
                             .eq(StripeSubscriptionStatus::Active),
                     ),

crates/collab/src/db/tables/billing_customer.rs πŸ”—

@@ -0,0 +1,39 @@
+use crate::db::{BillingCustomerId, UserId};
+use sea_orm::entity::prelude::*;
+
+/// A billing customer.
+#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_customers")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: BillingCustomerId,
+    pub user_id: UserId,
+    pub stripe_customer_id: String,
+    pub created_at: DateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+    #[sea_orm(has_many = "super::billing_subscription::Entity")]
+    BillingSubscription,
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl Related<super::billing_subscription::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::BillingSubscription.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tables/billing_subscription.rs πŸ”—

@@ -1,4 +1,4 @@
-use crate::db::{BillingSubscriptionId, UserId};
+use crate::db::{BillingCustomerId, BillingSubscriptionId};
 use sea_orm::entity::prelude::*;
 
 /// A billing subscription.
@@ -7,8 +7,7 @@ use sea_orm::entity::prelude::*;
 pub struct Model {
     #[sea_orm(primary_key)]
     pub id: BillingSubscriptionId,
-    pub user_id: UserId,
-    pub stripe_customer_id: String,
+    pub billing_customer_id: BillingCustomerId,
     pub stripe_subscription_id: String,
     pub stripe_subscription_status: StripeSubscriptionStatus,
     pub created_at: DateTime,
@@ -17,16 +16,16 @@ pub struct Model {
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
 pub enum Relation {
     #[sea_orm(
-        belongs_to = "super::user::Entity",
-        from = "Column::UserId",
-        to = "super::user::Column::Id"
+        belongs_to = "super::billing_customer::Entity",
+        from = "Column::BillingCustomerId",
+        to = "super::billing_customer::Column::Id"
     )]
-    User,
+    BillingCustomer,
 }
 
-impl Related<super::user::Entity> for Entity {
+impl Related<super::billing_customer::Entity> for Entity {
     fn to() -> RelationDef {
-        Relation::User.def()
+        Relation::BillingCustomer.def()
     }
 }
 

crates/collab/src/db/tables/user.rs πŸ”—

@@ -24,8 +24,8 @@ pub struct Model {
 pub enum Relation {
     #[sea_orm(has_many = "super::access_token::Entity")]
     AccessToken,
-    #[sea_orm(has_many = "super::billing_subscription::Entity")]
-    BillingSubscription,
+    #[sea_orm(has_one = "super::billing_customer::Entity")]
+    BillingCustomer,
     #[sea_orm(has_one = "super::room_participant::Entity")]
     RoomParticipant,
     #[sea_orm(has_many = "super::project::Entity")]
@@ -44,6 +44,12 @@ impl Related<super::access_token::Entity> for Entity {
     }
 }
 
+impl Related<super::billing_customer::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::BillingCustomer.def()
+    }
+}
+
 impl Related<super::room_participant::Entity> for Entity {
     fn to() -> RelationDef {
         Relation::RoomParticipant.def()

crates/collab/src/db/tests/billing_subscription_tests.rs πŸ”—

@@ -2,7 +2,7 @@ use std::sync::Arc;
 
 use crate::db::billing_subscription::StripeSubscriptionStatus;
 use crate::db::tests::new_test_user;
-use crate::db::CreateBillingSubscriptionParams;
+use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
 use crate::test_both_dbs;
 
 use super::Database;
@@ -25,9 +25,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
     // A user with an active subscription has one active billing subscription.
     {
         let user_id = new_test_user(db, "active-user@example.com").await;
+        let customer = db
+            .create_billing_customer(&CreateBillingCustomerParams {
+                user_id,
+                stripe_customer_id: "cus_active_user".into(),
+            })
+            .await
+            .unwrap();
+        assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
+
         db.create_billing_subscription(&CreateBillingSubscriptionParams {
-            user_id,
-            stripe_customer_id: "cus_active_user".into(),
+            billing_customer_id: customer.id,
             stripe_subscription_id: "sub_active_user".into(),
             stripe_subscription_status: StripeSubscriptionStatus::Active,
         })
@@ -38,10 +46,6 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
         assert_eq!(subscriptions.len(), 1);
 
         let subscription = &subscriptions[0];
-        assert_eq!(
-            subscription.stripe_customer_id,
-            "cus_active_user".to_string()
-        );
         assert_eq!(
             subscription.stripe_subscription_id,
             "sub_active_user".to_string()
@@ -55,9 +59,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
     // A user with a past-due subscription has no active billing subscriptions.
     {
         let user_id = new_test_user(db, "past-due-user@example.com").await;
+        let customer = db
+            .create_billing_customer(&CreateBillingCustomerParams {
+                user_id,
+                stripe_customer_id: "cus_past_due_user".into(),
+            })
+            .await
+            .unwrap();
+        assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
+
         db.create_billing_subscription(&CreateBillingSubscriptionParams {
-            user_id,
-            stripe_customer_id: "cus_past_due_user".into(),
+            billing_customer_id: customer.id,
             stripe_subscription_id: "sub_past_due_user".into(),
             stripe_subscription_status: StripeSubscriptionStatus::PastDue,
         })