1use std::sync::Arc;
2
3use crate::db::billing_subscription::StripeSubscriptionStatus;
4use crate::db::tests::new_test_user;
5use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
6use crate::test_both_dbs;
7
8use super::Database;
9
10test_both_dbs!(
11 test_get_active_billing_subscriptions,
12 test_get_active_billing_subscriptions_postgres,
13 test_get_active_billing_subscriptions_sqlite
14);
15
16async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
17 // A user with no subscription has no active billing subscriptions.
18 {
19 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
20 let subscription_count = db
21 .count_active_billing_subscriptions(user_id)
22 .await
23 .unwrap();
24
25 assert_eq!(subscription_count, 0);
26 }
27
28 // A user with an active subscription has one active billing subscription.
29 {
30 let user_id = new_test_user(db, "active-user@example.com").await;
31 let customer = db
32 .create_billing_customer(&CreateBillingCustomerParams {
33 user_id,
34 stripe_customer_id: "cus_active_user".into(),
35 })
36 .await
37 .unwrap();
38 assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
39
40 db.create_billing_subscription(&CreateBillingSubscriptionParams {
41 billing_customer_id: customer.id,
42 stripe_subscription_id: "sub_active_user".into(),
43 stripe_subscription_status: StripeSubscriptionStatus::Active,
44 })
45 .await
46 .unwrap();
47
48 let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
49 assert_eq!(subscriptions.len(), 1);
50
51 let subscription = &subscriptions[0];
52 assert_eq!(
53 subscription.stripe_subscription_id,
54 "sub_active_user".to_string()
55 );
56 assert_eq!(
57 subscription.stripe_subscription_status,
58 StripeSubscriptionStatus::Active
59 );
60 }
61
62 // A user with a past-due subscription has no active billing subscriptions.
63 {
64 let user_id = new_test_user(db, "past-due-user@example.com").await;
65 let customer = db
66 .create_billing_customer(&CreateBillingCustomerParams {
67 user_id,
68 stripe_customer_id: "cus_past_due_user".into(),
69 })
70 .await
71 .unwrap();
72 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
73
74 db.create_billing_subscription(&CreateBillingSubscriptionParams {
75 billing_customer_id: customer.id,
76 stripe_subscription_id: "sub_past_due_user".into(),
77 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
78 })
79 .await
80 .unwrap();
81
82 let subscription_count = db
83 .count_active_billing_subscriptions(user_id)
84 .await
85 .unwrap();
86 assert_eq!(subscription_count, 0);
87 }
88}