@@ -249,29 +249,31 @@ async fn create_billing_subscription(
));
}
- if app.db.has_overdue_billing_subscriptions(user.id).await? {
- return Err(Error::http(
- StatusCode::PAYMENT_REQUIRED,
- "user has overdue billing subscriptions".into(),
- ));
+ let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
+ if let Some(existing_billing_customer) = &existing_billing_customer {
+ if existing_billing_customer.has_overdue_invoices {
+ return Err(Error::http(
+ StatusCode::PAYMENT_REQUIRED,
+ "user has overdue invoices".into(),
+ ));
+ }
}
- let customer_id =
- if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
- CustomerId::from_str(&existing_customer.stripe_customer_id)
- .context("failed to parse customer ID")?
- } else {
- let customer = Customer::create(
- &stripe_client,
- CreateCustomer {
- email: user.email_address.as_deref(),
- ..Default::default()
- },
- )
- .await?;
+ let customer_id = if let Some(existing_customer) = existing_billing_customer {
+ CustomerId::from_str(&existing_customer.stripe_customer_id)
+ .context("failed to parse customer ID")?
+ } else {
+ let customer = Customer::create(
+ &stripe_client,
+ CreateCustomer {
+ email: user.email_address.as_deref(),
+ ..Default::default()
+ },
+ )
+ .await?;
- customer.id
- };
+ customer.id
+ };
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
@@ -170,40 +170,4 @@ impl Database {
})
.await
}
-
- /// Returns whether the user has any overdue billing subscriptions.
- pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<bool> {
- Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0)
- }
-
- /// Returns the count of the overdue billing subscriptions for the user with the specified ID.
- ///
- /// This includes subscriptions:
- /// - Whose status is `past_due`
- /// - Whose status is `canceled` and the cancellation reason is `payment_failed`
- pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
- self.transaction(|tx| async move {
- let past_due = billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::PastDue);
- let payment_failed = billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::Canceled)
- .and(
- billing_subscription::Column::StripeCancellationReason
- .eq(StripeCancellationReason::PaymentFailed),
- );
-
- let count = billing_subscription::Entity::find()
- .inner_join(billing_customer::Entity)
- .filter(
- billing_customer::Column::UserId
- .eq(user_id)
- .and(past_due.or(payment_failed)),
- )
- .count(&*tx)
- .await?;
-
- Ok(count as usize)
- })
- .await
- }
}
@@ -1,6 +1,6 @@
use std::sync::Arc;
-use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
+use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::tests::new_test_user;
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
@@ -88,113 +88,3 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
assert_eq!(subscription_count, 0);
}
}
-
-test_both_dbs!(
- test_count_overdue_billing_subscriptions,
- test_count_overdue_billing_subscriptions_postgres,
- test_count_overdue_billing_subscriptions_sqlite
-);
-
-async fn test_count_overdue_billing_subscriptions(db: &Arc<Database>) {
- // A user with no subscription has no overdue billing subscriptions.
- {
- let user_id = new_test_user(db, "no-subscription-user@example.com").await;
- let subscription_count = db
- .count_overdue_billing_subscriptions(user_id)
- .await
- .unwrap();
-
- assert_eq!(subscription_count, 0);
- }
-
- // A user with a past-due subscription has an overdue billing subscription.
- {
- let user_id = new_test_user(db, "past-due-user@example.com").await;
- let customer = db
- .create_billing_customer(&CreateBillingCustomerParams {
- user_id,
- stripe_customer_id: "cus_past_due_user".into(),
- })
- .await
- .unwrap();
- assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
-
- db.create_billing_subscription(&CreateBillingSubscriptionParams {
- billing_customer_id: customer.id,
- stripe_subscription_id: "sub_past_due_user".into(),
- stripe_subscription_status: StripeSubscriptionStatus::PastDue,
- stripe_cancellation_reason: None,
- })
- .await
- .unwrap();
-
- let subscription_count = db
- .count_overdue_billing_subscriptions(user_id)
- .await
- .unwrap();
- assert_eq!(subscription_count, 1);
- }
-
- // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription.
- {
- let user_id =
- new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await;
- let customer = db
- .create_billing_customer(&CreateBillingCustomerParams {
- user_id,
- stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(),
- })
- .await
- .unwrap();
- assert_eq!(
- customer.stripe_customer_id,
- "cus_canceled_subscription_payment_failed_user".to_string()
- );
-
- db.create_billing_subscription(&CreateBillingSubscriptionParams {
- billing_customer_id: customer.id,
- stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(),
- stripe_subscription_status: StripeSubscriptionStatus::Canceled,
- stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed),
- })
- .await
- .unwrap();
-
- let subscription_count = db
- .count_overdue_billing_subscriptions(user_id)
- .await
- .unwrap();
- assert_eq!(subscription_count, 1);
- }
-
- // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions.
- {
- let user_id = new_test_user(db, "canceled-subscription-user@example.com").await;
- let customer = db
- .create_billing_customer(&CreateBillingCustomerParams {
- user_id,
- stripe_customer_id: "cus_canceled_subscription_user".into(),
- })
- .await
- .unwrap();
- assert_eq!(
- customer.stripe_customer_id,
- "cus_canceled_subscription_user".to_string()
- );
-
- db.create_billing_subscription(&CreateBillingSubscriptionParams {
- billing_customer_id: customer.id,
- stripe_subscription_id: "sub_canceled_subscription_user".into(),
- stripe_subscription_status: StripeSubscriptionStatus::Canceled,
- stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested),
- })
- .await
- .unwrap();
-
- let subscription_count = db
- .count_overdue_billing_subscriptions(user_id)
- .await
- .unwrap();
- assert_eq!(subscription_count, 0);
- }
-}