collab: Use `billing_customers.has_overdue_invoices` to gate subscription access (#24240)

Marshall Bowers created

This PR updates the check that prevents subscribing with overdue
subscriptions to use the `billing_customers.has_overdue_invoices` field
instead.

This will allow us to set the value of `has_overdue_invoices` to `false`
when the invoices have been paid.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                         |  42 +-
crates/collab/src/db/queries/billing_subscriptions.rs    |  36 ---
crates/collab/src/db/tests/billing_subscription_tests.rs | 112 ---------
3 files changed, 23 insertions(+), 167 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -249,29 +249,31 @@ async fn create_billing_subscription(
         ));
     }
 
-    if app.db.has_overdue_billing_subscriptions(user.id).await? {
-        return Err(Error::http(
-            StatusCode::PAYMENT_REQUIRED,
-            "user has overdue billing subscriptions".into(),
-        ));
+    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
+    if let Some(existing_billing_customer) = &existing_billing_customer {
+        if existing_billing_customer.has_overdue_invoices {
+            return Err(Error::http(
+                StatusCode::PAYMENT_REQUIRED,
+                "user has overdue invoices".into(),
+            ));
+        }
     }
 
-    let customer_id =
-        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
-            CustomerId::from_str(&existing_customer.stripe_customer_id)
-                .context("failed to parse customer ID")?
-        } else {
-            let customer = Customer::create(
-                &stripe_client,
-                CreateCustomer {
-                    email: user.email_address.as_deref(),
-                    ..Default::default()
-                },
-            )
-            .await?;
+    let customer_id = if let Some(existing_customer) = existing_billing_customer {
+        CustomerId::from_str(&existing_customer.stripe_customer_id)
+            .context("failed to parse customer ID")?
+    } else {
+        let customer = Customer::create(
+            &stripe_client,
+            CreateCustomer {
+                email: user.email_address.as_deref(),
+                ..Default::default()
+            },
+        )
+        .await?;
 
-            customer.id
-        };
+        customer.id
+    };
 
     let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
     let stripe_model = stripe_billing.register_model(default_model).await?;

crates/collab/src/db/queries/billing_subscriptions.rs 🔗

@@ -170,40 +170,4 @@ impl Database {
         })
         .await
     }
-
-    /// Returns whether the user has any overdue billing subscriptions.
-    pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<bool> {
-        Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0)
-    }
-
-    /// Returns the count of the overdue billing subscriptions for the user with the specified ID.
-    ///
-    /// This includes subscriptions:
-    /// - Whose status is `past_due`
-    /// - Whose status is `canceled` and the cancellation reason is `payment_failed`
-    pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
-        self.transaction(|tx| async move {
-            let past_due = billing_subscription::Column::StripeSubscriptionStatus
-                .eq(StripeSubscriptionStatus::PastDue);
-            let payment_failed = billing_subscription::Column::StripeSubscriptionStatus
-                .eq(StripeSubscriptionStatus::Canceled)
-                .and(
-                    billing_subscription::Column::StripeCancellationReason
-                        .eq(StripeCancellationReason::PaymentFailed),
-                );
-
-            let count = billing_subscription::Entity::find()
-                .inner_join(billing_customer::Entity)
-                .filter(
-                    billing_customer::Column::UserId
-                        .eq(user_id)
-                        .and(past_due.or(payment_failed)),
-                )
-                .count(&*tx)
-                .await?;
-
-            Ok(count as usize)
-        })
-        .await
-    }
 }

crates/collab/src/db/tests/billing_subscription_tests.rs 🔗

@@ -1,6 +1,6 @@
 use std::sync::Arc;
 
-use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
+use crate::db::billing_subscription::StripeSubscriptionStatus;
 use crate::db::tests::new_test_user;
 use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
 use crate::test_both_dbs;
@@ -88,113 +88,3 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
         assert_eq!(subscription_count, 0);
     }
 }
-
-test_both_dbs!(
-    test_count_overdue_billing_subscriptions,
-    test_count_overdue_billing_subscriptions_postgres,
-    test_count_overdue_billing_subscriptions_sqlite
-);
-
-async fn test_count_overdue_billing_subscriptions(db: &Arc<Database>) {
-    // A user with no subscription has no overdue billing subscriptions.
-    {
-        let user_id = new_test_user(db, "no-subscription-user@example.com").await;
-        let subscription_count = db
-            .count_overdue_billing_subscriptions(user_id)
-            .await
-            .unwrap();
-
-        assert_eq!(subscription_count, 0);
-    }
-
-    // A user with a past-due subscription has an overdue billing subscription.
-    {
-        let user_id = new_test_user(db, "past-due-user@example.com").await;
-        let customer = db
-            .create_billing_customer(&CreateBillingCustomerParams {
-                user_id,
-                stripe_customer_id: "cus_past_due_user".into(),
-            })
-            .await
-            .unwrap();
-        assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
-
-        db.create_billing_subscription(&CreateBillingSubscriptionParams {
-            billing_customer_id: customer.id,
-            stripe_subscription_id: "sub_past_due_user".into(),
-            stripe_subscription_status: StripeSubscriptionStatus::PastDue,
-            stripe_cancellation_reason: None,
-        })
-        .await
-        .unwrap();
-
-        let subscription_count = db
-            .count_overdue_billing_subscriptions(user_id)
-            .await
-            .unwrap();
-        assert_eq!(subscription_count, 1);
-    }
-
-    // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription.
-    {
-        let user_id =
-            new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await;
-        let customer = db
-            .create_billing_customer(&CreateBillingCustomerParams {
-                user_id,
-                stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(),
-            })
-            .await
-            .unwrap();
-        assert_eq!(
-            customer.stripe_customer_id,
-            "cus_canceled_subscription_payment_failed_user".to_string()
-        );
-
-        db.create_billing_subscription(&CreateBillingSubscriptionParams {
-            billing_customer_id: customer.id,
-            stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(),
-            stripe_subscription_status: StripeSubscriptionStatus::Canceled,
-            stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed),
-        })
-        .await
-        .unwrap();
-
-        let subscription_count = db
-            .count_overdue_billing_subscriptions(user_id)
-            .await
-            .unwrap();
-        assert_eq!(subscription_count, 1);
-    }
-
-    // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions.
-    {
-        let user_id = new_test_user(db, "canceled-subscription-user@example.com").await;
-        let customer = db
-            .create_billing_customer(&CreateBillingCustomerParams {
-                user_id,
-                stripe_customer_id: "cus_canceled_subscription_user".into(),
-            })
-            .await
-            .unwrap();
-        assert_eq!(
-            customer.stripe_customer_id,
-            "cus_canceled_subscription_user".to_string()
-        );
-
-        db.create_billing_subscription(&CreateBillingSubscriptionParams {
-            billing_customer_id: customer.id,
-            stripe_subscription_id: "sub_canceled_subscription_user".into(),
-            stripe_subscription_status: StripeSubscriptionStatus::Canceled,
-            stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested),
-        })
-        .await
-        .unwrap();
-
-        let subscription_count = db
-            .count_overdue_billing_subscriptions(user_id)
-            .await
-            .unwrap();
-        assert_eq!(subscription_count, 0);
-    }
-}