1use std::sync::Arc;
2
3use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
4use crate::db::tests::new_test_user;
5use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
6use crate::test_both_dbs;
7
8use super::Database;
9
10test_both_dbs!(
11 test_get_active_billing_subscriptions,
12 test_get_active_billing_subscriptions_postgres,
13 test_get_active_billing_subscriptions_sqlite
14);
15
16async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
17 // A user with no subscription has no active billing subscriptions.
18 {
19 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
20 let subscription_count = db
21 .count_active_billing_subscriptions(user_id)
22 .await
23 .unwrap();
24
25 assert_eq!(subscription_count, 0);
26 }
27
28 // A user with an active subscription has one active billing subscription.
29 {
30 let user_id = new_test_user(db, "active-user@example.com").await;
31 let customer = db
32 .create_billing_customer(&CreateBillingCustomerParams {
33 user_id,
34 stripe_customer_id: "cus_active_user".into(),
35 })
36 .await
37 .unwrap();
38 assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
39
40 db.create_billing_subscription(&CreateBillingSubscriptionParams {
41 billing_customer_id: customer.id,
42 stripe_subscription_id: "sub_active_user".into(),
43 stripe_subscription_status: StripeSubscriptionStatus::Active,
44 stripe_cancellation_reason: None,
45 })
46 .await
47 .unwrap();
48
49 let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
50 assert_eq!(subscriptions.len(), 1);
51
52 let subscription = &subscriptions[0];
53 assert_eq!(
54 subscription.stripe_subscription_id,
55 "sub_active_user".to_string()
56 );
57 assert_eq!(
58 subscription.stripe_subscription_status,
59 StripeSubscriptionStatus::Active
60 );
61 }
62
63 // A user with a past-due subscription has no active billing subscriptions.
64 {
65 let user_id = new_test_user(db, "past-due-user@example.com").await;
66 let customer = db
67 .create_billing_customer(&CreateBillingCustomerParams {
68 user_id,
69 stripe_customer_id: "cus_past_due_user".into(),
70 })
71 .await
72 .unwrap();
73 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
74
75 db.create_billing_subscription(&CreateBillingSubscriptionParams {
76 billing_customer_id: customer.id,
77 stripe_subscription_id: "sub_past_due_user".into(),
78 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
79 stripe_cancellation_reason: None,
80 })
81 .await
82 .unwrap();
83
84 let subscription_count = db
85 .count_active_billing_subscriptions(user_id)
86 .await
87 .unwrap();
88 assert_eq!(subscription_count, 0);
89 }
90}
91
92test_both_dbs!(
93 test_count_overdue_billing_subscriptions,
94 test_count_overdue_billing_subscriptions_postgres,
95 test_count_overdue_billing_subscriptions_sqlite
96);
97
98async fn test_count_overdue_billing_subscriptions(db: &Arc<Database>) {
99 // A user with no subscription has no overdue billing subscriptions.
100 {
101 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
102 let subscription_count = db
103 .count_overdue_billing_subscriptions(user_id)
104 .await
105 .unwrap();
106
107 assert_eq!(subscription_count, 0);
108 }
109
110 // A user with a past-due subscription has an overdue billing subscription.
111 {
112 let user_id = new_test_user(db, "past-due-user@example.com").await;
113 let customer = db
114 .create_billing_customer(&CreateBillingCustomerParams {
115 user_id,
116 stripe_customer_id: "cus_past_due_user".into(),
117 })
118 .await
119 .unwrap();
120 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
121
122 db.create_billing_subscription(&CreateBillingSubscriptionParams {
123 billing_customer_id: customer.id,
124 stripe_subscription_id: "sub_past_due_user".into(),
125 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
126 stripe_cancellation_reason: None,
127 })
128 .await
129 .unwrap();
130
131 let subscription_count = db
132 .count_overdue_billing_subscriptions(user_id)
133 .await
134 .unwrap();
135 assert_eq!(subscription_count, 1);
136 }
137
138 // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription.
139 {
140 let user_id =
141 new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await;
142 let customer = db
143 .create_billing_customer(&CreateBillingCustomerParams {
144 user_id,
145 stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(),
146 })
147 .await
148 .unwrap();
149 assert_eq!(
150 customer.stripe_customer_id,
151 "cus_canceled_subscription_payment_failed_user".to_string()
152 );
153
154 db.create_billing_subscription(&CreateBillingSubscriptionParams {
155 billing_customer_id: customer.id,
156 stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(),
157 stripe_subscription_status: StripeSubscriptionStatus::Canceled,
158 stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed),
159 })
160 .await
161 .unwrap();
162
163 let subscription_count = db
164 .count_overdue_billing_subscriptions(user_id)
165 .await
166 .unwrap();
167 assert_eq!(subscription_count, 1);
168 }
169
170 // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions.
171 {
172 let user_id = new_test_user(db, "canceled-subscription-user@example.com").await;
173 let customer = db
174 .create_billing_customer(&CreateBillingCustomerParams {
175 user_id,
176 stripe_customer_id: "cus_canceled_subscription_user".into(),
177 })
178 .await
179 .unwrap();
180 assert_eq!(
181 customer.stripe_customer_id,
182 "cus_canceled_subscription_user".to_string()
183 );
184
185 db.create_billing_subscription(&CreateBillingSubscriptionParams {
186 billing_customer_id: customer.id,
187 stripe_subscription_id: "sub_canceled_subscription_user".into(),
188 stripe_subscription_status: StripeSubscriptionStatus::Canceled,
189 stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested),
190 })
191 .await
192 .unwrap();
193
194 let subscription_count = db
195 .count_overdue_billing_subscriptions(user_id)
196 .await
197 .unwrap();
198 assert_eq!(subscription_count, 0);
199 }
200}