1use std::sync::Arc;
2
3use crate::db::billing_subscription::StripeSubscriptionStatus;
4use crate::db::tests::new_test_user;
5use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
6use crate::test_both_dbs;
7
8use super::Database;
9
10test_both_dbs!(
11 test_get_active_billing_subscriptions,
12 test_get_active_billing_subscriptions_postgres,
13 test_get_active_billing_subscriptions_sqlite
14);
15
16async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
17 // A user with no subscription has no active billing subscriptions.
18 {
19 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
20 let subscription_count = db
21 .count_active_billing_subscriptions(user_id)
22 .await
23 .unwrap();
24
25 assert_eq!(subscription_count, 0);
26 }
27
28 // A user with an active subscription has one active billing subscription.
29 {
30 let user_id = new_test_user(db, "active-user@example.com").await;
31 let customer = db
32 .create_billing_customer(&CreateBillingCustomerParams {
33 user_id,
34 stripe_customer_id: "cus_active_user".into(),
35 })
36 .await
37 .unwrap();
38 assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
39
40 db.create_billing_subscription(&CreateBillingSubscriptionParams {
41 billing_customer_id: customer.id,
42 stripe_subscription_id: "sub_active_user".into(),
43 stripe_subscription_status: StripeSubscriptionStatus::Active,
44 stripe_cancellation_reason: None,
45 })
46 .await
47 .unwrap();
48
49 let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
50 assert_eq!(subscriptions.len(), 1);
51
52 let subscription = &subscriptions[0];
53 assert_eq!(
54 subscription.stripe_subscription_id,
55 "sub_active_user".to_string()
56 );
57 assert_eq!(
58 subscription.stripe_subscription_status,
59 StripeSubscriptionStatus::Active
60 );
61 }
62
63 // A user with a past-due subscription has no active billing subscriptions.
64 {
65 let user_id = new_test_user(db, "past-due-user@example.com").await;
66 let customer = db
67 .create_billing_customer(&CreateBillingCustomerParams {
68 user_id,
69 stripe_customer_id: "cus_past_due_user".into(),
70 })
71 .await
72 .unwrap();
73 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
74
75 db.create_billing_subscription(&CreateBillingSubscriptionParams {
76 billing_customer_id: customer.id,
77 stripe_subscription_id: "sub_past_due_user".into(),
78 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
79 stripe_cancellation_reason: None,
80 })
81 .await
82 .unwrap();
83
84 let subscription_count = db
85 .count_active_billing_subscriptions(user_id)
86 .await
87 .unwrap();
88 assert_eq!(subscription_count, 0);
89 }
90}