@@ -249,6 +249,13 @@ async fn create_billing_subscription(
));
}
+ if app.db.has_overdue_billing_subscriptions(user.id).await? {
+ return Err(Error::http(
+ StatusCode::PAYMENT_REQUIRED,
+ "user has overdue billing subscriptions".into(),
+ ));
+ }
+
let customer_id =
if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
CustomerId::from_str(&existing_customer.stripe_customer_id)
@@ -719,6 +726,10 @@ async fn handle_customer_subscription_event(
billing_customer_id: billing_customer.id,
stripe_subscription_id: subscription.id.to_string(),
stripe_subscription_status: subscription.status.into(),
+ stripe_cancellation_reason: subscription
+ .cancellation_details
+ .and_then(|details| details.reason)
+ .map(|reason| reason.into()),
})
.await?;
}
@@ -7,6 +7,7 @@ pub struct CreateBillingSubscriptionParams {
pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
+ pub stripe_cancellation_reason: Option<StripeCancellationReason>,
}
#[derive(Debug, Default)]
@@ -29,6 +30,7 @@ impl Database {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
+ stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
..Default::default()
})
.exec_without_returning(&*tx)
@@ -168,4 +170,40 @@ impl Database {
})
.await
}
+
+ /// Returns whether the user has any overdue billing subscriptions.
+ pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<bool> {
+ Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0)
+ }
+
+ /// Returns the count of the overdue billing subscriptions for the user with the specified ID.
+ ///
+ /// This includes subscriptions:
+ /// - Whose status is `past_due`
+ /// - Whose status is `canceled` and the cancellation reason is `payment_failed`
+ pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
+ self.transaction(|tx| async move {
+ let past_due = billing_subscription::Column::StripeSubscriptionStatus
+ .eq(StripeSubscriptionStatus::PastDue);
+ let payment_failed = billing_subscription::Column::StripeSubscriptionStatus
+ .eq(StripeSubscriptionStatus::Canceled)
+ .and(
+ billing_subscription::Column::StripeCancellationReason
+ .eq(StripeCancellationReason::PaymentFailed),
+ );
+
+ let count = billing_subscription::Entity::find()
+ .inner_join(billing_customer::Entity)
+ .filter(
+ billing_customer::Column::UserId
+ .eq(user_id)
+ .and(past_due.or(payment_failed)),
+ )
+ .count(&*tx)
+ .await?;
+
+ Ok(count as usize)
+ })
+ .await
+ }
}
@@ -1,6 +1,6 @@
use std::sync::Arc;
-use crate::db::billing_subscription::StripeSubscriptionStatus;
+use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
use crate::db::tests::new_test_user;
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
@@ -41,6 +41,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
+ stripe_cancellation_reason: None,
})
.await
.unwrap();
@@ -75,6 +76,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
+ stripe_cancellation_reason: None,
})
.await
.unwrap();
@@ -86,3 +88,113 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
assert_eq!(subscription_count, 0);
}
}
+
+test_both_dbs!(
+ test_count_overdue_billing_subscriptions,
+ test_count_overdue_billing_subscriptions_postgres,
+ test_count_overdue_billing_subscriptions_sqlite
+);
+
+async fn test_count_overdue_billing_subscriptions(db: &Arc<Database>) {
+ // A user with no subscription has no overdue billing subscriptions.
+ {
+ let user_id = new_test_user(db, "no-subscription-user@example.com").await;
+ let subscription_count = db
+ .count_overdue_billing_subscriptions(user_id)
+ .await
+ .unwrap();
+
+ assert_eq!(subscription_count, 0);
+ }
+
+ // A user with a past-due subscription has an overdue billing subscription.
+ {
+ let user_id = new_test_user(db, "past-due-user@example.com").await;
+ let customer = db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id,
+ stripe_customer_id: "cus_past_due_user".into(),
+ })
+ .await
+ .unwrap();
+ assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
+
+ db.create_billing_subscription(&CreateBillingSubscriptionParams {
+ billing_customer_id: customer.id,
+ stripe_subscription_id: "sub_past_due_user".into(),
+ stripe_subscription_status: StripeSubscriptionStatus::PastDue,
+ stripe_cancellation_reason: None,
+ })
+ .await
+ .unwrap();
+
+ let subscription_count = db
+ .count_overdue_billing_subscriptions(user_id)
+ .await
+ .unwrap();
+ assert_eq!(subscription_count, 1);
+ }
+
+ // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription.
+ {
+ let user_id =
+ new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await;
+ let customer = db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id,
+ stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(),
+ })
+ .await
+ .unwrap();
+ assert_eq!(
+ customer.stripe_customer_id,
+ "cus_canceled_subscription_payment_failed_user".to_string()
+ );
+
+ db.create_billing_subscription(&CreateBillingSubscriptionParams {
+ billing_customer_id: customer.id,
+ stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(),
+ stripe_subscription_status: StripeSubscriptionStatus::Canceled,
+ stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed),
+ })
+ .await
+ .unwrap();
+
+ let subscription_count = db
+ .count_overdue_billing_subscriptions(user_id)
+ .await
+ .unwrap();
+ assert_eq!(subscription_count, 1);
+ }
+
+ // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions.
+ {
+ let user_id = new_test_user(db, "canceled-subscription-user@example.com").await;
+ let customer = db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id,
+ stripe_customer_id: "cus_canceled_subscription_user".into(),
+ })
+ .await
+ .unwrap();
+ assert_eq!(
+ customer.stripe_customer_id,
+ "cus_canceled_subscription_user".to_string()
+ );
+
+ db.create_billing_subscription(&CreateBillingSubscriptionParams {
+ billing_customer_id: customer.id,
+ stripe_subscription_id: "sub_canceled_subscription_user".into(),
+ stripe_subscription_status: StripeSubscriptionStatus::Canceled,
+ stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested),
+ })
+ .await
+ .unwrap();
+
+ let subscription_count = db
+ .count_overdue_billing_subscriptions(user_id)
+ .await
+ .unwrap();
+ assert_eq!(subscription_count, 0);
+ }
+}